Skip to content

Commit 3c1f215

Browse files
zejunhfacebook-github-bot
authored andcommitted
support permute_multi_embedding_function on torch.export (#3897)
Summary: X-link: facebookresearch/FBGEMM#988 Pull Request resolved: #3897 support fbgemm.permute_multi_embedding_function.default for LPV model register with separate entry in kernel.yaml for graph mode lowering added fp16 ref kernel (sitecao) Reviewed By: StellarrZ Differential Revision: D71821354 fbshipit-source-id: daaf882f70c102f133d1a64be96386a8acecdde9
1 parent f8322d7 commit 3c1f215

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
torch.ops.load_library(
4242
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
4343
)
44+
torch.ops.load_library(
45+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
46+
)
4447

4548

4649
import torch.utils._pytree as pytree
@@ -1122,6 +1125,24 @@ def generic_histogram_binning_calibration_by_feature(
11221125
)
11231126

11241127

1128+
def permute_multi_embedding_function_impl_abstract(
1129+
pooled_embs: List[Tensor],
1130+
permutes: Tensor,
1131+
in_shapes: Tensor,
1132+
out_shapes: Tensor,
1133+
out_lengths: List[int],
1134+
reverse: bool = False,
1135+
) -> List[Tensor]:
1136+
out_dtype = pooled_embs[0].dtype
1137+
bs = pooled_embs[0].shape[0]
1138+
torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
1139+
1140+
output = []
1141+
for i in range(len(out_lengths)):
1142+
output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
1143+
return output
1144+
1145+
11251146
def _setup() -> None:
11261147
# pyre-ignore[16]
11271148
_setup.done = getattr(_setup, "done", False)
@@ -1259,6 +1280,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12591280
"fbgemm::generic_histogram_binning_calibration_by_feature",
12601281
generic_histogram_binning_calibration_by_feature,
12611282
)
1283+
impl_abstract(
1284+
"fbgemm::permute_multi_embedding_function",
1285+
permute_multi_embedding_function_impl_abstract,
1286+
)
12621287
impl_abstract(
12631288
"fbgemm::FloatToHFP8Quantized",
12641289
float_to_hfp8_quantized,

fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ std::vector<Tensor> regroup_keyed_tensor_meta(
344344
} // namespace fbgemm_gpu
345345

346346
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
347+
m.set_python_module("fbgemm_gpu.sparse_ops");
347348
// register the forward function for internal (autograd) usage
348349
m.def(
349350
"permute_multi_embedding_function(Tensor[] pooled_embs, Tensor permutes, Tensor in_shapes, Tensor out_shapes, SymInt[] out_lengths, bool reverse=False) -> Tensor[]");

0 commit comments

Comments
 (0)