@@ -69,6 +69,8 @@ def generate(
69
69
long_indices : bool = True ,
70
70
tables_pooling : Optional [List [int ]] = None ,
71
71
weighted_tables_pooling : Optional [List [int ]] = None ,
72
+ randomize_indices : bool = True ,
73
+ device : Optional [torch .device ] = None ,
72
74
) -> Tuple ["ModelInput" , List ["ModelInput" ]]:
73
75
"""
74
76
Returns a global (single-rank training) batch
@@ -132,15 +134,16 @@ def _validate_pooling_factor(
132
134
idlist_pooling_factor [idx ],
133
135
idlist_pooling_factor [idx ] / 10 ,
134
136
[batch_size * world_size ],
137
+ device = device ,
135
138
),
136
- torch .tensor (1.0 ),
139
+ torch .tensor (1.0 , device = device ),
137
140
).int ()
138
141
else :
139
142
lengths_ = torch .abs (
140
- torch .randn (batch_size * world_size ) + pooling_avg
143
+ torch .randn (batch_size * world_size , device = device ) + pooling_avg ,
141
144
).int ()
142
145
if variable_batch_size :
143
- lengths = torch .zeros (batch_size * world_size ).int ()
146
+ lengths = torch .zeros (batch_size * world_size , device = device ).int ()
144
147
for r in range (world_size ):
145
148
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]] = (
146
149
lengths_ [
@@ -150,12 +153,20 @@ def _validate_pooling_factor(
150
153
else :
151
154
lengths = lengths_
152
155
num_indices = cast (int , torch .sum (lengths ).item ())
153
- indices = torch .randint (
154
- 0 ,
155
- ind_range ,
156
- (num_indices ,),
157
- dtype = torch .long if long_indices else torch .int32 ,
158
- )
156
+ if randomize_indices :
157
+ indices = torch .randint (
158
+ 0 ,
159
+ ind_range ,
160
+ (num_indices ,),
161
+ dtype = torch .long if long_indices else torch .int32 ,
162
+ device = device ,
163
+ )
164
+ else :
165
+ indices = torch .zeros (
166
+ (num_indices ),
167
+ dtype = torch .long if long_indices else torch .int32 ,
168
+ device = device ,
169
+ )
159
170
global_idlist_lengths .append (lengths )
160
171
global_idlist_indices .append (indices )
161
172
global_idlist_kjt = KeyedJaggedTensor (
@@ -167,15 +178,15 @@ def _validate_pooling_factor(
167
178
for idx in range (len (idscore_ind_ranges )):
168
179
ind_range = idscore_ind_ranges [idx ]
169
180
lengths_ = torch .abs (
170
- torch .randn (batch_size * world_size )
181
+ torch .randn (batch_size * world_size , device = device )
171
182
+ (
172
183
idscore_pooling_factor [idx ]
173
184
if idscore_pooling_factor
174
185
else pooling_avg
175
186
)
176
187
).int ()
177
188
if variable_batch_size :
178
- lengths = torch .zeros (batch_size * world_size ).int ()
189
+ lengths = torch .zeros (batch_size * world_size , device = device ).int ()
179
190
for r in range (world_size ):
180
191
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]] = (
181
192
lengths_ [
@@ -185,13 +196,21 @@ def _validate_pooling_factor(
185
196
else :
186
197
lengths = lengths_
187
198
num_indices = cast (int , torch .sum (lengths ).item ())
188
- indices = torch .randint (
189
- 0 ,
190
- ind_range ,
191
- (num_indices ,),
192
- dtype = torch .long if long_indices else torch .int32 ,
193
- )
194
- weights = torch .rand ((num_indices ,))
199
+ if randomize_indices :
200
+ indices = torch .randint (
201
+ 0 ,
202
+ ind_range ,
203
+ (num_indices ,),
204
+ dtype = torch .long if long_indices else torch .int32 ,
205
+ device = device ,
206
+ )
207
+ else :
208
+ indices = torch .zeros (
209
+ (num_indices ),
210
+ dtype = torch .long if long_indices else torch .int32 ,
211
+ device = device ,
212
+ )
213
+ weights = torch .rand ((num_indices ,), device = device )
195
214
global_idscore_lengths .append (lengths )
196
215
global_idscore_indices .append (indices )
197
216
global_idscore_weights .append (weights )
@@ -206,8 +225,10 @@ def _validate_pooling_factor(
206
225
else None
207
226
)
208
227
209
- global_float = torch .rand ((batch_size * world_size , num_float_features ))
210
- global_label = torch .rand (batch_size * world_size )
228
+ global_float = torch .rand (
229
+ (batch_size * world_size , num_float_features ), device = device
230
+ )
231
+ global_label = torch .rand (batch_size * world_size , device = device )
211
232
212
233
# Split global batch into local batches.
213
234
local_inputs = []
0 commit comments