File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
onnxscript/function_libs/torch_lib/ops Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -825,10 +825,15 @@ def aten_leaky_relu_backward(
825
825
def aten_linear (input : TFloat , weight : TFloat , bias : Optional [TFloat ] = None ) -> TFloat :
826
826
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""
827
827
828
- if len (input .shape ) == 2 :
828
+ if len (input .shape ) == 2 and len ( weight . shape ) == 2 :
829
829
# Use Gemm for the rank 2 input
830
830
return op .Gemm (input , weight , bias , transB = True )
831
- weight_transposed = op .Transpose (weight , perm = [1 , 0 ])
831
+ if len (weight .shape ) == 1 :
832
+ # In rare cases the weight can be 1d
833
+ weight_transposed = op .Unsqueeze (weight , [1 ])
834
+ else :
835
+ assert len (weight .shape ) == 2
836
+ weight_transposed = op .Transpose (weight , perm = [1 , 0 ])
832
837
mul = op .MatMul (input , weight_transposed )
833
838
if bias is None :
834
839
return mul
You can’t perform that action at this time.
0 commit comments