Hello authors, thank you for sharing your insightful work!
I have a question regarding to parameter sharing strategy shown in the paper and the code. In the paper page 7, you mentioned In experiments, we notice that for rational function, we share the denominator coefficient b_n among all groups and use different a_m for each group. It gets better performance.
So I understand that the model has a single combination of denominator weight and multiple (=8) combination of numerator weight. However, the code seems to have single numerator weight and eight denominator weights:
https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L59-L61
weight_numerator = torch.tensor(data[mode]["init_w_numerator"]).view(1, -1)
weight_denominator = torch.tensor(data[mode]["init_w_denominator"])
weight_denominator = torch.cat([weight_denominator] * self.num_groups).view(self.num_groups, -1)
https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L85-L87
# Repeat the weights for all groups
weight_numerator = self.weight_numerator.repeat(self.num_groups, 1)
return self.rational(input, weight_numerator, self.weight_denominator, self.num_groups)
Which one did you really intend?
Thank you in advance!
Hello authors, thank you for sharing your insightful work!
I have a question regarding to parameter sharing strategy shown in the paper and the code. In the paper page 7, you mentioned
In experiments, we notice that for rational function, we share the denominator coefficient b_n among all groups and use different a_m for each group. It gets better performance.So I understand that the model has a single combination of denominator weight and multiple (=8) combination of numerator weight. However, the code seems to have single numerator weight and eight denominator weights:
https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L59-L61
https://github.com/Adamdad/rational_kat_cu/blob/181bae8baf19075bef94b5f62dac320b3d4b27d3/kat_rational/kat_1dgroup_triton.py#L85-L87
Which one did you really intend?
Thank you in advance!