diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 85b3052bfb..b35977a81c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1208,12 +1208,34 @@ def aten_bilinear( # bias shape: (out_features) - optional # output shape: (..., out_features) - # Use Einsum to compute the bilinear transformation - # "...i,oij,...j->...o" means: - # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] - result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + # input1 and input2 must have identical batch dimensions + # Use MatMul to compute the bilinear transformation + batch_size = op.Shape(input1, start=0, end=-1) + input1_shape = op.Shape(input1, start=-1) + input2_shape = op.Shape(input2, start=-1) + output_shape = op.Shape(weight, start=0, end=1) + neg_1 = op.Constant(value_ints=[-1]) + + # (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features) + W_permute = op.Transpose(weight, perm=[1, 0, 2]) + + # (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features) + W_flat = op.Reshape( + W_permute, + op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), + ) + + # (..., in1_features) @ (in1_features, out_features * in2_features) -> (..., out_features * in2_features) + tmp = op.MatMul(input1, W_flat) + + # (..., out_features * in2_features) -> (..., out_features, in2_features) + tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) + + # (..., in2_features) -> (..., in2_features, 1) + # -> (..., out_features, in2_features) @ (..., in2_features, 1) + # -> (..., out_features, 1) -> (..., out_features) + result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) - # Add bias if provided if bias is not None: result = op.Add(result, bias)