Skip to content

Add VBE inverse indices pass through support to KJT concat #2366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,9 @@ def concat(
length_list: List[torch.Tensor] = []
stride_per_key_per_rank: List[List[int]] = []
stride: Optional[int] = None
inv_idx_keys: List[str] = []
inv_idx_tensors: List[torch.Tensor] = []

variable_stride_per_key_list = [
kjt.variable_stride_per_key() for kjt in kjt_list
]
Expand All @@ -1664,7 +1667,7 @@ def concat(
), "variable stride per key must be consistent for all KJTs"
variable_stride_per_key = all(variable_stride_per_key_list)

for kjt in kjt_list:
for i, kjt in enumerate(kjt_list):
curr_is_weighted: bool = kjt.weights_or_none() is not None
if is_weighted != curr_is_weighted:
raise ValueError("Can't merge weighted KJT with unweighted KJT")
Expand All @@ -1686,6 +1689,16 @@ def concat(
stride = kjt.stride()
else:
assert stride == kjt.stride(), "strides must be consistent for all KJTs"
if kjt.inverse_indices_or_none() is not None:
assert (
len(inv_idx_tensors) == i
), "inverse indices must be consistent for all KJTs"
inv_idx_keys += kjt.inverse_indices()[0]
inv_idx_tensors.append(kjt.inverse_indices()[1])
else:
assert (
len(inv_idx_tensors) == 0
), "inverse indices must be consistent for all KJTs"

return KeyedJaggedTensor(
keys=keys,
Expand All @@ -1697,6 +1710,11 @@ def concat(
stride_per_key_per_rank if variable_stride_per_key else None
),
length_per_key=length_per_key if has_length_per_key else None,
inverse_indices=(
(inv_idx_keys, torch.cat(inv_idx_tensors))
if len(inv_idx_tensors) == len(kjt_list)
else None
),
)

@staticmethod
Expand Down
Loading