Skip to content

Commit 8a742c0

Browse files
authored
[torchlib] Update linear implementation to support 1d weights (#2340)
It is possible when users call `F.linear()` directly in PyTorch.
1 parent a3ce145 commit 8a742c0

File tree

1 file changed

+7
-2
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+7
-2
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,10 +825,15 @@ def aten_leaky_relu_backward(
825825
def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat:
826826
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""
827827

828-
if len(input.shape) == 2:
828+
if len(input.shape) == 2 and len(weight.shape) == 2:
829829
# Use Gemm for the rank 2 input
830830
return op.Gemm(input, weight, bias, transB=True)
831-
weight_transposed = op.Transpose(weight, perm=[1, 0])
831+
if len(weight.shape) == 1:
832+
# In rare cases the weight can be 1d
833+
weight_transposed = op.Unsqueeze(weight, [1])
834+
else:
835+
assert len(weight.shape) == 2
836+
weight_transposed = op.Transpose(weight, perm=[1, 0])
832837
mul = op.MatMul(input, weight_transposed)
833838
if bias is None:
834839
return mul

0 commit comments

Comments
 (0)