Skip to content

Commit dac17a1

Browse files
sarckkfacebook-github-bot
authored andcommitted
Improve CPU benchmark for KJT
Summary: Add benchmark for KJT methods: - `permute` - `to_dict` - `split` - `__getitem__` - `dist_init` Reviewed By: gnahzg Differential Revision: D57314675
1 parent 7924c6f commit dac17a1

File tree

3 files changed

+274
-126
lines changed

3 files changed

+274
-126
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,20 @@ class BenchmarkResult:
106106
max_mem_allocated: List[int] # megabytes
107107
rank: int = -1
108108

109-
def runtime_percentile(self, percentile: int = 50) -> torch.Tensor:
109+
def runtime_percentile(
110+
self, percentile: int = 50, interpolation: str = "nearest"
111+
) -> torch.Tensor:
110112
return torch.quantile(
111-
self.elapsed_time, percentile / 100.0, interpolation="nearest"
113+
self.elapsed_time,
114+
percentile / 100.0,
115+
interpolation=interpolation,
112116
)
113117

114-
def max_mem_percentile(self, percentile: int = 50) -> torch.Tensor:
118+
def max_mem_percentile(
119+
self, percentile: int = 50, interpolation: str = "nearest"
120+
) -> torch.Tensor:
115121
max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float)
116-
return torch.quantile(max_mem, percentile / 100.0, interpolation="nearest")
122+
return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation)
117123

118124

119125
class ECWrapper(torch.nn.Module):

torchrec/distributed/test_utils/test_model.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def generate(
6969
long_indices: bool = True,
7070
tables_pooling: Optional[List[int]] = None,
7171
weighted_tables_pooling: Optional[List[int]] = None,
72+
randomize_indices: bool = True,
73+
device: Optional[torch.device] = None,
7274
) -> Tuple["ModelInput", List["ModelInput"]]:
7375
"""
7476
Returns a global (single-rank training) batch
@@ -132,15 +134,16 @@ def _validate_pooling_factor(
132134
idlist_pooling_factor[idx],
133135
idlist_pooling_factor[idx] / 10,
134136
[batch_size * world_size],
137+
device=device,
135138
),
136-
torch.tensor(1.0),
139+
torch.tensor(1.0, device=device),
137140
).int()
138141
else:
139142
lengths_ = torch.abs(
140-
torch.randn(batch_size * world_size) + pooling_avg
143+
torch.randn(batch_size * world_size, device=device) + pooling_avg,
141144
).int()
142145
if variable_batch_size:
143-
lengths = torch.zeros(batch_size * world_size).int()
146+
lengths = torch.zeros(batch_size * world_size, device=device).int()
144147
for r in range(world_size):
145148
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
146149
lengths_[
@@ -150,12 +153,20 @@ def _validate_pooling_factor(
150153
else:
151154
lengths = lengths_
152155
num_indices = cast(int, torch.sum(lengths).item())
153-
indices = torch.randint(
154-
0,
155-
ind_range,
156-
(num_indices,),
157-
dtype=torch.long if long_indices else torch.int32,
158-
)
156+
if randomize_indices:
157+
indices = torch.randint(
158+
0,
159+
ind_range,
160+
(num_indices,),
161+
dtype=torch.long if long_indices else torch.int32,
162+
device=device,
163+
)
164+
else:
165+
indices = torch.zeros(
166+
(num_indices),
167+
dtype=torch.long if long_indices else torch.int32,
168+
device=device,
169+
)
159170
global_idlist_lengths.append(lengths)
160171
global_idlist_indices.append(indices)
161172
global_idlist_kjt = KeyedJaggedTensor(
@@ -167,15 +178,15 @@ def _validate_pooling_factor(
167178
for idx in range(len(idscore_ind_ranges)):
168179
ind_range = idscore_ind_ranges[idx]
169180
lengths_ = torch.abs(
170-
torch.randn(batch_size * world_size)
181+
torch.randn(batch_size * world_size, device=device)
171182
+ (
172183
idscore_pooling_factor[idx]
173184
if idscore_pooling_factor
174185
else pooling_avg
175186
)
176187
).int()
177188
if variable_batch_size:
178-
lengths = torch.zeros(batch_size * world_size).int()
189+
lengths = torch.zeros(batch_size * world_size, device=device).int()
179190
for r in range(world_size):
180191
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
181192
lengths_[
@@ -185,13 +196,21 @@ def _validate_pooling_factor(
185196
else:
186197
lengths = lengths_
187198
num_indices = cast(int, torch.sum(lengths).item())
188-
indices = torch.randint(
189-
0,
190-
ind_range,
191-
(num_indices,),
192-
dtype=torch.long if long_indices else torch.int32,
193-
)
194-
weights = torch.rand((num_indices,))
199+
if randomize_indices:
200+
indices = torch.randint(
201+
0,
202+
ind_range,
203+
(num_indices,),
204+
dtype=torch.long if long_indices else torch.int32,
205+
device=device,
206+
)
207+
else:
208+
indices = torch.zeros(
209+
(num_indices),
210+
dtype=torch.long if long_indices else torch.int32,
211+
device=device,
212+
)
213+
weights = torch.rand((num_indices,), device=device)
195214
global_idscore_lengths.append(lengths)
196215
global_idscore_indices.append(indices)
197216
global_idscore_weights.append(weights)
@@ -206,8 +225,10 @@ def _validate_pooling_factor(
206225
else None
207226
)
208227

209-
global_float = torch.rand((batch_size * world_size, num_float_features))
210-
global_label = torch.rand(batch_size * world_size)
228+
global_float = torch.rand(
229+
(batch_size * world_size, num_float_features), device=device
230+
)
231+
global_label = torch.rand(batch_size * world_size, device=device)
211232

212233
# Split global batch into local batches.
213234
local_inputs = []

0 commit comments

Comments
 (0)