Skip to content

Commit a2be078

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix missing device in fake operator
Summary: # context * when we have the input tensors on the meta device, it calls the fake operator * however the device information is unintentionally missed so the output tensor is on the default device (cpu) * this is an incorrect behavior Reviewed By: gnahzg, iamzainhuda Differential Revision: D57077813
1 parent 144fba9 commit a2be078

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,23 @@ def test_dynamic_shape_ebc(self) -> None:
295295
self.assertEqual(eager_out[i].shape, tensor.shape)
296296
assert torch.allclose(eager_out[i], tensor)
297297

298+
def test_ir_custom_op_device(self) -> None:
299+
model = self.generate_model()
300+
model.fpebc1 = copy.deepcopy(model.ebc1)
301+
model.fpebc2 = copy.deepcopy(model.ebc1)
302+
feature1 = KeyedJaggedTensor.from_offsets_sync(
303+
keys=["f1", "f2", "f3"],
304+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
305+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
306+
)
307+
308+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
309+
for device in ["cpu", "cuda", "meta"]:
310+
device = torch.device(device)
311+
outputs = model.to(device)(feature1.to(device))
312+
for output in outputs:
313+
self.assertEqual(output.device.type, device.type)
314+
298315
def test_deserialized_device(self) -> None:
299316
model = self.generate_model()
300317
id_list_features = KeyedJaggedTensor.from_offsets_sync(

torchrec/ir/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from torch import nn
19-
from torch.export import Dim, ExportedProgram, ShapesCollection
19+
from torch.export import Dim, ShapesCollection
2020
from torch.export.dynamic_shapes import _Dim as DIM
2121
from torchrec.ir.types import SerializerInterface
2222
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -37,16 +37,21 @@ def ir_custom_op_impl(
3737
if t is not None:
3838
device = t.device
3939
break
40-
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim})")
40+
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}")
4141
return torch.empty(batch_size, dim, device=device)
4242

4343

4444
@torch.library.register_fake("torchrec::ir_custom_op")
4545
def ir_custom_op_fake(
4646
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
4747
) -> torch.Tensor:
48-
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim})")
49-
return torch.empty(batch_size, dim)
48+
device = None
49+
for t in tensors:
50+
if t is not None:
51+
device = t.device
52+
break
53+
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}")
54+
return torch.empty(batch_size, dim, device=device)
5055

5156

5257
def encapsulate_ir_modules(

0 commit comments

Comments
 (0)