Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/janus/modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,8 +1419,8 @@ def generate(
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
)

model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
if "attention_mask" in model_inputs:
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)

outputs = self.model.language_model(
Expand Down
6 changes: 6 additions & 0 deletions tests/models/janus/test_modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ def test_model_generate_images(self):
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
],
("xpu", None): [
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
],
}
)
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
Expand Down