Skip to content

Commit ed5abe1

Browse files
committed
fix: export gemma3-text is now working thanks to
attention vmap patch as in here huggingface#2319
1 parent 099fde9 commit ed5abe1

5 files changed

Lines changed: 185 additions & 84 deletions

File tree

optimum/exporters/onnx/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
519519
and self.PAD_ATTENTION_MASK_TO_PAST
520520
and self.use_cache_branch is not False
521521
and "attention_mask" in dummy_inputs
522+
and not isinstance(dummy_inputs["attention_mask"], dict)
522523
):
523524
# Obtain the past sequence length from the value instead of the key (Bloom).
524525
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2]

optimum/exporters/onnx/model_patcher.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import transformers
2525
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
2626

27-
from ...utils import is_transformers_version, logging
27+
from ...utils import is_torch_version, is_transformers_version, logging
2828
from ._traceable_cache import TraceableCache
2929

3030

@@ -40,6 +40,8 @@
4040
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
4141
from transformers.integrations.sdpa_attention import repeat_kv, sdpa_attention_forward
4242
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
43+
if is_transformers_version(">=", "4.53"):
44+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, _ignore_causal_mask_sdpa, prepare_padding_mask
4345

4446

4547
if TYPE_CHECKING:
@@ -218,6 +220,72 @@ def onnx_compatible_linalg_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None
218220
return original_linal_norm(x, ord=ord, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
219221

220222

223+
def sdpa_mask_without_vmap(
224+
batch_size: int,
225+
cache_position: torch.Tensor,
226+
kv_length: int,
227+
kv_offset: int = 0,
228+
attention_mask: Optional[torch.Tensor] = None,
229+
local_size: Optional[int] = None,
230+
allow_is_causal_skip: bool = True,
231+
allow_torch_fix: bool = True,
232+
**kwargs,
233+
) -> Optional[torch.Tensor]:
234+
q_length = cache_position.shape[0]
235+
# Potentially pad the 2D mask, and slice it correctly
236+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
237+
238+
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
239+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
240+
return None
241+
242+
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
243+
# but without data-dependent slicing (i.e. torch.compile friendly)
244+
kv_arange = torch.arange(kv_length, device=cache_position.device)
245+
kv_arange += kv_offset
246+
reshaped_cache_position = cache_position.view(-1, 1)
247+
248+
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
249+
# the config through kwargs anyway, so it allows to rely on it
250+
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
251+
# but this is more efficient
252+
sliding_window = getattr(kwargs["config"], "sliding_window", None)
253+
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
254+
255+
if sliding_window is not None and chunk_size is not None:
256+
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
257+
258+
# Simplest and most efficient way to obtain a causal mask
259+
causal_mask = kv_arange <= reshaped_cache_position
260+
# If using sliding window, add the sliding mask
261+
if sliding_window is not None:
262+
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
263+
causal_mask *= sliding_mask_overlay
264+
# If using chunk attention, add the chunked mask
265+
elif chunk_size is not None:
266+
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
267+
causal_mask *= chunked_mask_overlay
268+
269+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
270+
if padding_mask is not None:
271+
causal_mask = causal_mask * padding_mask[:, None, None, :]
272+
273+
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
274+
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
275+
if is_torch_version("<", "2.5") and allow_torch_fix:
276+
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
277+
return causal_mask
278+
279+
280+
def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
281+
kwargs.pop("allow_torch_fix", None)
282+
kwargs.pop("allow_is_causal_skip", None)
283+
dtype = kwargs.get("dtype", torch.float32)
284+
mask = sdpa_mask_without_vmap(*args, **kwargs, allow_is_causal_skip=False, allow_torch_fix=False) # type: ignore
285+
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), torch.finfo(dtype).min) # type: ignore
286+
return mask
287+
288+
221289
UNSUPPORTED_OPS_PATCHING_SPEC = [
222290
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
223291
PatchingSpec(torch.linalg, "norm", onnx_compatible_linalg_norm, original_linal_norm),
@@ -355,10 +423,20 @@ def __enter__(self):
355423
self.patch_ops()
356424
setattr(self._model, self.orig_forward_name, self.patched_forward)
357425

426+
if is_transformers_version(">=", "4.53"):
427+
self.original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
428+
self.original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS["eager"]
429+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask_without_vmap)
430+
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
431+
358432
def __exit__(self, exc_type, exc_value, traceback):
359433
self.restore_ops()
360434
setattr(self._model, self.orig_forward_name, self.orig_forward)
361435

