@@ -1205,51 +1205,59 @@ def aten_bilinear(
12051205 # bias shape: (out_features) - optional
12061206 # output shape: (..., out_features)
12071207
1208- # Decompose bilinear into MatMul operations:
1209- # 1. Create outer product of input1 and input2
1210- # 2. Reshape to flatten feature dimensions
1211- # 3. Use MatMul with reshaped weight
1212-
1213- # Get shapes for reshaping
1214- input1_shape = op .Shape (input1 )
1215- weight_shape = op .Shape (weight )
1208+ # Leveraging N-dimensional MatMul, we can compute this as:
1209+ # 1. weight @ input2.T -> [out_features, in1_features, ...batch_dims]
1210+ # 2. input1 @ result -> [...batch_dims, out_features]
12161211
12171212 # Get dimensions
1213+ weight_shape = op .Shape (weight )
12181214 out_features = op .Gather (weight_shape , 0 , axis = 0 )
12191215 in1_features = op .Gather (weight_shape , 1 , axis = 0 )
12201216 in2_features = op .Gather (weight_shape , 2 , axis = 0 )
12211217
1222- # Get batch dimensions (everything except the last dimension)
1223- input1_rank = Rank (input1 )
1224- batch_dims = op .Slice (input1_shape , [0 ], [input1_rank - 1 ])
1225- batch_size = op .ReduceProd (batch_dims , keepdims = False )
1218+ # Step 1: Reshape weight for matrix multiplication
1219+ # weight: [out_features, in1_features, in2_features] -> [out_features * in1_features, in2_features]
1220+ weight_2d = op .Reshape (weight , op .Concat ([op .Mul (out_features , in1_features )], [in2_features ], axis = 0 ))
12261221
1227- # Create outer product: input1[..., i] * input2[..., j] -> [..., i, j]
1228- # Reshape inputs to [batch_size, features] for easier handling
1229- input1_2d = op .Reshape (input1 , op .Concat ([batch_size ], [in1_features ], axis = 0 ))
1222+ # Get input2 shape for transpose
1223+ input2_shape = op .Shape (input2 )
1224+ input2_rank = Rank (input2 )
1225+ batch_dims = op .Slice (input2_shape , [0 ], [input2_rank - 1 ])
1226+
1227+ # Reshape input2 to 2D: [...batch_dims, in2_features] -> [batch_size, in2_features]
1228+ batch_size = op .ReduceProd (batch_dims , keepdims = False )
12301229 input2_2d = op .Reshape (input2 , op .Concat ([batch_size ], [in2_features ], axis = 0 ))
12311230
1232- # Create outer product using unsqueeze and broadcasting
1233- input1_expanded = op .Unsqueeze (input1_2d , axes = [2 ]) # [batch_size, in1_features, 1]
1234- input2_expanded = op .Unsqueeze (input2_2d , axes = [1 ]) # [batch_size, 1, in2_features]
1231+ # Transpose input2_2d: [batch_size, in2_features] -> [in2_features, batch_size]
1232+ input2_t = op .Transpose (input2_2d , perm = [1 , 0 ])
1233+
1234+ # First MatMul: weight_2d @ input2_t
1235+ # [out_features * in1_features, in2_features] @ [in2_features, batch_size]
1236+ # -> [out_features * in1_features, batch_size]
1237+ temp = op .MatMul (weight_2d , input2_t )
12351238
1236- # Outer product via broadcasting multiplication
1237- outer_product = op .Mul ( input1_expanded , input2_expanded ) # [batch_size, in1_features, in2_features]
1239+ # Reshape temp: [out_features * in1_features, batch_size] -> [out_features, in1_features, batch_size]
1240+ temp = op .Reshape ( temp , op . Concat ([ out_features ], [ in1_features ], [ batch_size ], axis = 0 ))
12381241
1239- # Flatten the feature dimensions
1240- features_total = op .Mul (in1_features , in2_features )
1241- outer_flat = op .Reshape (outer_product , op .Concat ([batch_size ], [features_total ], axis = 0 ))
1242+ # Transpose temp for second matmul: [out_features, in1_features, batch_size] -> [batch_size, in1_features, out_features]
1243+ temp_t = op .Transpose (temp , perm = [2 , 1 , 0 ])
1244+
1245+ # Step 2: Prepare input1 for second MatMul
1246+ # Reshape input1 to 2D: [...batch_dims, in1_features] -> [batch_size, in1_features]
1247+ input1_2d = op .Reshape (input1 , op .Concat ([batch_size ], [in1_features ], axis = 0 ))
12421248
1243- # Reshape weight to 2D : [out_features , in1_features * in2_features ]
1244- weight_2d = op .Reshape ( weight , op . Concat ([ out_features ], [ features_total ], axis = 0 ) )
1249+ # Expand input1 for batch matrix multiplication : [batch_size , in1_features] -> [batch_size, 1, in1_features ]
1250+ input1_expanded = op .Unsqueeze ( input1_2d , axes = [ 1 ] )
12451251
1246- # Transpose weight for MatMul: [in1_features * in2_features, out_features]
1247- weight_t = op .Transpose (weight_2d , perm = [1 , 0 ])
1252+ # Second MatMul: input1_expanded @ temp_t
1253+ # [batch_size, 1, in1_features] @ [batch_size, in1_features, out_features]
1254+ # -> [batch_size, 1, out_features]
1255+ result = op .MatMul (input1_expanded , temp_t )
12481256
1249- # Matrix multiplication: [batch_size, out_features]
1250- result = op .MatMul ( outer_flat , weight_t )
1257+ # Remove singleton dimension: [batch_size, 1, out_features] -> [batch_size, out_features]
1258+ result = op .Squeeze ( result , axes = [ 1 ] )
12511259
1252- # Reshape back to original batch dimensions + out_features
1260+ # Reshape back to original batch dimensions
12531261 output_shape = op .Concat (batch_dims , [out_features ], axis = 0 )
12541262 result = op .Reshape (result , output_shape )
12551263
0 commit comments