Skip to content

Commit ec16b3c

Browse files
gramalingambmehta001
authored andcommitted
Optimize away zero-length concat operands (microsoft#2150)
We optimize `Concat (x1, x2, x3)` if one or more the concat operands has zero length along the concatenated axis-dimension. This pattern shows up, for example, in Phi models. See [this line](https://github.com/huggingface/transformers/blob/786d9c5ed920a099573ea7b6dbf265f1aeb32fc0/src/transformers/models/phi3/modeling_phi3.py#L152) in the implementation of partial-rotary-embedding: ```py q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) ``` In the special case of total-rotary-embedding, the second operand `q_pass` of the concat is empty. This also interferes with the pattern-matching for GQA in the generated graph. Optimizing the redundant Concat away will help with GQA fusion as well. Handle the edge case when all operands have zero size.
1 parent 72c673a commit ec16b3c

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,21 +558,59 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
558558
@register("Concat")
559559
def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
560560
"""Replace a Concat node with a single input by Identity"""
561+
562+
# Replace Concat(x) by Identity(x)
561563
inputs = node.inputs
562564
if len(inputs) == 1:
563565
return op.Identity(inputs[0])
564-
# Track value of tensors that carry a shape value:
565-
output = node.outputs[0]
566-
if output is None:
566+
567+
axis = _get_int_attribute(node, "axis", None)
568+
if axis is None:
567569
return None
570+
571+
# Eliminate zero-length operands from Concat
572+
def has_zero_size(operand: ir.Value | None) -> bool:
573+
if operand is None:
574+
return False # Invalid model
575+
if (shape := operand.shape) is None:
576+
return False
577+
try:
578+
# We have already checked that axis is an int value (!= None)
579+
dim_size = shape[axis] # type: ignore[index]
580+
except IndexError:
581+
return False
582+
return dim_size == 0 # return False if symbolic or None or non-zero int value
583+
584+
new_inputs = [x for x in inputs if not has_zero_size(x)]
585+
if len(new_inputs) != len(inputs):
586+
if new_inputs:
587+
# Remove zero-length operands from Concat
588+
logger.debug(
589+
"Concat: removing zero-length operand(s) %s => %s", inputs, new_inputs
590+
)
591+
return op.Concat(*new_inputs, axis=axis)
592+
elif inputs:
593+
# All operands are zero-length. Concat is a no-op, but we need to use one of the
594+
# inputs to get the other dimensions correct:
595+
logger.debug("Concat: removing all zero-length operands %s", inputs)
596+
return op.Identity(inputs[0])
597+
else:
598+
# No inputs: invalid model.
599+
return None
600+
601+
# Track value of tensors that carry a shape value:
602+
568603
# Check axis attribute is 0
569-
axis = _get_int_attribute(node, "axis", None)
604+
570605
if axis != 0:
571606
return None
572607
shapes = [state.get_shape_value(input) for input in inputs]
573608
if any(shape is None for shape in shapes):
574609
return None
575610
concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
611+
output = node.outputs[0]
612+
if output is None:
613+
return None
576614
state.set_sym_value(output, concatenated)
577615
return None
578616

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,44 @@ def test_concat_identity(self):
479479
self.assertEqual(len(optimized.graph), 1)
480480
self.assertEqual(optimized.graph.node(0).op_type, "Identity")
481481

482+
def test_concat_zero_length(self):
483+
model = """
484+
<ir_version: 7, opset_import: [ "" : 17]>
485+
agraph (float[N, 128] x1, float[N, 0] x2, float[N, 128] x3) => (float[N, M] z)
486+
{
487+
z = Concat <axis=-1> (x1, x2, x3)
488+
}
489+
"""
490+
optimized = self._fold(model)
491+
self.assertEqual(len(optimized.graph), 1)
492+
self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1", "x3"])
493+
494+
def test_concat_zero_length_identity(self):
495+
model = """
496+
<ir_version: 7, opset_import: [ "" : 17]>
497+
agraph (float[N, 0] x1, float[N, 128] x2, float[N, 0] x3) => (float[N, M] z)
498+
{
499+
z = Concat <axis=-1> (x1, x2, x3)
500+
}
501+
"""
502+
optimized = self._fold(model)
503+
self.assertEqual(len(optimized.graph), 1)
504+
self.assertEqual(optimized.graph.node(0).op_type, "Identity")
505+
self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x2"])
506+
507+
def test_concat_zero_length_output(self):
508+
model = """
509+
<ir_version: 7, opset_import: [ "" : 17]>
510+
agraph (float[N, 0] x1, float[N, 0] x2, float[N, 0] x3) => (float[N, M] z)
511+
{
512+
z = Concat <axis=-1> (x1, x2, x3)
513+
}
514+
"""
515+
optimized = self._fold(model)
516+
self.assertEqual(len(optimized.graph), 1)
517+
self.assertEqual(optimized.graph.node(0).op_type, "Identity")
518+
self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1"])
519+
482520
def test_expand_identity(self):
483521
model = """
484522
<ir_version: 7, opset_import: [ "" : 17]>

0 commit comments

Comments
 (0)