Skip to content

Commit a216d23

Browse files
vfdev-5xuhancn
authored andcommitted
Fixed cat uint8 lowering (pytorch#112753)
Description: - Fixed cat uint8 lowering Otherwise, it gives the following issue on the repro code: ```python def func(x): batch_shape = x.shape[:1] out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1) return out cfunc = torch.compile(func) x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8) out = cfunc(x) ``` Error message: ``` File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr> if all(len(input.layout.size) == 4 for input in inputs): File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__ fn = getattr(self.data, name) torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout' target: aten.cat.default args[0]: [TensorBox( ExpandView(data=StorageBox( ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise( 'cpu', torch.uint8, def inner_fn(index): _ = index tmp0 = ops.constant(0, torch.uint8) return tmp0 , ranges=[1], origin_node=full, origins={full} )) ), size=[3, 1]) ), TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1])) ))] args[1]: 1 Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information ``` Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056 Pull Request resolved: pytorch#112753 Approved by: https://github.com/peterbell10
1 parent 933a54e commit a216d23

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3455,6 +3455,17 @@ def fn(a):
34553455
(torch.randn([1, 3, 3, 16]).to(memory_format=torch.channels_last),),
34563456
)
34573457

3458+
def test_cat_uint8(self):
3459+
def fn(x):
3460+
batch_shape = x.shape[:1]
3461+
out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
3462+
return out
3463+
3464+
self.common(
3465+
fn,
3466+
(torch.randint(0, 256, size=(3, 255), dtype=torch.uint8),),
3467+
)
3468+
34583469
def test_cat_empty(self):
34593470
def fn_2(*tensors):
34603471
return torch.cat(tensors)

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def run(*ex, **kwargs):
195195
("cpu", "cuda")
196196
),
197197
"test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda")),
198+
"test_cat_uint8_dynamic_shapes": TestFailure(
199+
("cpu",)
200+
), # cat on uint8 input is using aten fallback on cpu
198201
#
199202
# Tests not using 'common' or directly calling 'assertEqual':
200203
#

torch/_inductor/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ def cat(inputs, dim=0):
10341034
# code gen with uint8 data type directly.
10351035
for input in inputs:
10361036
input.realize()
1037-
if all(len(input.layout.size) == 4 for input in inputs):
1037+
if all(len(input.get_size()) == 4 for input in inputs):
10381038
inputs, _ = require_channels_last(aten.cat, *inputs)
10391039
return fallback_handler(aten.cat.default)(inputs, dim)
10401040

0 commit comments

Comments
 (0)