Skip to content

Commit 5e2916b

Browse files
authored
tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors: * Fix "expected all tensors to be on the same device" error * Fix "can't convert device type tensor to numpy" According to pytorch documentation torch.Tensor.numpy(force=False) performs conversion only if tensor is on CPU (plus few other restrictions) which is not the case. For our case we need force=True since we just need a data and don't care about tensors coherency. Fixes: #33517 See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 52daf4e commit 5e2916b

File tree

8 files changed

+29
-26
lines changed

8 files changed

+29
-26
lines changed

src/transformers/modeling_flax_pytorch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
163163
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
164164
if v.dtype == bfloat16:
165165
v = v.float()
166-
pt_state_dict[k] = v.numpy()
166+
pt_state_dict[k] = v.cpu().numpy()
167167

168168
model_prefix = flax_model.base_model_prefix
169169

tests/models/clip/test_modeling_clip.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ def test_equivalence_pt_to_flax(self):
848848
with self.subTest(model_class.__name__):
849849
# load PyTorch class
850850
pt_model = model_class(config).eval()
851+
pt_model.to(torch_device)
851852
# Flax models don't use the `use_cache` option and cache is not returned as a default.
852853
# So we disable `use_cache` here for PyTorch model.
853854
pt_model.config.use_cache = False
@@ -881,7 +882,7 @@ def test_equivalence_pt_to_flax(self):
881882
fx_outputs = fx_model(**fx_inputs).to_tuple()
882883
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
883884
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
884-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
885+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
885886

886887
with tempfile.TemporaryDirectory() as tmpdirname:
887888
pt_model.save_pretrained(tmpdirname)
@@ -892,7 +893,7 @@ def test_equivalence_pt_to_flax(self):
892893
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
893894
)
894895
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
895-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
896+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
896897

897898
# overwrite from common since FlaxCLIPModel returns nested output
898899
# which is not supported in the common test
@@ -921,6 +922,7 @@ def test_equivalence_flax_to_pt(self):
921922
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
922923

923924
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
925+
pt_model.to(torch_device)
924926

925927
# make sure weights are tied in PyTorch
926928
pt_model.tie_weights()
@@ -940,11 +942,12 @@ def test_equivalence_flax_to_pt(self):
940942
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
941943

942944
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
943-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
945+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
944946

945947
with tempfile.TemporaryDirectory() as tmpdirname:
946948
fx_model.save_pretrained(tmpdirname)
947949
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
950+
pt_model_loaded.to(torch_device)
948951

949952
with torch.no_grad():
950953
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
@@ -953,7 +956,7 @@ def test_equivalence_flax_to_pt(self):
953956
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
954957
)
955958
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
956-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
959+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
957960

958961
@slow
959962
def test_model_from_pretrained(self):

tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
297297

298298
# prepare inputs
299299
flax_inputs = inputs_dict
300-
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
300+
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
301301

302302
with torch.no_grad():
303303
pt_outputs = pt_model(**pt_inputs).to_tuple()
304304

305305
fx_outputs = fx_model(**inputs_dict).to_tuple()
306306
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
307307
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
308-
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
308+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
309309

310310
# PT -> Flax
311311
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -315,7 +315,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
315315
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
316316
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
317317
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
318-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
318+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
319319

320320
# Flax -> PT
321321
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -330,7 +330,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
330330

331331
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
332332
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
333-
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
333+
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
334334

335335
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
336336
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

tests/models/informer/test_modeling_informer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):
170170

171171
embed_positions = InformerSinusoidalPositionalEmbedding(
172172
config.context_length + config.prediction_length, config.d_model
173-
)
173+
).to(torch_device)
174174
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
175175
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))
176176

tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
412412

413413
# prepare inputs
414414
flax_inputs = inputs_dict
415-
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
415+
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
416416

417417
with torch.no_grad():
418418
pt_outputs = pt_model(**pt_inputs).to_tuple()
419419

420420
fx_outputs = fx_model(**inputs_dict).to_tuple()
421421
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
422422
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
423-
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
423+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
424424

425425
# PT -> Flax
426426
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -430,7 +430,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
430430
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
431431
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
432432
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
433-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
433+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
434434

435435
# Flax -> PT
436436
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -445,7 +445,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
445445

446446
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
447447
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
448-
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
448+
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
449449

450450
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
451451
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
241241

242242
# prepare inputs
243243
flax_inputs = inputs_dict
244-
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
244+
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
245245

246246
with torch.no_grad():
247247
pt_outputs = pt_model(**pt_inputs).to_tuple()
248248

249249
fx_outputs = fx_model(**inputs_dict).to_tuple()
250250
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
251251
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
252-
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
252+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
253253

254254
# PT -> Flax
255255
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -259,7 +259,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
259259
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
260260
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
261261
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
262-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
262+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
263263

264264
# Flax -> PT
265265
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -274,7 +274,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
274274

275275
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
276276
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
277-
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
277+
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
278278

279279
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
280280
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
160160

161161
# prepare inputs
162162
flax_inputs = inputs_dict
163-
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
163+
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
164164

165165
with torch.no_grad():
166166
pt_outputs = pt_model(**pt_inputs).to_tuple()
167167

168168
fx_outputs = fx_model(**inputs_dict).to_tuple()
169169
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
170170
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
171-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
171+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
172172

173173
# PT -> Flax
174174
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -178,7 +178,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
178178
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
179179
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
180180
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
181-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
181+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
182182

183183
# Flax -> PT
184184
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -193,7 +193,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
193193

194194
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
195195
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
196-
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
196+
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
197197

198198
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
199199
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
179179
# prepare inputs
180180
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
181181
pt_inputs = inputs_dict
182-
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
182+
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
183183

184184
with torch.no_grad():
185185
pt_outputs = pt_model(**pt_inputs).to_tuple()
186186

187187
fx_outputs = fx_model(**flax_inputs).to_tuple()
188188
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
189189
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
190-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
190+
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
191191

192192
# PT -> Flax
193193
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -197,7 +197,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
197197
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
198198
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
199199
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
200-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
200+
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
201201

202202
# Flax -> PT
203203
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -212,7 +212,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
212212

213213
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
214214
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
215-
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
215+
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
216216

217217
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
218218
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

0 commit comments

Comments
 (0)