Skip to content

Commit 159e5bc

Browse files
[torchlib] Add torchlib operator for glu (#1695)
Fix #1665 --------- Co-authored-by: Justin Chu <[email protected]>
1 parent be00339 commit 159e5bc

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,13 @@ def aten_gelu_backward(
565565
raise NotImplementedError()
566566

567567

568-
def aten_glu(self: TensorType, dim: int = -1) -> TensorType:
568+
@torch_op("aten::glu", traceable=True)
569+
def aten_glu(self: TFloat, dim: int = -1) -> TFloat:
569570
"""glu(Tensor self, int dim=-1) -> Tensor"""
570571

571-
raise NotImplementedError()
572+
first, second = op.Split(self, axis=dim, num_outputs=2)
573+
result = op.Mul(first, op.Sigmoid(second))
574+
return result
572575

573576

574577
def aten_glu_backward(grad_output: TensorType, self: TensorType, dim: int) -> TensorType:

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,7 @@ def _where_input_wrangler(
18681868
nn_ops.aten_gelu,
18691869
tolerance={torch.float16: (8e-2, 1e-4)},
18701870
),
1871+
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
18711872
TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip(
18721873
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
18731874
matcher=lambda sample: len(sample.args) != 1,

0 commit comments

Comments
 (0)