Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
61d0f3e
[OpenVINO] Support Qwen3-next
rkazants Nov 16, 2025
ea6b4b3
Fix config and add base patching
rkazants Nov 16, 2025
7e37aae
Extend patching
rkazants Nov 17, 2025
8bc1c5a
Initial patching for linear attention
rkazants Nov 18, 2025
26a4b65
Patch recurrent gated delta rule
rkazants Nov 18, 2025
a0e8d3c
Use module extension for conversion of chunked_attention_cell
rkazants Nov 20, 2025
486a4f8
Implement conversion extension for chunked gated delta rule cell
rkazants Nov 23, 2025
f623e57
Patch sparse moe block
rkazants Nov 23, 2025
e76f243
Use core_attn_out
rkazants Nov 23, 2025
b191d59
Fix use of mask
rkazants Nov 24, 2025
0b1bb21
Correct shape for recurrent_state in config file
rkazants Nov 24, 2025
6a3d22f
Re-write patch for MoE
rkazants Nov 28, 2025
9df28e3
rkazants Nov 28, 2025
9ddaad9
Merge remote-tracking branch 'upstream/main' into support_qwen3_next
rkazants Jan 12, 2026
6384b9f
Apply code-formatting
rkazants Jan 12, 2026
f66862a
Fix previous commit with main merge
rkazants Jan 12, 2026
f4af348
Re-patch sparse MoE
rkazants Jan 12, 2026
500810f
Merge remote-tracking branch 'upstream/main' into support_qwen3_next
rkazants Jan 30, 2026
aee20f4
Merge remote-tracking branch 'upstream/main' into support_qwen3_next
rkazants Mar 2, 2026
92ec0e5
Fix code formatting
rkazants Mar 2, 2026
3e1c66f
Add tests for qwen3 next
rkazants Mar 2, 2026
5f45761
Unify representation for CausalConv1d
rkazants Mar 3, 2026
d49c7cd
Apply code-formatting
rkazants Mar 3, 2026
9874d8c
Leave only one GatedDeltaNet representation
rkazants Mar 3, 2026
f1dd676
Fix support for other models
rkazants Mar 3, 2026
2665dc9
Fix test_decoder.py
rkazants Mar 3, 2026
162bb72
Use chunk size equal to one
rkazants Mar 3, 2026
f940262
Apply suggestion from @rkazants
rkazants Mar 3, 2026
0112422
Move to recurrent gated delta net
rkazants Mar 4, 2026
6a6bb5b
Merge remote-tracking branch 'origin/support_qwen3_next' into support…
rkazants Mar 4, 2026
0bbc2a1
Apply code-formatting
rkazants Mar 4, 2026
15a3aee
Fix inference
rkazants Mar 4, 2026
70d75ed
Add comments to patching and config code
rkazants Mar 4, 2026
d9c233c
Apply code-formatting
rkazants Mar 4, 2026
2956e0a
Use the right decoder patcher
rkazants Mar 4, 2026
5657efd
Fix test_export test
rkazants Mar 4, 2026
1f32570
No beam search support for Qwen3-next
rkazants Mar 4, 2026
203726f
Update tests/openvino/test_decoder.py
rkazants Mar 4, 2026
b4d7505
Remove unneeded cached function calls for chunked gdn
rkazants Mar 4, 2026
b453da7
Update optimum/exporters/openvino/model_patcher.py
rkazants Mar 4, 2026
33cb551
Handle bf16 weights
rkazants Mar 4, 2026
c0e1311
Apply suggestion from @rkazants
rkazants Mar 4, 2026
e8c4702
Apply suggestion from @rkazants
rkazants Mar 4, 2026
111e69d
Apply suggestion from @rkazants
rkazants Mar 4, 2026
1a91c7b
Apply suggestion from @rkazants
rkazants Mar 4, 2026
5238cff
Apply suggestion from @rkazants
rkazants Mar 4, 2026
e789e75
Comment patch_recurrent_gated_delta_rule
rkazants Mar 5, 2026
169033e
Move convert_recurrent_attention_cell to internal module _ov_ops.py
rkazants Mar 5, 2026
1b835bb
Apply code formatting
rkazants Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Here is the list of the supported architectures :
- Qwen2VL
- Qwen2.5VL
- Qwen3VL
- Qwen3-Next
- ResNet
- Roberta
- Roformer
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,16 @@ def ts_patched_forward(*args, **kwargs):

__make_16bit_traceable(model)

conversion_extensions = getattr(patcher, "conversion_extensions", None)
module_extensions = getattr(patcher, "module_extensions", None)
if module_extensions is not None:
ts_decoder_kwargs["module_extensions"] = module_extensions
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
extension=conversion_extensions,
)

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?
Expand Down
140 changes: 140 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
Qwen2VLLanguageModelPatcher,
Qwen2VLVisionEmbMergerPatcher,
Qwen3MoeModelPatcher,
Qwen3NextModelPatcher,
Qwen3VLLanguageModelPatcher,
Qwen3VLVisionEmbMergerPatcher,
QwenModelPatcher,
Expand Down Expand Up @@ -5314,3 +5315,142 @@ class HunyuanV1DenseOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.57.0"
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator


class Qwen3NextDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
"""
Generates dummy cache_params inputs for Qwen3-Next architectures.
"""

SUPPORTED_INPUT_NAMES = ("cache_params",)

def __init__(
self,
task: str,
normalized_config,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
**kwargs,
)

config = normalized_config.config
self.num_full_attn_layers = config.layer_types.count("full_attention")
self.num_linear_attn_layers = config.layer_types.count("linear_attention")
self.conv_kernel_size = config.linear_conv_kernel_dim
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.num_key_value_heads = config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
cache_params = []

for idx in range(self.num_linear_attn_layers):
# (batch_size, d_inner, d_conv)
d_inner = self.num_k_heads * (2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads)
conv_state_shape = (
self.batch_size,
d_inner,
self.conv_kernel_size,
)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
cache_params.append(conv_state)
num_heads = self.num_v_heads
recurrent_state_shape = (self.batch_size, num_heads, self.head_k_dim, self.head_v_dim)
recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype)
cache_params.append(recurrent_state)

for idx in range(self.num_full_attn_layers):
kv_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.head_dim)
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
cache_params.append(k)
cache_params.append(v)

return cache_params


@register_in_tasks_manager(
"qwen3_next",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3NextOpenVINOConfig(Qwen3OpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3NextDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Qwen3NextDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = "4.57.0"
_MODEL_PATCHER = Qwen3NextModelPatcher

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
cache_name_prefix = "cache_params.past"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
cache_name_prefix = "cache_params.present"

self.num_full_attn_layers = self._normalized_config.layer_types.count("full_attention")
self.num_linear_attn_layers = self._normalized_config.layer_types.count("linear_attention")

for i in range(self.num_linear_attn_layers):
# [batch_size, conv_kernel_size - 1, d_model]
inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"}
# [batch_size, d_state, d_model]
inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"}

for i in range(self.num_full_attn_layers):
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states
# which we separate and call them as past_ssm_states and past_conv_states
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")]
if self.use_past_in_inputs:
input_names.extend(["cache_params"])

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

return dummy_inputs
Loading
Loading