Commit 3304438
Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. (#2949)
Summary:
Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.
Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.
Differential Revision: D742730831 parent 3e2737e commit 3304438
File tree
2 files changed
+17
-0
lines changed- torchrec/sparse
- tests
2 files changed
+17
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1097 | 1097 | | |
1098 | 1098 | | |
1099 | 1099 | | |
| 1100 | + | |
1100 | 1101 | | |
1101 | 1102 | | |
1102 | 1103 | | |
1103 | 1104 | | |
1104 | 1105 | | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
1105 | 1109 | | |
1106 | 1110 | | |
1107 | 1111 | | |
| |||
2165 | 2169 | | |
2166 | 2170 | | |
2167 | 2171 | | |
| 2172 | + | |
2168 | 2173 | | |
2169 | 2174 | | |
2170 | 2175 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1017 | 1017 | | |
1018 | 1018 | | |
1019 | 1019 | | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
1020 | 1032 | | |
1021 | 1033 | | |
1022 | 1034 | | |
| |||
0 commit comments