436+
if is_transformers_version(">=", "4.53"):
437+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", self.original_sdpa_mask)
438+
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", self.original_eager_mask)
439+
362440
def __call__(self, *args, **kwargs):
363441
if getattr(self._model, self.orig_forward_name) is self.orig_forward:
364442
logger.warning("Running the non-patched model")
@@ -368,14 +446,14 @@ def __call__(self, *args, **kwargs):
368446
class Seq2SeqModelPatcher(ModelPatcher):
369447
def __enter__(self):
370448
super().__enter__()
371-
if is_transformers_version(">=", "4.48"):
449+
if is_transformers_version(">=", "4.48") and is_transformers_version("<", "4.53"):
372450
# this is required when gpt2 is used as decoder in any
373451
# encoder-decoder model with cross attention blocks
374452
ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward
375453

376454
def __exit__(self, exc_type, exc_value, traceback):
377455
super().__exit__(exc_type, exc_value, traceback)
378-
if is_transformers_version(">=", "4.48"):
456+
if is_transformers_version(">=", "4.48") and is_transformers_version("<", "4.53"):
379457
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward
380458

381459
def __init__(

optimum/utils/input_generators.py

Lines changed: 89 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,87 +1345,95 @@ def __init__(
13451345
)
13461346
self.sliding_window_size = getattr(normalized_config, "sliding_window", sequence_length)
13471347

1348-
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1349-
if input_name in ["input_ids", "token_type_ids", "position_ids"]:
1350-
return super().generate(
1351-
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1352-
)
1353-
if input_name == "attention_mask":
1354-
return {
1355-
"full_causal_mask": self._generate_full_causal_mask(framework, float_dtype),
1356-
"sliding_causal_mask": self._generate_sliding_causal_mask(framework, float_dtype),
1357-
}
1358-
# if input_name == "full_causal_mask":
1359-
# return self._generate_full_causal_mask(framework, float_dtype)
1360-
# elif input_name == "sliding_causal_mask":
1361-
# return self._generate_sliding_causal_mask(framework, float_dtype)
1362-
else:
1363-
raise ValueError(f"What happened? This is not supported and should not be here: {input_name}")
1364-
1365-
def _generate_full_causal_mask(self, framework: str = "pt", float_dtype: str = "float32"):
1366-
if framework == "pt":
1367-
mask = torch.triu(
1368-
torch.ones((self.sequence_length, self.sequence_length), dtype=DTYPE_MAPPER.pt(float_dtype)),
1369-
diagonal=1,
1370-
)
1371-
mask = mask.masked_fill(mask == 1, float("-inf"))
1372-
mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1373-
return mask
1374-
elif framework == "tf":
1375-
mask = tf.linalg.band_part(
1376-
tf.ones((self.sequence_length, self.sequence_length), dtype=DTYPE_MAPPER.tf(float_dtype)), -1, 0
1377-
)
1378-
mask = tf.where(mask == 0, float("-inf"), 0.0)
1379-
mask = tf.expand_dims(mask, 0)
1380-
mask = tf.tile(mask, [self.batch_size, 1, 1])
1381-
return mask
1382-
else:
1383-
mask = np.triu(
1384-
np.ones((self.sequence_length, self.sequence_length), dtype=DTYPE_MAPPER.np(float_dtype)), k=1
1385-
)
1386-
mask = np.where(mask == 1, float("-inf"), 0.0)
1387-
mask = np.expand_dims(mask, 0)
1388-
mask = np.tile(mask, (self.batch_size, 1, 1))
1389-
return mask
1390-
1391-
def _generate_sliding_causal_mask(self, framework: str = "pt", float_dtype: str = "fp32"):
1392-
if framework == "pt":
1393-
mask = torch.full(
1394-
(self.sequence_length, self.sequence_length), float("-inf"), dtype=DTYPE_MAPPER.pt(float_dtype)
1395-
)
1396-
for i in range(self.sequence_length):
1397-
start = max(0, i - self.sliding_window_size + 1)
1398-
mask[i, start : i + 1] = 0.0
1399-
mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1400-
return mask
1401-
elif framework == "tf":
1402-
mask = tf.fill((self.sequence_length, self.sequence_length), float("-inf"))
1403-
mask = tf.cast(mask, DTYPE_MAPPER.tf(float_dtype))
1404-
1405-
updates = []
1406-
indices = []
1407-
for i in range(self.sequence_length):
1408-
start = max(0, i - self.sliding_window_size + 1)
1409-
for j in range(start, i + 1):
1410-
indices.append([i, j])
1411-
updates.append(0.0)
1412-
if indices:
1413-
indices = tf.constant(indices)
1414-
updates = tf.constant(updates, dtype=DTYPE_MAPPER.tf(float_dtype))
1415-
mask = tf.tensor_scatter_nd_update(mask, indices, updates)
1416-
mask = tf.expand_dims(mask, 0)
1417-
mask = tf.tile(mask, [self.batch_size, 1, 1])
1418-
return mask
1419-
else:
1420-
mask = np.full(
1421-
(self.sequence_length, self.sequence_length), float("-inf"), dtype=DTYPE_MAPPER.np(float_dtype)
1422-
)
1423-
for i in range(self.sequence_length):
1424-
start = max(0, i - self.sliding_window_size + 1)
1425-
mask[i, start : i + 1] = 0.0
1426-
mask = np.expand_dims(mask, 0)
1427-
mask = np.tile(mask, (self.batch_size, 1, 1))
1428-
return mask
1348+
# def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1349+
# if input_name in ["input_ids", "token_type_ids", "position_ids"]:
1350+
# return super().generate(
1351+
# input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1352+
# )
1353+
# if input_name == "attention_mask":
1354+
# return {
1355+
# "full_attention": self._generate_full_causal_mask(framework, float_dtype),
1356+
# "sliding_attention": self._generate_sliding_causal_mask(framework, float_dtype),
1357+
# }
1358+
# # if input_name == "full_causal_mask":
1359+
# # return self._generate_full_causal_mask(framework, float_dtype)
1360+
# # elif input_name == "sliding_causal_mask":
1361+
# # return self._generate_sliding_causal_mask(framework, float_dtype)
1362+
# else:
1363+
# raise ValueError(f"What happened? This is not supported and should not be here: {input_name}")
1364+
1365+
# def _generate_full_causal_mask(self, framework: str = "pt", float_dtype: str = "float32"):
1366+
# if framework == "pt":
1367+
# row_indices = torch.arange(self.sequence_length).view(-1, 1)
1368+
# col_indices = torch.arange(self.sequence_length).view(1, -1)
1369+
# causal_mask = row_indices >= col_indices
1370+
# dtype = getattr(torch, float_dtype)
1371+
# mask = torch.zeros((self.sequence_length, self.sequence_length), dtype=dtype)
1372+
# mask[~causal_mask] = float("-inf")
1373+
# mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1374+
# return mask
1375+
# elif framework == "tf":
1376+
# row_indices, col_indices = tf.meshgrid(
1377+
# tf.range(self.sequence_length), tf.range(self.sequence_length), indexing="ij"
1378+
# )
1379+
# causal_mask = row_indices >= col_indices
1380+
# dtype = getattr(tf, float_dtype)
1381+
# mask = tf.where(
1382+
# causal_mask,
1383+
# tf.zeros((self.sequence_length, self.sequence_length), dtype=dtype),
1384+
# tf.fill((self.sequence_length, self.sequence_length), float("-inf")),
1385+
# )
1386+
# mask = tf.expand_dims(mask, 0)
1387+
# mask = tf.tile(mask, [self.batch_size, 1, 1])
1388+
# return mask
1389+
1390+
# else:
1391+
# row_indices = np.arange(self.sequence_length).reshape(-1, 1)
1392+
# col_indices = np.arange(self.sequence_length).reshape(1, -1)
1393+
# causal_mask = row_indices >= col_indices
1394+
# dtype = getattr(np, float_dtype)
1395+
# mask = np.full((self.sequence_length, self.sequence_length), float("-inf"), dtype=dtype)
1396+
# mask[causal_mask] = 0.0
1397+
# mask = np.expand_dims(mask, 0)
1398+
# mask = np.repeat(mask, self.batch_size, axis=0)
1399+
# return mask
1400+
1401+
# def _generate_sliding_causal_mask(self, window_size: int, framework: str = "pt", float_dtype: str = "float32"):
1402+
# if framework == "pt":
1403+
# row_indices = torch.arange(self.sequence_length).view(-1, 1)
1404+
# col_indices = torch.arange(self.sequence_length).view(1, -1)
1405+
# causal_mask = (row_indices >= col_indices) & (row_indices - col_indices < window_size)
1406+
# dtype = getattr(torch, float_dtype)
1407+
# mask = torch.zeros((self.sequence_length, self.sequence_length), dtype=dtype)
1408+
# mask[~causal_mask] = float("-inf")
1409+
# mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1410+
# return mask
1411+
# elif framework == "tf":
1412+
# row_indices, col_indices = tf.meshgrid(
1413+
# tf.range(self.sequence_length), tf.range(self.sequence_length), indexing="ij"
1414+
# )
1415+
# causal_condition = row_indices >= col_indices
1416+
# window_condition = (row_indices - col_indices) < window_size
1417+
# sliding_mask = causal_condition & window_condition
1418+
# dtype = getattr(tf, float_dtype)
1419+
# mask = tf.where(
1420+
# sliding_mask,
1421+
# tf.zeros((self.sequence_length, self.sequence_length), dtype=dtype),
1422+
# tf.fill((self.sequence_length, self.sequence_length), float("-inf")),
1423+
# )
1424+
# mask = tf.expand_dims(mask, 0)
1425+
# mask = tf.tile(mask, [self.batch_size, 1, 1])
1426+
# return mask
1427+
# else:
1428+
# row_indices = np.arange(self.sequence_length).reshape(-1, 1)
1429+
# col_indices = np.arange(self.sequence_length).reshape(1, -1)
1430+
# causal_mask = (row_indices >= col_indices) & (row_indices - col_indices < window_size)
1431+
# dtype = getattr(np, float_dtype)
1432+
# mask = np.full((self.sequence_length, self.sequence_length), float("-inf"), dtype=dtype)
1433+
# mask[causal_mask] = 0.0
1434+
# mask = np.expand_dims(mask, 0)
1435+
# mask = np.repeat(mask, self.batch_size, axis=0)
1436+
# return mask
14291437

14301438

14311439
class DummySpeechT5InputGenerator(DummyInputGenerator):

optimum/utils/normalized_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class NormalizedConfigManager:
253253
"electra": NormalizedTextConfig,
254254
"encoder-decoder": NormalizedEncoderDecoderConfig,
255255
"gemma": NormalizedTextConfigWithGQA,
256+
"gemma3_text": NormalizedTextConfigWithGQA,
256257
"gpt2": GPT2LikeNormalizedTextConfig,
257258
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
258259
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),

test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from optimum.onnxruntime import ORTModelForCausalLM
2+
3+
4+
model_name = "google/gemma-3-1b-it"
5+
6+
onnx_model = ORTModelForCausalLM.from_pretrained(
7+
model_name,
8+
export=True,
9+
trust_remote_code=True,
10+
)
11+
12+
13+
print("done")

0 commit comments

Comments
 (0)