Skip to content

Zentorch 5.1 AVX2 Paths #2

@meven3000

Description

@meven3000

Hi based on ZenDNN 5.1, BLIS 5.1 AVX2 can now emulate acceleration for BF16.

However, from testing Zentorch appears to be running AVX512 check for AVX2 which fail if the CPU does not support AVX512.

[WARNING zentorch.llm._checks - essential_checks:92] Intel Extension for PyTorch not installed. So, the ZenTorch specific optimizations for LLMs might not be triggered.

Setting pad_token_id to eos_token_id:None for open-end generation.

[INFO zentorch._compile_backend - zentorch_compile:123] Called the zentorch backend.

[INFO zentorch._compile_backend - zentorch_compile_fx_inner:88] Optimizing the model with zentorch ops.

[INFO zentorch._optimize - optimize:48] Optimizing the fx_graph with zentorch ops.

[INFO zentorch._graph_cleanup - unused_node_elimination:30] Removing unused nodes from the fx_graph.

[INFO zentorch._eltwise_binary_fusions - zentorch_eltwise_binary_fusions:25] Fusing the zentorch binary elementwise ops in fx graph.

[INFO zentorch._eltwise_binary_fusions - zentorch_eltwise_binary_fusions:78] Recompiling the fx_graph with fusion changes made.

[INFO zentorch._eltwise_unary_fusions - zentorch_eltwise_unary_fusions:1527] Fusing the zentorch unary elementwise ops in fx graph.

[INFO zentorch.custom_op_replacement - qlinear_reorder_optimizations:642] Reorder optimization for serialized qlinear* ops.

[INFO zentorch._custom_op_replacement - emb_ops_horizontal_fusion:130] Fusing horizontal parallel embedding ops.

[INFO zentorch._custom_op_replacement - qkv_fusion:850] Detecting and executing QKV parallel ops.

[INFO zentorch._custom_op_replacement - eb_group_mlp_group_fusion:982] Fusing the horizontally fused EmbeddingBag op and the vertically fused MLP op

[INFO zentorch._compile_backend - zentorch_compile_fx_inner:92] Model is passed to compile_fx_inner.

W0723 06:48:58.894000 1 torch/_inductor/debug.py:454] [0/0] model__0_inference_0 debug trace: /app/torch_compile_debug/run_2025_07_23_06_48_32_528769-pid_1/torchinductor/model__0_inference_0.0

[API:I][0.000007] CPU Engine create

[CORE:V0][0.000006] CPU Engine created [engine]

[CORE:I][0.000017] CPU Engine created [cpu/engine]

[API:I][0.000001] Memory create

[CORE:V0][0.000001] Memory desc init by tag [memory]

[CORE:I][0.000024] Memory created [memory]

[API:I][0.000044] Memory create

[CORE:V0][0.000031] Memory desc init by tag [memory]

[CORE:I][0.000036] Memory created [memory]

[API:I][0.000053] Memory create

[CORE:V0][0.000041] Memory desc init by tag [memory]

[CORE:I][0.000045] Memory created [memory]

[API:I][0.000001] matmul desc create - no bias

[CORE:I][0.000001] matmul desc init [matmul]

[CORE:I][0.000001] CPU Engine: primitive_cache_capacity: 1024

[CORE:V0][0.000001] zendnn_f32_matmul_t::pd_t::init()

[CORE:V0][0.000129] Memory desc init by tag [memory]

[CORE:V0][0.000132] Memory desc init by tag [memory]

[CORE:V0][0.000135] Memory desc init by tag [memory]

ZenDNN is Running......

[CORE:V0][0.000001] ZenDNN Ref gemm_f32_matmul_t::pd_t::init()

[CORE:V0][0.000008] ZenDNN Ref gemm_f32_matmul_t::pd_t::check_and_configure_attributes

[API:I][0.000162] matmul primitive_desc create - attr

[PROF:I][0.000015] zendnn_primitive_create,cache_miss,cpu,plugin_op:zentorch::zentorch_bmm,matmul,gemm:jit,undef,src_f32::blocked:abc:f0 wei_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,,,1x64x1:1x1x8:1x64x8,0.03392,ms

[API:I][0.000227] matmul primitive create

[API:I][0.000230] CPU Stream create

[CORE:I][0.000001] CPU Stream created [stream]

[CORE:V0][0.000001] CPU Stream created [cpu/stream]

[CORE:I][0.000118] ZenDNN Ref gemm_f32_matmul_t::execute_ref

[PROF:I][0.013998] zendnn_primitive_execute,cpu,plugin_op:zentorch::zentorch_bmm,matmul,gemm:jit,undef,src_f32::blocked:abc:f0 wei_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,,,1x64x1:1x1x8:1x64x8,13.9558,ms

Traceback (most recent call last):

File "/app/startup/test_llm_bf16.py", line 115, in

output = model.generate(

         ^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context

return func(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py", line 2215, in generate

result = self._sample(

         ^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py", line 3206, in _sample

outputs = self(**model_inputs, return_dict=True)

          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl

return self._call_impl(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl

return forward_call(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn

return fn(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 1135, in forward

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)

File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn

return fn(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1209, in forward

return compiled_fn(full_args)

       ^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper

all_outs = call_func_at_runtime_with_args(

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args

out = normalize_as_list(f(args))

                        ^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn

outs = compiled_fn(args)

       ^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper

return compiled_fn(runtime_args)

       ^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 460, in call

return self.current_callable(inputs)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/tmp/torchinductor_root/ss/cssupy3etd3u5fmf5pwt7zwvqk464ctj7moo6mmd6mmsikzepmnl.py", line 35114, in call

buf8 = torch.ops.zentorch.zentorch_attn_qkv_fusion.default([buf6, buf6, buf6], [reinterpret_tensor(buf7, (8, 8192), (8192, 1), 0), reinterpret_tensor(buf7, (8, 8192), (8192, 1), 0), reinterpret_tensor(buf7, (8, 8192), (8192, 1), 0)], [reinterpret_tensor(arg7_1, (8192, 8192), (1, 8192), 0), reinterpret_tensor(arg8_1, (8192, 1024), (1, 8192), 0), reinterpret_tensor(arg9_1, (8192, 1024), (1, 8192), 0)], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0, 0, 0], [1, 1, 1])

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in call

return self._op(*args, **kwargs)

       ^^^^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: /tmp/zentorch/src/cpu/cpp/MatmulUtils.hpp:220 check_valid_dtypes_for_matmul : zentorch_matmul bf16 path needs the cpu support avx512bf16

[CORE:I][0.772005] CPU Stream deleted [stream]

[CORE:I][0.772400] CPU Engine deleted [engine]

Can you please advise if this is expected and can be modified to allow inference using only AVX2 (zen 2). Or are additional parameters required for this to work.

REF: model is llama 7b converted from with transformers from FP32 to BF16.

Note: interestingly by disabled #os.environ["TORCHDYNAMO_DISABLE"] = "1" the model will run, but there is no Zentorch processing based on logs that are presented

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions