Skip to content

Commit 0b166c1

Browse files
committed
[Rewriter] improve docstring and rename objects (#2301)
1 parent 45958da commit 0b166c1

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
broadcast_to_matmul,
2020
cast_constant_of_shape,
2121
collapse_slices,
22+
fuse_pad_into_conv,
2223
no_op,
2324
pattern,
24-
fuse_pad_into_conv,
2525
)
2626

2727
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@
1818

1919

2020
def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]:
21+
"""Converts the parameters of the ONNX Pad operator into an explicit list of values.
22+
23+
A filled list of pads will be returned following the format:
24+
[x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25+
26+
Args:
27+
pads: list of integers indicating the number of padding elements to add at
28+
the beginning and end of each axis.
29+
axes: list of axes that pads apply to.
30+
rank: value to compute the size of the filled list (2 * rank).
31+
32+
Returns:
33+
The filled list of pads.
34+
"""
2135
new_pads = []
2236
for axis in range(rank):
2337
if axis not in axes:
@@ -47,11 +61,13 @@ def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
4761
return attributes
4862

4963

50-
class _FusePadConvBase(orp.RewriteRuleClassBase):
64+
class _FuseConvPadBase(orp.RewriteRuleClassBase):
5165
"""Interface for PadConv nodes fusion."""
5266

5367
def __init__(self, as_function: bool = False):
54-
# Remove nodes is set to False to remove unused nodes after the rewrite.
68+
# Remove nodes is set to False to remove unused nodes after the rewrite, since
69+
# Pad or Conv inputs can come from constant nodes.
70+
# With remove_nodes=False these nodes are removed if these nodes are no longer needed.
5571
super().__init__(remove_nodes=False, as_function=as_function)
5672

5773
def rewrite(
@@ -89,6 +105,22 @@ def rewrite(
89105
)
90106

91107
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
108+
"""Condition to check if we need to replace the pattern.
109+
110+
If Pad inputs can be added in 'pads' attribute of the Conv operator.
111+
112+
To validate this, we need to check the following:
113+
1. `Pad<mode>` attribute has 'constant' as value
114+
2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
115+
3. 'constant_value' is equal to 0.0.
116+
4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
117+
remain unchanged).
118+
119+
If the above are true, then we don't need the reshapes.
120+
121+
Returns:
122+
True if we need to replace the pattern, False otherwise.
123+
"""
92124
del context # Unused
93125
check_result = orp.MatchResult()
94126
pad_node = pad.producer()
@@ -123,8 +155,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
123155
return check_result
124156

125157

126-
class FusePadConv(_FusePadConvBase):
127-
"""Replaces ``Pad(Conv(x))`` with ``Conv(x)``."""
158+
class FuseConvPad(_FuseConvPadBase):
159+
"""Replaces ``Conv(Pad(x))`` with ``Conv(x)``."""
128160

129161
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
130162
return op.Conv(
@@ -147,8 +179,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
147179
return check_result
148180

149181

150-
class FusePadConvInteger(FusePadConv):
151-
"""Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``."""
182+
class FuseConvIntegerPad(FuseConvPad):
183+
"""Replaces ``ConvInteger(Pad(x))`` with ``ConvInteger(x)``."""
152184

153185
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
154186
return op.ConvInteger(
@@ -195,6 +227,24 @@ def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
195227
)
196228

197229
def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
230+
"""Condition to check if we need to replace the pattern.
231+
232+
If it is possible to deduce 'pads'.
233+
234+
To validate this, we need to check the following:
235+
1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
236+
already explicit)
237+
2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
238+
3. When `Conv<auto_pad != "VALID">`:
239+
* spatial input/output shapes are static
240+
* it is possible to infer `kernel_shape` either from the `Conv` operator attribute
241+
or from the kernel input
242+
243+
If the above are true, then we don't need the reshapes.
244+
245+
Returns:
246+
True if we need to replace the pattern, False otherwise.
247+
"""
198248
del context
199249
check_result = orp.MatchResult()
200250

@@ -274,8 +324,8 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
274324

275325
normalize_pad_format_conv = NormalizePadFormatConv.rule()
276326
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule()
277-
fuse_pad_into_conv = FusePadConv.rule()
278-
fuse_pad_into_conv_integer = FusePadConvInteger.rule()
327+
fuse_pad_into_conv = FuseConvPad.rule()
328+
fuse_pad_into_conv_integer = FuseConvIntegerPad.rule()
279329

280330

281331
def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _clone_model(model: ir.Model) -> ir.Model:
2323
return ir.from_proto(ir.to_proto(model))
2424

2525

26-
class FusePadConvBaseTest(unittest.TestCase):
26+
class FuseConvPadBaseTest(unittest.TestCase):
2727
@property
2828
def rng(self):
2929
return np.random.default_rng(20250522)
@@ -89,7 +89,7 @@ def build_model(
8989
return ir_model
9090

9191

92-
class FusePadConvTest(FusePadConvBaseTest):
92+
class FuseConvPadTest(FuseConvPadBaseTest):
9393
@parameterized.parameterized.expand(
9494
[
9595
(pad_pads, const_value, axes, conv_pads, conv_auto_pad)
@@ -218,7 +218,7 @@ def test_unsupported_fuse_pad_into_conv(
218218
self.assertRegex(tracer_match.match_result.reason, err_msg)
219219

220220

221-
class FusePadConvIntegerTest(FusePadConvBaseTest):
221+
class FuseConvIntegerPadTest(FuseConvPadBaseTest):
222222
def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None):
223223
w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W")
224224
if tape is not None:
@@ -267,7 +267,7 @@ def test_fuse_pad_into_conv_integer(
267267
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
268268

269269

270-
class NormalizePadFormatTest(FusePadConvBaseTest):
270+
class NormalizePadFormatTest(FuseConvPadBaseTest):
271271
def build_model(
272272
self,
273273
input_shape: ir.Shape,

0 commit comments

Comments
 (0)