diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7e67ec6d0c94e..5805333a0868c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -90,6 +90,7 @@ Do not modify directly.* * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft + * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling * com.microsoft.SkipLayerNormalization @@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RotaryEmbedding** + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + that are multiplied to query and key before the inner product of query and key is taken. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1.0
+
+ +#### Inputs + +
+
input : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
position_ids : M
+
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+
cos_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
+ +#### Outputs + +
+
output : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int64)
+
Constrain input and output types to integer tensors
+
+ + ### **com.microsoft.SampleOp** Sample echo operator. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e2d500006b05f..dea71d81f8df5 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -477,9 +477,11 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -866,6 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804c61..694c40bf3eda6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -16,7 +16,6 @@ #include #include -#include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 73b83057bdbe9..00e82c9844b3d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -206,6 +206,7 @@ Status CheckInputs(const T* query, } } + int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; @@ -216,13 +217,21 @@ Status CheckInputs(const T* query, } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(kv_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; } if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)"); + "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); } } @@ -257,7 +266,6 @@ Status CheckInputs(const T* query, } } - int total_sequence_length = past_sequence_length + kv_sequence_length; bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..4a266af789250 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int position_ids_format = parameters.position_ids_format; + const int half_head_size = head_size / 2; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input_src + data_offset; + T* output_data = output_dest + data_offset; + + // Cache is (M, H/2) + const int position_id = (position_ids_format == 0) + ? static_cast(pos_ids_data[0]) + s + : static_cast(pos_ids_data[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache_data + cache_offset; + const T* sin_data = sin_cache_data + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < head_size; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..be834a66cdc69 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h new file mode 100644 index 0000000000000..cf8080800e072 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + input_dims.size()); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->position_ids_format = position_ids_format; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index b4c51ab290eb7..f77e403f26dde 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -124,6 +125,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -253,6 +256,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -299,6 +303,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index e86a12d9fb873..4e103c2556a7a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -20,20 +20,29 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { +template +Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = p_ctx->Input(1); const Tensor* gamma = p_ctx->Input(2); @@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { } mean = mean / hidden_size; - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon_); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + } for (int64_t h = 0; h < hidden_size; h++) { - if (nullptr == beta_data) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; } else { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 7723541cb6b18..69edf4609e340 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { -template +template class SkipLayerNorm final : public OpKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..b4b5dac1fbe19 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "contrib_ops/cuda/bert/rotary_embedding.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids->Data(), + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h new file mode 100644 index 0000000000000..6dab2ad56749e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..c54e72dcfce13 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int sequence_length, + const int num_heads, + const int head_size, + const int position_ids_format, + const bool interleaved) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.z; + const int s = blockIdx.y; + const int n = blockIdx.x; + + const int i = threadIdx.x; + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input + data_offset; + T* output_data = output + data_offset; + + // Cache is (M, H/2) + const int half_head_size = head_size / 2; + const int position_id = (position_ids_format == 0) ? \ + static_cast(position_ids[0]) + s \ + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i+1 : i-1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + float* output, + const float* input, + const int64_t* position_ids, + const float* cos_cache, + const float* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + half* output, + const half* input, + const int64_t* position_ids, + const half* cos_cache, + const half* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h new file mode 100644 index 0000000000000..29ff48a8ad0fb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 52ff285539360..c52f869d6a9d2 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -91,6 +91,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -250,6 +252,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 3a75b29ffe3c7..76c3f8716ff09 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -946,7 +946,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), " + "or (batch_size, sequence_length, total_sequence_length)", "M", OpSchema::Optional) .Input(5, @@ -1129,6 +1130,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DecoderAttentionTypeAndShapeInference(ctx); })); +constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +that are multiplied to query and key before the inner product of query and key is taken. +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbedding, 1, + OpSchema() + .SetDoc(RotaryEmbedding_ver1_doc) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1.0", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "input", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "position_ids", + "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)", + "M") + .Input(2, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Input(3, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Output(0, + "output", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, @@ -1500,4 +1544,4 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index afa5d101bbd8d..afaa380d6ac79 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -95,6 +95,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBia class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); @@ -200,6 +201,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 67e9f1b55e9ae..272727a9f5375 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,9 +206,11 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, @@ -462,6 +464,8 @@ def _onnx_infer_single_node(self, node): "BiasSplitGelu", "BiasAdd", "NhwcConv", + "QuickGelu", + "RotaryEmbedding", ] if not skip_infer: @@ -2307,6 +2311,9 @@ def _infer_FastGelu(self, node): # noqa: N802 def _infer_Gelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_QuickGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) @@ -2378,6 +2385,19 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802 def _infer_BiasAdd(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_RotaryEmbedding(self, node): # noqa: N802 + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + def _infer_PythonOp(self, node): # noqa: N802 output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types @@ -2583,12 +2603,19 @@ def get_prereq(node): self._check_merged_dims(in_dims, allow_broadcast=True) for i_o in range(len(node.output)): - # Special case: We do not care about the training related - # outputs of SkipLayerNormalization + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op if ( node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" ) and i_o in [1, 2]: continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue vi = self.known_vi_[node.output[i_o]] out_type = vi.type @@ -2750,13 +2777,13 @@ def get_prereq(node): if i in self.known_vi_: logger.debug(self.known_vi_[i]) else: - logger.debug(f"not in knwon_vi_ for {i}") + logger.debug(f"not in known_vi_ for {i}") logger.debug("node outputs:") for o in node.output: if o in self.known_vi_: logger.debug(self.known_vi_[o]) else: - logger.debug(f"not in knwon_vi_ for {o}") + logger.debug(f"not in known_vi_ for {o}") if self.auto_merge_ and not out_type_undefined: logger.debug("Merging: " + str(self.suggested_merge_)) return False diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 4f898245d01bd..b6f7a44450c62 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -33,6 +33,7 @@ class Precision(Enum): FLOAT32 = "fp32" FLOAT16 = "fp16" INT8 = "int8" + INT4 = "int4" def __str__(self): return self.value @@ -610,7 +611,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): return memory_before_test with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_cpu_usage) try: fn_thread = executor.submit(func) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index c1c709d6d759b..4228c892d03ae 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,6 +1272,38 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): + past_seq_len = past_seq_len_input + if past_seq_len not in model.get_graphs_input_names(): + # Replace model input for past sequence length + new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) + model.model.graph.input.append(new_input) + + # Replace MultiHeadAttention with GroupQueryAttention + for node in model.model.graph.node: + if node.op_type == "MultiHeadAttention": + gqa_node = onnx.helper.make_node( + "GroupQueryAttention", + inputs=[ + node.input[0], # query + node.input[1], # key + node.input[2], # value + node.input[6], # past_key + node.input[7], # past_value + past_seq_len, # past_sequence_length + ], + outputs=node.output, + name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), + domain="com.microsoft", + num_heads=node.attribute[0].i, + kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + is_past_bsnh=0, + ) + model.model.graph.node.remove(node) + model.model.graph.node.extend([gqa_node]) + return model + + def update_decoder_subgraph_output_cross_attention(subg: GraphProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 1dbdf39613cdd..c1b241aa1a5ec 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -111,7 +111,7 @@ def __init__( model: OnnxModel, hidden_size: int, num_heads: int, - attention_mask: AttentionMask, + attention_mask: Optional[AttentionMask] = None, use_multi_head_attention: bool = False, disable_multi_head_attention_bias: bool = False, search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006 @@ -120,7 +120,7 @@ def __init__( super().__init__(model, attention_op_name, search_op_types) self.hidden_size = hidden_size self.num_heads = num_heads - self.attention_mask = attention_mask + self.attention_mask = attention_mask if attention_mask else AttentionMask(model) self.use_multi_head_attention = use_multi_head_attention self.disable_multi_head_attention_bias = disable_multi_head_attention_bias self.mask_filter_value = None @@ -219,6 +219,31 @@ def get_add_qk_str(self, add_qk: NodeProto): return add_qk.input[1] + def reshape_add_qk(self, add_qk: str): + # Convert 4D mask from (B,1,S,T) to (B,N,S,T) + # B = batch size, N = num heads, S = source sequence length, T = target sequence length + mask_output_name = add_qk + "_mask" + + # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists + concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add)) + if len(concat_node) == 1: + return mask_output_name + + assert len(concat_node) == 0 + concat_node_name = self.model.create_node_name("Concat") + concat_add_qk_fp32 = helper.make_node( + "Concat", + inputs=[add_qk for _ in range(self.num_heads)], + outputs=[mask_output_name], + name=concat_node_name, + axis=1, + ) + # Add new node to graph + self.nodes_to_add.append(concat_add_qk_fp32) + self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + + return mask_output_name + def concat_kv(self, past_k: str, past_v: str) -> str: """Concatenate past_k and past_v inputs to create past_kv input. @@ -875,21 +900,8 @@ def create_attention_node( past_kv = self.concat_kv(past_k, past_v) attention_inputs.append(past_kv) - if add_qk_str: - # Convert 4d mask from (B,1,M,M) to (B,N,M,M) - # B = batch size, M = max sequence length, N = num heads - concat_node_name = self.model.create_node_name("Concat") - mask_output_name = add_qk_str + "_mask" - concat_add_qk_fp32 = helper.make_node( - "Concat", - inputs=[add_qk_str for _ in range(num_heads)], - outputs=[mask_output_name], - name=concat_node_name, - axis=1, - ) - # Add new nodes to graph - self.nodes_to_add.append(concat_add_qk_fp32) - self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + if add_qk_str is not None: + mask_output_name = self.reshape_add_qk(add_qk_str) # Add attention mask to attention node if not past_exists: diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index 117468be412fa..c5d7bc16d64f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -113,3 +113,20 @@ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: self.model.add_initializer(tensor, self.this_graph_name) return tensor + + def add_nodes_to_remove(self, nodes: List[NodeProto]): + # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths). + # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B + # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are + # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first. + # Since path A's shared nodes are removed, path B's shared nodes are not removed because they + # were previously removed for path A. This causes an error to print in remove_node that a node + # has failed to be removed. + # + # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`. + # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could + # be scenarios where the nodes need to be removed in a specific order and converting to a set would + # lose this order. + for node in nodes: + if node not in self.nodes_to_remove: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 69b5cd26f4525..8c80fcad0ab49 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -26,6 +26,7 @@ def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True + self.enable_rotary_embeddings = True # Use MultiHeadAttention instead of Attention operator. The difference: # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is @@ -81,6 +82,8 @@ def parse(args): options.enable_gelu = False if args.disable_layer_norm: options.enable_layer_norm = False + if args.disable_rotary_embeddings: + options.enable_rotary_embeddings = False if args.disable_attention: options.enable_attention = False if args.use_multi_head_attention: @@ -294,3 +297,10 @@ def add_arguments(parser: ArgumentParser): help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(use_group_norm_channels_first=False) + + parser.add_argument( + "--disable_rotary_embeddings", + required=False, + action="store_true", + help="Do not fuse rotary embeddings into RotaryEmbedding op", + ) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py new file mode 100644 index 0000000000000..3c5029ac5752f --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -0,0 +1,1044 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional, Union + +from fusion_attention import FusionAttention +from fusion_base import Fusion +from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionRotaryAttention(FusionAttention): + """ + Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + super().__init__( + model, + hidden_size, + num_heads, + use_multi_head_attention=True, + search_op_types=["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Add"], + ) + + def create_mha_node( + self, + input: str, + output: str, + q_rotary: NodeProto, + k_rotary: NodeProto, + v_matmul: NodeProto, + attn_mask: str = "", + add_qk: str = "", + past_k: str = "", + past_v: str = "", + present_k: str = "", + present_v: str = "", + scale: Optional[float] = None, + ) -> Union[NodeProto, None]: + assert self.num_heads > 0 + + if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0: + logger.debug( + f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}" + ) + return None + + mha_node_name = self.model.create_node_name("MultiHeadAttention") + mha_inputs = [ + q_rotary.output[0], + k_rotary.output[0], + v_matmul.output[0], + "", # bias + attn_mask, # key_padding_mask + add_qk, # relative_position_bias + past_k, + past_v, + ] + + mha_outputs = [output] + if present_k and present_v: + mha_outputs.extend([present_k, present_v]) + + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=mha_outputs, + name=mha_node_name, + ) + + mha_node.domain = "com.microsoft" + mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) + if scale is not None: + mha_node.attribute.extend([helper.make_attribute("scale", scale)]) + if self.mask_filter_value is not None: + mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + self.increase_counter("MultiHeadAttention") + return mha_node + + def check_runtime_shape_paths_for_function( + self, + reshape_qkv_2, # Reshape after Transpose + reshape_qkv_1, # Reshape before Transpose + reshape_q_2, # Reshape after RotaryEmbedding + reshape_k_2, # Reshape after RotaryEmbedding + reshape_v_2, # Reshape after Transpose + reshape_v_1, # Reshape before Transpose + add_qk, # Add before Softmax + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) + concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1]) + if concat_qkv_2_path is None or concat_qkv_1_path is None: + return False + concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if ( + reshape_qkv_2_path_1 is None + or reshape_qkv_2_path_2 is None + or reshape_qkv_1_path_1 is None + or reshape_qkv_1_path_2 is None + ): + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2 + if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name: + return False + + # Check #2: check paths for v nodes + concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1]) + concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1]) + if concat_v_2_path is None or concat_v_1_path is None: + return False + concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0] + + reshape_v_2_path_1 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_v_2_path_2 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0] + ) + reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if ( + reshape_v_2_path_1 is None + or reshape_v_2_path_2 is None + or reshape_v_1_path_1 is None + or reshape_v_1_path_2 is None + ): + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2 + if ( + reshape_v_2_path_1[2].name != gather_1.name + or reshape_v_2_path_2[2].name != gather_2.name + or reshape_v_1_path_1[1].name != gather_1.name + or reshape_v_1_path_2[1].name != gather_2.name + ): + return False + + # Check #3: check paths for k nodes + concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1]) + if concat_k_2_path is None: + return False + concat_k_2 = concat_k_2_path[0] + + reshape_k_2_path_1 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_k_2_path_2 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0] + ) + if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2 + if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1]) + if concat_q_2_path is None: + return False + concat_q_2 = concat_q_2_path[0] + + reshape_q_2_path_1 = self.model.match_parent_path( + concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2 + if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name: + return False + + # Check #5: check Mul nodes are the same for q, k, v + mul_q = reshape_q_2_path_1[1] + mul_k = reshape_k_2_path_1[1] + mul_v = reshape_v_2_path_1[1] + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + # Check #6: check paths for attention mask nodes + attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0]) + attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0]) + if attn_mask_path_1 is not None: + _, slice_qk_2, slice_qk_1 = attn_mask_path_1 + elif attn_mask_path_2 is not None: + _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2 + else: + return False + # Check first input to Slice #1 is 3D attention mask of shape (B,S,T) + if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}: + return False + + slice_qk_2_path = self.model.match_parent_path( + slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_1 = self.model.match_parent_path( + slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1]) + if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None: + return False + + # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path + # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1 + if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name: + return False + + # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2 + # Check if first input to Add and Unsqueeze #1 is position ids + if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]: + return False + + return True + + def check_runtime_shape_paths_for_nodes( + self, + reshape_qkv, # Final reshape before o_proj MatMul + reshape_q, # Reshape before q_proj MatMul + reshape_k, # Reshape before k_proj MatMul + reshape_v, # Reshape before v_proj MatMul + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1]) + if concat_qkv_path is None: + return False + concat_qkv = concat_qkv_path[0] + + reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_path_1 + _, gather_2, shape_2 = reshape_qkv_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check #2: check paths for v nodes + concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1]) + if concat_v_path is None: + return False + concat_v = concat_v_path[0] + + reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_v_path_1 is None or reshape_v_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name: + return False + + # Check #3: check paths for k nodes + concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1]) + if concat_k_path is None: + return False + concat_k = concat_k_path[0] + + reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_k_path_1 is None or reshape_k_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1]) + if concat_q_path is None: + return False + concat_q = concat_q_path[0] + + reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_path_1 is None or reshape_q_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": + return + + # qkv_nodes_1 is for LLaMA-2 Microsoft + # qkv_nodes_2 is for LLaMA-2 Hugging Face + qkv_nodes = None + qkv_nodes_1 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + qkv_nodes_2 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) + if qkv_nodes_1 is not None: + _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 + qkv_nodes = qkv_nodes_1 + elif qkv_nodes_2 is not None: + _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 + qkv_nodes = qkv_nodes_2 + else: + logger.debug("fuse_rotary_attention: failed to match qkv nodes") + return + + # v_nodes_1 is for LLaMA-2 Microsoft + # v_nodes_3 is for LLaMA-2 Hugging Face + past_v, present_v, past_seq_len = "", "", "" + v_nodes = None + v_nodes_1 = self.model.match_parent_path( + matmul_qkv, + ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + v_nodes_2 = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0], + ) + v_nodes_3 = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if v_nodes_1 is not None: + reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 + v_nodes = v_nodes_1 + + concat_v_path = self.model.match_parent_path( + concat_v, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_v_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in v path") + return + + past_v = concat_v_path[0].input[0] + past_seq_len = concat_v_path[-1].input[0] + present_v = concat_v.output[0] + elif v_nodes_2 is not None: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2 + v_nodes = v_nodes_2 + past_v = concat_v.input[0] + present_v = concat_v.output[0] + elif v_nodes_3 is not None: + transpose_v, reshape_v, matmul_v = v_nodes_3 + v_nodes = v_nodes_3 + present_v = transpose_v.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "Div", "MatMul"], + [0, 0, 0, 0], + ) + add_qk, matmul_qk = None, None + if qk_nodes is not None: + _, add_qk, _, matmul_qk = qk_nodes + else: + logger.debug("fuse_rotary_attention: failed to match qk nodes") + return + + # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask + # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + attn_mask, add_qk_str = "", "" + attn_mask_nodes_1 = self.model.match_parent_path( + add_qk, + ["Concat", "Slice", "Slice"], + [1, 0, 0], + ) + attn_mask_nodes_2 = self.model.match_parent_path( + add_qk, + ["Cast", "Concat", "Slice", "Slice"], + [1, 0, 0, 0], + ) + attn_mask_nodes_3 = self.model.match_parent_path( + add_qk, + ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_4 = self.model.match_parent_path( + add_qk, + ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 2, 1, 0, 0, 0], + ) + if attn_mask_nodes_1 is not None: + _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_2 is not None: + _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_3 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0]) + elif attn_mask_nodes_4 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + else: + logger.debug("fuse_rotary_attention: failed to match attention mask nodes") + return + + # k_nodes_1 is for LLaMA-2 Microsoft + # k_nodes_2 is for LLaMA-2 Hugging Face + past_k, present_k = "", "" + k_nodes = None + k_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + k_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + k_nodes_3 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0, 0], + ) + if k_nodes_1 is not None: + reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 + k_nodes = k_nodes_1 + + concat_k_path = self.model.match_parent_path( + concat_k, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_k_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in k path") + return + + past_k = concat_k_path[0].input[0] + shared_past_seq_len = concat_k_path[-1].input[0] + present_k = concat_k.output[0] + + assert past_seq_len == shared_past_seq_len + elif k_nodes_2 is not None: + _, rotary_k, _, reshape_k, matmul_k = k_nodes_2 + k_nodes = k_nodes_2 + present_k = rotary_k.output[0] + elif k_nodes_3 is not None: + _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3 + k_nodes = k_nodes_3 + past_k = concat_k.input[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match k nodes") + return + + # q_nodes_1 is for LLaMA-2 Microsoft + # q_nodes_2 is for LLaMA-2 Hugging Face + q_nodes = None + q_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"], + [0, 0, 0, 0], + ) + q_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [0, 0, 0, 0], + ) + if q_nodes_1 is not None: + reshape_q_2, _, rotary_q, matmul_q = q_nodes_1 + q_nodes = q_nodes_1 + elif q_nodes_2 is not None: + rotary_q, _, reshape_q, matmul_q = q_nodes_2 + q_nodes = q_nodes_2 + else: + logger.debug("fuse_rotary_attention: failed to match q nodes") + return + + if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]: + logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths") + return + + root_output = "" + if qkv_nodes == qkv_nodes_1: + if not self.check_runtime_shape_paths_for_function( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + reshape_v_1, + add_qk, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv_2.output[0] + + elif qkv_nodes == qkv_nodes_2: + if not self.check_runtime_shape_paths_for_nodes( + reshape_qkv, + reshape_q, + reshape_k, + reshape_v, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv.output[0] + + # Rename inputs of rotary_q/k so it connects with output of matmul_q/k + # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding + # After: MatMul --> RotaryEmbedding + rotary_q.input[0] = matmul_q.output[0] + rotary_k.input[0] = matmul_k.output[0] + + # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) + rotary_k.output[0] = rotary_k.name + "_output_0" + + new_node = self.create_mha_node( + matmul_q.input[0], + root_output, + rotary_q, + rotary_k, + matmul_v, + attn_mask, + add_qk_str, + past_k, + past_v, + present_k, + present_v, + ) + if new_node is None: + logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend(qkv_nodes[1:]) + self.nodes_to_remove.extend(v_nodes[:-1]) + self.nodes_to_remove.extend(qk_nodes) + + if k_nodes == k_nodes_1: + self.nodes_to_remove.extend(k_nodes[:-2]) + elif k_nodes == k_nodes_2: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[2]) + self.nodes_to_remove.append(k_nodes[3]) + elif k_nodes == k_nodes_3: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[1]) + self.nodes_to_remove.append(k_nodes[3]) + self.nodes_to_remove.append(k_nodes[4]) + + if q_nodes == q_nodes_1: + self.nodes_to_remove.extend(q_nodes[:-2]) + elif q_nodes == q_nodes_2: + self.nodes_to_remove.append(q_nodes[1]) + self.nodes_to_remove.append(q_nodes[2]) + + self.prune_graph = True + + +class FusionRotaryEmbeddings(Fusion): + def __init__(self, model: OnnxModel): + self.base_name = "RotaryEmbedding" + super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"]) + + # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output. + # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter. + # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used. + def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto): + # Find extra outputs and Constant nodes attached to those outputs + extra_constants, extra_outputs = [], [] + for fn_node in function.node: + if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output: + extra_constants.append(fn_node) + output_index = list(function.output).index(fn_node.output[0]) + extra_outputs.append(rot_emb_node.output[output_index]) + + # Set extra Constant node outputs as initializers + extra_initializers = [] + for extra_constant in extra_constants: + constant_tensorproto = extra_constant.attribute[0].t + constant_tensorproto.name = self.model.create_node_name("Constant") + self.model.add_initializer(constant_tensorproto) + extra_initializers.append(constant_tensorproto.name) + + # Update references of Constant node outputs to initializer references + for extra_output, extra_initializer in zip(extra_outputs, extra_initializers): + nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node)) + for node_to_update in nodes_to_update: + OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer) + + return extra_outputs + + def create_rotary_embeddings_from_function(self, node: NodeProto): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + matmul_path = self.model.match_parent_path( + node, + ["Reshape", "MatMul"], + [0, 0], + ) + if matmul_path is not None: + reshape_node, matmul_node = matmul_path + else: + logger.debug("fuse_rotary_embeddings: failed to match MatMul") + return + + rotary_emb_inputs = [ + matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H) + node.input[1], # position_ids + ] + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_inputs.extend([cos_cache_name, sin_cache_name]) + + rotary_emb_outputs = node.output + if len(rotary_emb_outputs) > 1: + # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers + func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions)) + assert len(func) == 1 + extra_outputs = self.reassign_extra_outputs(node, func[0]) + rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs)) + assert len(rotary_emb_outputs) == 1 + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=rotary_emb_inputs, + outputs=rotary_emb_outputs, + name=rotary_emb_node_name, + interleaved=1, + ) + rotary_emb_node.domain = "com.microsoft" + + self.nodes_to_remove.append(reshape_node) + + return rotary_emb_node + + def create_rotary_embeddings_from_nodes( + self, + root_input: str, + position_ids: str, + cos_slice: str, + sin_slice: str, + output: str, + ): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + # Reshape cos/sin cache from (M, H) to (M, H/2) + head_size = cos_cache.shape[1] + cos_cache = cos_cache[:, : (head_size // 2)] + sin_cache = sin_cache[:, : (head_size // 2)] + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=[root_input, position_ids, cos_cache_name, sin_cache_name], + outputs=[output], + name=rotary_emb_node_name, + interleaved=0, + ) + rotary_emb_node.domain = "com.microsoft" + return rotary_emb_node + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Node is either RotaryEmbedding function or Add + if self.base_name not in node.op_type and node.op_type != "Add": + return + + # Check if node is "RotaryEmbedding nn.Module" exported as a function + # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export) + rotary_emb_node = None + if node.op_type != "Add": + # Verify that function has the correct inputs + if len(node.input) not in {4, 5} or node.input[1] not in { + "pos", + "pos_id", + "position_id", + "pos_ids", + "position_ids", + }: + logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function") + return + + rotary_emb_node = self.create_rotary_embeddings_from_function(node) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove RotaryEmbedding function + self.nodes_to_remove.append(node) + + # Remove RotaryEmbedding function's shape inference stored in value_info + # The new shape will be calculated during symbolic shape inference + old_shape_infer = list( + filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info) + ) + assert len(old_shape_infer) == 1 + self.model.model.graph.value_info.remove(old_shape_infer[0]) + + else: + # Rotary embeddings are defined using the below functions: + # + # def rotate_half(x): + # """Rotates half the hidden dims of the input.""" + # x1 = x[..., : x.shape[-1] // 2] + # x2 = x[..., x.shape[-1] // 2 :] + # return torch.cat((-x2, x1), dim=-1) + # + # def apply_rope(x, cos, sin, position_ids): + # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # x_embed = (x * cos) + (rotate_half(x) * sin) + # return x_embed + + # Check paths for rotate_half(x) + rotate_half_x2_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Transpose"], + [1, 0, 0, 0, 0], + ) + rotate_half_x2_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 0, 0, 1, 0, 0, 0, 0], + ) + if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half") + return + + rotate_half_x1_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Transpose"], + [1, 0, 1, 0], + ) + rotate_half_x1_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 1, 2, 0, 0, 0, 0], + ) + if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half") + return + + if ( + rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name + or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name + or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name + or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name + ): + logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half") + return + + # Check path for x + x_path = self.model.match_parent_path( + node, + ["Mul", "Transpose"], + [0, 0], + ) + if x_path is None: + logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half") + return + + # Check path for sin + sin_path, sin_cache, position_ids = None, "", "" + sin_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 0, 0, 2, 0, 0], + ) + sin_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 0, 0, 2, 0], + ) + sin_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 2, 0, 0], + ) + sin_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 2, 0], + ) + if sin_path_1 is not None: + sin_path = sin_path_1 + sin_cache = sin_path[-4].input[0] + elif sin_path_2 is not None: + sin_path = sin_path_2 + sin_cache = sin_path[-3].input[0] + elif sin_path_3 is not None: + sin_path = sin_path_3 + sin_cache = sin_path[-4].input[0] + position_ids = sin_path[2].input[1] + elif sin_path_4 is not None: + sin_path = sin_path_4 + sin_cache = sin_path[-3].input[0] + position_ids = sin_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for cos + cos_path, cos_cache = None, "" + cos_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 0, 0, 2, 0, 0], + ) + cos_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 0, 0, 2, 0], + ) + cos_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 2, 0, 0], + ) + cos_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 2, 0], + ) + if cos_path_1 is not None: + cos_path = cos_path_1 + cos_cache = cos_path[-4].input[0] + elif cos_path_2 is not None: + cos_path = cos_path_2 + cos_cache = cos_path[-3].input[0] + elif cos_path_3 is not None: + cos_path = cos_path_3 + cos_cache = cos_path[-4].input[0] + position_ids = cos_path[2].input[1] + elif cos_path_4 is not None: + cos_path = cos_path_4 + cos_cache = cos_path[-3].input[0] + position_ids = cos_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for position ids + if position_ids == "": + position_ids_from_sin_path = self.model.match_parent_path( + sin_path[2], + ["Reshape"], + [1], + ) + position_ids_from_cos_path = self.model.match_parent_path( + cos_path[2], + ["Reshape"], + [1], + ) + if ( + position_ids_from_sin_path is None + or position_ids_from_cos_path is None + or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name + ): + logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope") + return + position_ids = position_ids_from_cos_path[0].input[0] + else: + position_ids_from_sin_path = [] + position_ids_from_cos_path = [] + + past_seq_len_path, curr_seq_len_path = None, None + if (sin_path == sin_path_1 and cos_path == cos_path_1) or ( + sin_path == sin_path_3 and cos_path == cos_path_3 + ): + if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name: + logger.debug( + "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache" + ) + return + elif (sin_path == sin_path_2 and cos_path == cos_path_2) or ( + sin_path == sin_path_4 and cos_path == cos_path_4 + ): + if sin_path[-1].name != cos_path[-1].name: + logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache") + return + # Match past sequence length path: past_key --> Shape --> Gather --> Add + past_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape"], + [1, 0], + ) + # Match current sequence length path: transpose_k --> Shape --> Gather --> Add + curr_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape", "Transpose"], + [0, 0, 0], + ) + if ( + past_seq_len_path is None + or curr_seq_len_path is None + or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None + or curr_seq_len_path[-1].op_type != "Transpose" + ): + logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths") + return + else: + logger.debug("fuse_rotary_embeddings: failed to match common cache paths") + + rotary_emb_node = self.create_rotary_embeddings_from_nodes( + rotate_half_x1_path_1[-1].output[0], + position_ids, + cos_cache, + sin_cache, + node.output[0], + ) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove rotary embedding nodes + self.add_nodes_to_remove([node]) + self.add_nodes_to_remove(rotate_half_x1_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x1_path_2[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_2[:-1]) + self.add_nodes_to_remove(x_path[:-1]) + self.add_nodes_to_remove(sin_path) + self.add_nodes_to_remove(cos_path) + self.add_nodes_to_remove(position_ids_from_sin_path[:-1]) + self.add_nodes_to_remove(position_ids_from_cos_path[:-1]) + + if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1: + # In merged HF model, output of Gather in past_seq_len_path is used twice + # for past_key_values.0.key and once for other past_key_values + self.add_nodes_to_remove(past_seq_len_path) + if curr_seq_len_path is not None: + self.add_nodes_to_remove(curr_seq_len_path[:-1]) + + self.increase_counter(self.base_name) + self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name + self.nodes_to_add.append(rotary_emb_node) + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index 11d6b7a8d3cf4..bc32d78eda66c 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -48,22 +48,22 @@ def fuse( input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): - """ - Smplify subgraph like - - (2d_input) - / \ - Shape shape - / \ - Gather(indices=0) Gather(indices=1) - | | - Unsqueeze(axes=0) Unsqueeze(axes=0) - \\ / - Concat - | - - into (2d_input) --> Shape --> - """ + # + # Simplify subgraph like + # + # (2d_input) + # / \ + # Shape shape + # / \ + # Gather(indices=0) Gather(indices=1) + # | | + # Unsqueeze(axes=0) Unsqueeze(axes=0) + # \ / + # Concat + # | + # + # into (2d_input) --> Shape --> + # opset_version = self.model.get_opset_version() inputs = len(concat_node.input) diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py new file mode 100644 index 0000000000000..6f35fa5617a39 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -0,0 +1,141 @@ +import logging +from typing import Dict + +from fusion_base import Fusion +from fusion_skiplayernorm import FusionSkipLayerNormalization +from onnx import helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionSimplifiedLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SimplifiedLayerNormalization", "Mul") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + if node.op_type != "Mul": + return + + sim_ln_nodes = None + # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): + # DD = Pow(D, 2) + # Var = ReduceMean(DD) + # VarEps = Add(Var, epsilon) + # StdDev = Sqrt(VarEps) + # InvStdDev = Div(1, StdDev) + # Normalized = Mul(D, InvStdDev) + # NormalizedScaled = Mul(Normalized, Scale) + + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_1 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [1, 1, 1, 0, 0, 0, 0], + ) + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_2 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], + [1, 1, 1, 0, 0, 0, 0], + ) + + # For LLaMA from Microsoft custom export: + # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1 + # + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_3 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [0, 1, 1, 0, 0, 0, 0], + ) + + # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3 + # + # SimplifiedLayerNorm + # +-----------------------------------------------+ + # | | + # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul + # | + # node + sim_ln_nodes_4 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"], + [0, 1, 1, 0, 0, 0], + ) + + add_node, pow_node = None, None + if sim_ln_nodes_1 is not None: + sim_ln_nodes = sim_ln_nodes_1 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_2 is not None: + sim_ln_nodes = sim_ln_nodes_2 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_3 is not None: + sim_ln_nodes = sim_ln_nodes_3 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_4 is not None: + sim_ln_nodes = sim_ln_nodes_4 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-1] + # Verify that parent input to Pow node is graph_input + if pow_node.input[0] not in self.model.get_graphs_input_names(): + return + else: + return + + layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0 + starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4 + + if self.model.find_constant_input(pow_node, 2.0) != 1: + return + + root_input = pow_node.input[0] + if root_input != sim_ln_nodes[0].input[0]: + return + + i, add_weight = self.model.get_constant_input(add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.warning(f"epsilon value is not expected: {add_weight}") + return + + self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes) + self.nodes_to_remove.append(node) + + normalize_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[root_input, node.input[layernorm_weight_index]], + outputs=[node.output[0]], + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("axis", -1)]) + normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + + +class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + super().fuse(node, input_name_to_nodes, output_name_to_node) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index b4461a2eadb8c..6057b46667fe6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -17,12 +17,31 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). +As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: + +``` +# Before +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + + +# After +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids + # Flatten the past_key_values (no need to flatten for models using multi-query attn) +``` + ### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. ### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) +Note that this will produce two ONNX models whereas the above two options produce one ONNX model. + First, log into the Hugging Face CLI in your terminal: ``` @@ -56,38 +75,81 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b ``` -Export for FP16 +Export for FP32 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda +``` + +Export for FP32 CPU ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu ``` -Export for INT8 +Export for FP16 CUDA ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda +``` + +Export for INT8 CPU (SmoothQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu ``` Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. +Export for INT8 CPU (DynamicQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu +``` + +Export for INT4 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda +``` + +Export for INT4 CPU +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. -Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`. - ### Variants -1. PyTorch (without `torch.compile`), FP32 +1. PyTorch without `torch.compile`, FP32 ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-pt \ + --benchmark-type hf-pt-eager \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -96,10 +158,10 @@ python3 -m models.llama.benchmark \ --auth ``` -2. PyTorch 2.0 (with `torch.compile`), FP16 +2. PyTorch with `torch.compile`, FP16 ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-pt2 \ + --benchmark-type hf-pt-compile \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -112,7 +174,7 @@ python3 -m models.llama.benchmark \ ``` python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ - --hf-ort-model-path ./Llama-2-7b-hf-onnx/ \ + --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -125,7 +187,7 @@ python3 -m models.llama.benchmark \ ``` python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ - --hf-ort-model-path ./llama2-7b-fp16/ \ + --hf-ort-dir-path ./llama2-7b-fp16/ \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -134,24 +196,35 @@ python3 -m models.llama.benchmark \ --auth ``` -5. Optimum + ONNX Runtime, INT8, export via convert_to_onnx +5. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-ort \ - --hf-ort-model-path ./llama2-7b-int8/ \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ - --precision int8 \ + --precision fp32 \ --batch-sizes "1 2" \ --sequence-lengths "8 16" \ - --device cpu \ - --auth + --device cpu ``` -6. ONNX Runtime, FP32, Microsoft custom export +6. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ - --benchmark-type ort \ - --ort-model-path llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda +``` + +7. ONNX Runtime, FP32, convert_to_onnx +``` +python3 -m models.llama.benchmark \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -159,11 +232,11 @@ python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, Microsoft custom export +8. ONNX Runtime, FP16, convert_to_onnx ``` python3 -m models.llama.benchmark \ - --benchmark-type ort \ - --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -174,11 +247,14 @@ python3 -m models.llama.benchmark \ You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination. ### Benchmark All -You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example. +You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. ``` python3 -m models.llama.benchmark_all \ - --hf-ort-model-path ./llama2-7b-fp16/ \ - --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./llama2-7b-fp16/ \ + --ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index d19ed5cc28fed..976de2abc7c57 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -8,10 +8,17 @@ import time import numpy as np +import onnx import psutil import torch from benchmark_helper import setup_logger -from llama_inputs import get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs +from llama_inputs import ( + convert_inputs_for_ort, + get_merged_sample_with_past_kv_inputs, + get_msft_sample_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) from optimum.onnxruntime import ORTModelForCausalLM from torch.profiler import ProfilerActivity, profile, record_function from tqdm import trange @@ -23,8 +30,29 @@ logger = logging.getLogger(__name__) -def get_inputs(args: argparse.Namespace): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: +# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two +def get_ort_model_inputs_len(args, model): + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + return 0 + if args.benchmark_type == "hf-ort": + try: + # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268) + return len(model.inputs_names) + except Exception: + # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54) + return len(model.decoder.input_names) + return len(model.get_inputs()) + + +def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): + init_inputs, iter_inputs = None, None + + # For past_present_share_buffer: + # Set max_seq_len to 2048 for Hugging Face model since that is the default value + # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported + max_seq_len = 2048 + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( args.config, args.target_device, @@ -41,14 +69,95 @@ def get_inputs(args: argparse.Namespace): return_dict=True, ) - elif args.benchmark_type == "ort": + elif args.benchmark_type == "hf-ort": + if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids] + # Using split models in Optimum (e.g. created by Optimum export) + init_inputs = get_sample_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + return_dict=True, + ) + iter_inputs = get_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + else: + # Using merged model in Optimum (e.g. created by convert_to_onnx export) + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + use_fp16=args.use_fp16, + return_dict=True, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + + elif args.benchmark_type == "ort-convert-to-onnx": + # Microsoft export from convert_to_onnx + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + use_fp16=args.use_fp16, + return_dict=True, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + init_inputs = convert_inputs_for_ort( + init_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=0, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + iter_inputs = convert_inputs_for_ort( + iter_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + + elif args.benchmark_type == "ort-msft": # Microsoft export from https://github.com/microsoft/Llama-2-Onnx + split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos] + init_inputs = get_msft_sample_inputs( args.config, args.batch_size, past_seq_len=0, seq_len=args.sequence_length, use_fp16=args.use_fp16, + split_kv=split_kv, ) iter_inputs = get_msft_sample_inputs( args.config, @@ -56,6 +165,25 @@ def get_inputs(args: argparse.Namespace): past_seq_len=args.sequence_length, seq_len=1, use_fp16=args.use_fp16, + split_kv=split_kv, + ) + init_inputs = convert_inputs_for_ort( + init_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=0, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + iter_inputs = convert_inputs_for_ort( + iter_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, ) else: @@ -69,12 +197,14 @@ def get_model(args: argparse.Namespace): start_time, end_time = None, None # There are multiple sources that the model could come from: - # 1) Benchmark LLaMA from unofficial source on Hugging Face - # 2) Benchmark LLaMA from official source on Hugging Face, which requires an authentication token - # 3) Benchmark LLaMA from local download of model - - if args.benchmark_type in {"hf-pt", "hf-pt2"}: - source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name + # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face + # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token + # 3) Benchmark LLaMA-2 from local download of model + # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx) + # 5) Benchmark LLaMA-2 from convert_to_onnx + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name start_time = time.time() model = LlamaForCausalLM.from_pretrained( source, @@ -84,10 +214,10 @@ def get_model(args: argparse.Namespace): ).to(args.target_device) end_time = time.time() - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": model = torch.compile(model) - elif args.benchmark_type in {"hf-ort", "ort"}: + elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}: sess_options = ort.SessionOptions() sess_options.enable_profiling = args.profile if args.verbose: @@ -104,32 +234,33 @@ def get_model(args: argparse.Namespace): decoder_file_name = None decoder_with_past_file_name = None - for filename in os.listdir(args.hf_ort_model_path): + for filename in os.listdir(args.hf_ort_dir_path): if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename: continue - if "decoder_model.onnx" in filename or f"decoder_model_{args.precision}.onnx" in filename: + if "decoder_model" in filename or filename == "model.onnx": + decoder_file_name = filename + if "decoder_with_past_model" in filename: + decoder_with_past_file_name = filename + if "decoder_merged_model" in filename: decoder_file_name = filename - if ( - "decoder_with_past_model.onnx" in filename - or f"decoder_with_past_model_{args.precision}.onnx" in filename - ): decoder_with_past_file_name = filename start_time = time.time() model = ORTModelForCausalLM.from_pretrained( - args.hf_ort_model_path, + args.hf_ort_dir_path, decoder_file_name=decoder_file_name, decoder_with_past_file_name=decoder_with_past_file_name, use_auth_token=args.auth, use_io_binding=(args.device != "cpu"), + use_merged=(True if decoder_file_name == "model.onnx" else None), provider=provider, provider_options=provider_options, session_options=sess_options, ) end_time = time.time() - if args.benchmark_type == "ort": - # Microsoft export from https://github.com/microsoft/Llama-2-Onnx + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: + # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx logger.info(f"Loading model from {args.ort_model_path}") start_time = time.time() model = ort.InferenceSession( @@ -140,7 +271,6 @@ def get_model(args: argparse.Namespace): end_time = time.time() logger.info(f"Loaded model in {end_time - start_time} s") - return model @@ -148,7 +278,7 @@ def time_fn(args, fn, inputs): # Warm up warmup_range = ( range(args.warmup_runs) - if args.benchmark_type == "ort" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.warmup_runs, file=sys.stdout, desc="Warm up") ) @@ -166,7 +296,7 @@ def time_fn(args, fn, inputs): bench_range = ( range(args.num_runs) - if args.benchmark_type == "ort" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: @@ -177,7 +307,7 @@ def time_fn(args, fn, inputs): end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line - if args.benchmark_type != "ort": + if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") latency = (end_time - start_time) / args.num_runs @@ -186,7 +316,7 @@ def time_fn(args, fn, inputs): logger.info(f"Batch Size: {args.batch_size}") logger.info(f"Sequence Length: {args.sequence_length}") logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} qps") + logger.info(f"Throughput: {throughput} tps") return @@ -196,7 +326,7 @@ def profile_fn(args, fn, inputs, inputs_type): prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" filename = None - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: # Profile PyTorch kernels with profile( # noqa: SIM117 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True @@ -267,7 +397,7 @@ def get_logits(inputs): generate_fn = get_logits - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": # Run forward pass once with each set of inputs to process through Dynamo generate_fn(init_inputs) generate_fn(iter_inputs) @@ -280,7 +410,7 @@ def get_logits(inputs): logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) - new_logname = profile_fn(args, generate_fn, iter_inputs, "per-token") + new_logname = profile_fn(args, generate_fn, iter_inputs, "token") if args.benchmark_type == "hf-ort": # Turn profiling off to stop appending to log old_logname = model.decoder_with_past.session.end_profiling() @@ -319,10 +449,24 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding = model.io_binding() + for k, v in inputs.items(): - io_binding.bind_cpu_input(k, v) + if args.past_present_share_buffer: + # Bind all OrtValue inputs to device + io_binding.bind_ortvalue_input(k, v) + else: + io_binding.bind_cpu_input(k, v) + for output in model.get_outputs(): - io_binding.bind_output(output.name) + name = output.name + if args.past_present_share_buffer and ("out" in name or "present" in name): + # Bind present KV cache outputs to OrtValue with buffer sharing + io_binding.bind_ortvalue_output( + name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] + ) + else: + io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) + return io_binding return inputs @@ -350,7 +494,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) ort_iter_inputs = prepare_ort_inputs(iter_inputs) - new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "per-token") + new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log old_logname = model.end_profiling() @@ -371,9 +515,9 @@ def without_io_binding(inputs): def run_inference(args, init_inputs, iter_inputs, model): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: run_hf_inference(args, init_inputs, iter_inputs, model) - elif args.benchmark_type == "ort": + elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: run_ort_inference(args, init_inputs, iter_inputs, model) else: raise Exception(f"Cannot recognize {args.benchmark_type}") @@ -382,7 +526,11 @@ def run_inference(args, init_inputs, iter_inputs, model): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"] + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"], ) parser.add_argument( "-m", @@ -402,20 +550,20 @@ def get_args(): required=True, type=str, default="fp32", - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision for model. For ONNX models, the model's precision should be set before running this script.", ) parser.add_argument( - "--hf-pt-model-path", + "--hf-pt-dir-path", type=str, default="", help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", ) parser.add_argument( - "--hf-ort-model-path", + "--hf-ort-dir-path", type=str, default="", - help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)", + help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)", ) parser.add_argument( "--ort-model-path", @@ -475,15 +623,20 @@ def get_args(): args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) args.device = "cuda" - # Check that model paths have been specified for any benchmarking with ORT + # Check that paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": - assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`" - if args.benchmark_type == "ort": + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: assert args.ort_model_path, "Please specify a path to `--ort-model-path`" args.batch_sizes = args.batch_sizes.split(" ") args.sequence_lengths = args.sequence_lengths.split(" ") + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" + ) + # Check that only one (batch_size, sequence_length) combination is set for profiling if args.profile: assert ( @@ -509,14 +662,27 @@ def main(): setattr(args, "target_device", target_device) # noqa: B010 setattr(args, "use_fp16", use_fp16) # noqa: B010 - # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) + # Get model and model info model = get_model(args) + ort_model_inputs_len = get_ort_model_inputs_len(args, model) + + # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) + if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: + onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) + + use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" + setattr(args, "past_present_share_buffer", use_buffer_share) # noqa: B010 + else: + setattr(args, "past_present_share_buffer", False) # noqa: B010 + + # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 - init_inputs, iter_inputs = get_inputs(args) + init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len) run_inference(args, init_inputs, iter_inputs, model) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 7199c945fe6ba..951b2549368f7 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -43,15 +43,38 @@ def get_args(): ) parser.add_argument( - "--hf-ort-model-path", + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", type=str, + default="", help="Path to folder containing ONNX models for Optimum + ORT benchmarking", ) parser.add_argument( - "--ort-model-path", + "--ort-msft-model-path", + type=str, + default="", + help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx", + ) + + parser.add_argument( + "--ort-convert-to-onnx-model-path", type=str, - help="Path to ONNX model for ORT benchmarking", + default="", + help="Path to ONNX model from convert_to_onnx", ) parser.add_argument( @@ -65,7 +88,7 @@ def get_args(): "--precision", type=str, required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision to run model", ) @@ -138,8 +161,6 @@ def process_log_file(device_id, log_file, base_results): step = "per-token" elif latency_pattern in line: latency_s = float(line[len(latency_pattern) : line.rfind(" ")]) - if step == "prompt": - latency_s /= sequence_length latency_ms = latency_s * 1000 elif throughput_pattern in line: throughput = float(line[len(throughput_pattern) : line.rfind(" ")]) @@ -184,7 +205,7 @@ def save_results(results, filename): "Step", "Latency (s)", "Latency (ms)", - "Throughput (qps)", + "Throughput (tps)", "Memory (GB)", ], ) @@ -194,7 +215,7 @@ def save_results(results, filename): df["Sequence Length"] = df["Sequence Length"].astype("int") df["Latency (s)"] = df["Latency (s)"].astype("float") df["Latency (ms)"] = df["Latency (ms)"].astype("float") - df["Throughput (qps)"] = df["Throughput (qps)"].astype("float") + df["Throughput (tps)"] = df["Throughput (tps)"].astype("float") df["Memory (GB)"] = df["Memory (GB)"].astype("float") df.to_csv(filename, index=False) @@ -226,75 +247,81 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + # Benchmark PyTorch without torch.compile - benchmark_cmd = [ - "python3", - "benchmark.py", - "--benchmark-type", - "hf-pt", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--batch-sizes", - args.batch_sizes, - "--sequence-lengths", - args.sequence_lengths, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - "--auth", - ] - logger.info("Benchmark PyTorch without torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch") - all_results.extend(results) + if args.hf_pt_eager: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager") + all_results.extend(results) # Benchmark PyTorch with torch.compile - benchmark_cmd = [ - "python3", - "benchmark.py", - "--benchmark-type", - "hf-pt2", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--batch-sizes", - args.batch_sizes, - "--sequence-lengths", - args.sequence_lengths, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - "--auth", - ] - logger.info("Benchmark PyTorch with torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch-2") - all_results.extend(results) + if args.hf_pt_compile: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile") + all_results.extend(results) # Benchmark Optimum + ONNX Runtime - if args.hf_ort_model_path: + if args.hf_ort_dir_path: benchmark_cmd = [ - "python3", - "benchmark.py", + "python", + "-m", + "models.llama.benchmark", "--benchmark-type", "hf-ort", - "--hf-ort-model-path", - args.hf_ort_model_path, + "--hf-ort-dir-path", + args.hf_ort_dir_path, "--model-name", args.model_name, "--precision", @@ -316,18 +343,52 @@ def main(): "--auth", ] logger.info("Benchmark Optimum + ONNX Runtime") - results = benchmark(args, benchmark_cmd, "pytorch-ort") + results = benchmark(args, benchmark_cmd, "optimum-ort") + all_results.extend(results) + + # Benchmark Microsoft model in ONNX Runtime + if args.ort_msft_model_path: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "ort-msft", + "--ort-model-path", + args.ort_msft_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + logger.info("Benchmark Microsoft model in ONNX Runtime") + results = benchmark(args, benchmark_cmd, "ort-msft") all_results.extend(results) - # Benchmark ONNX Runtime - if args.ort_model_path: + # Benchmark convert_to_onnx model in ONNX Runtime + if args.ort_convert_to_onnx_model_path: benchmark_cmd = [ - "python3", - "benchmark.py", + "python", + "-m", + "models.llama.benchmark", "--benchmark-type", - "ort", + "ort-convert-to-onnx", "--ort-model-path", - args.ort_model_path, + args.ort_convert_to_onnx_model_path, "--model-name", args.model_name, "--precision", @@ -347,7 +408,7 @@ def main(): "--log-folder", args.log_folder, ] - logger.info("Benchmark ONNX Runtime") + logger.info("Benchmark convert_to_onnx model in ONNX Runtime") results = benchmark(args, benchmark_cmd, "onnxruntime") all_results.extend(results) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index f96347ba67aa6..61d71bc38f4e9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -8,12 +8,16 @@ import onnx import torch from benchmark_helper import Precision, prepare_environment, setup_logger -from llama_inputs import get_sample_inputs, get_sample_with_past_kv_inputs +from convert_generation import replace_mha_with_gqa +from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check from onnx_model import OnnxModel +from optimizer import optimize_model +from packaging import version from transformers import LlamaConfig, LlamaForCausalLM from onnxruntime import quantization as ort_quantization +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer logger = logging.getLogger("") @@ -58,6 +62,33 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li return dynamic_axes +def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str): onnx.save( onnx_model, @@ -152,7 +183,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") -def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): # Dummy values for export batch_size, sequence_length = 2, 8 device = torch.device("cpu") @@ -248,12 +279,206 @@ def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llam logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") +def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): + # Dummy values for export + batch_size, sequence_length, past_sequence_length = 2, 8, 0 + device = torch.device("cpu") + + # Export decoder_merged_model.onnx + decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( + l_config, device, batch_size, sequence_length, past_sequence_length + ) + input_names = [ + "input_ids", + "attention_mask", + "position_ids", + *list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers) + ) + ), + ] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) + temp_dir = tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir.name, "temp.onnx") + torch.onnx.export( + llama, + args=decoder_merged_inputs, + f=temp_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=13, + do_constant_folding=True, + verbose=args.verbose, + ) + + # Check decoder_merged_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model( + onnx_model, + output_path, + f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + ) + del onnx_model + temp_dir.cleanup() + + logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + + +# Optimize the model as FP32 +def optimize_export(config: LlamaConfig, input_path: str, output_path: str): + from fusion_options import FusionOptions + + optimization_options = FusionOptions("gpt2") + + model_opt = optimize_model( + input_path, + model_type="gpt2", + num_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + model_opt.save_model_to_file(output_path, use_external_data_format=True) + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") + remove_existing_model(input_path) + + +def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): + decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") + decoder_with_past_model_fp16_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") + new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] + + logger.info("Converting to float16...") + for fp32_path, fp16_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) + model.convert_float_to_float16(keep_io_types=False) + model = use_group_query_attention(config, model) + model.save_model_to_file(fp16_path, use_external_data_format=True) + del model + logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") + remove_existing_model(fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!") + return new_paths + + +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): + # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes + fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt.prune_graph() + fp16_model_opt.update_graph(allow_remove_graph_inputs=True) + return fp16_model_opt + + +def smooth_quant( + args: argparse.Namespace, + decoder_model_fp32_path: str, + decoder_with_past_model_fp32_path: str, + decoder_model_int8_path: str, + decoder_with_past_model_int8_path: str, +): + from neural_compressor import PostTrainingQuantConfig + from neural_compressor import quantization as intel_quantization + from neural_compressor import set_workspace + from onnx.external_data_helper import load_external_data_for_model + from quant_kv_dataloader import QuantKVDataLoader + + set_workspace(args.nc_workspace) + quantization_config = PostTrainingQuantConfig( + calibration_sampling_size=[args.calibration_sampling_size], + recipes={ + "optypes_to_exclude_output_quant": ["MatMul"], + "smooth_quant": args.smooth_quant, + "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, + }, + op_type_dict={ + "^((?!(MatMul|Gather|Conv)).)*$": { + "weight": {"dtype": ["fp32"]}, + "activation": {"dtype": ["fp32"]}, + } + }, + ) + + # Convert decoder_model.onnx to INT8 + decoder_model_int8 = intel_quantization.fit( + decoder_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args), + ) + load_external_data_for_model( + decoder_model_int8._model, + os.path.split(decoder_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_model_int8._model, + decoder_model_int8_path, + f"{args.model_name}_decoder_model_int8.onnx.data", + ) + del decoder_model_int8 + logger.info( + f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + # Convert decoder_with_past_model.onnx to INT8 + decoder_with_past_model_int8 = intel_quantization.fit( + decoder_with_past_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path), + ) + load_external_data_for_model( + decoder_with_past_model_int8._model, + os.path.split(decoder_with_past_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_with_past_model_int8._model, + decoder_with_past_model_int8_path, + f"{args.model_name}_decoder_with_past_model_int8.onnx.data", + ) + del decoder_with_past_model_int8 + logger.info( + f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!" + ) + remove_existing_model(decoder_with_past_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + logger.info(f"Removing {args.nc_workspace}") + os.system(f"rm -R {args.nc_workspace}") + + +def remove_existing_model(model_path: str): + # Remove ONNX model and its external data + data_path = os.path.join(model_path + ".data") + os.remove(model_path) + os.remove(data_path) + logger.warning(f"Removed {model_path} and {data_path}") + + def remove_existing_files(output_path: str): for filename in os.listdir(output_path): filepath = os.path.join(output_path, filename) if ".onnx" in filename or ".onnx.data" in filename: os.remove(filepath) - logger.warning(f"Removing {filepath}") + logger.warning(f"Removed {filepath}") def get_args(): @@ -288,7 +513,7 @@ def get_args(): required=False, type=Precision, default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8], + choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4], help="Precision to export model in", ) @@ -301,15 +526,51 @@ def get_args(): help="Execution provider to verify parity with", ) + parser.add_argument( + "-id", + "--device-id", + required=False, + type=str, + default="0", + help="Device ID for GPUs", + ) + + parser.add_argument( + "-r", + "--reexport", + required=False, + action="store_true", + help="Re-export models and overwrite existing models in output folder", + ) + parser.set_defaults(reexport=False) + + parser.add_argument( + "--no_merged", + required=False, + action="store_true", + help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.", + ) + parser.set_defaults(no_merged=False) + parser.add_argument( "-q", "--quantization_method", default="", - choices=["smooth_quant", "quantize_dynamic"], - help="Run a specific quantization algorithm. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", + choices=["blockwise", "smooth_quant", "quantize_dynamic"], + help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", ) - smooth_quant_group = parser.add_argument_group("smooth_quant") + blockwise_group = parser.add_argument_group("4-bit quantization") + + blockwise_group.add_argument( + "--block_size", + required=False, + default=32, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)") smooth_quant_group.add_argument( "--smooth_quant_alpha", @@ -352,7 +613,7 @@ def get_args(): help="Workspace to save intermediate files generated by Intel's Neural Compressor package.", ) - quantize_dynamic_group = parser.add_argument_group("quantize_dynamic") + quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)") quantize_dynamic_group.add_argument( "--quantize_embedding_layer", @@ -399,177 +660,193 @@ def get_args(): def main(): + if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__: + # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false + # in that scenario. It can be removed when torch v2.2.0 is released in stable. + logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") + return + args = get_args() setup_logger(args.verbose) prepare_environment(args.input, args.output, args.execution_provider != "cpu") - remove_existing_files(args.output) + if args.reexport: + remove_existing_files(args.output) logger.info(f"Arguments: {args}") # Load model and config use_auth_token = args.input == os.path.join(".") setattr(args, "use_auth_token", use_auth_token) # noqa: B010 - l_config = LlamaConfig.from_pretrained( - args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token - ) - llama = LlamaForCausalLM.from_pretrained( - args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token, use_cache=True - ) + + location = args.model_name if use_auth_token else args.input + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) + llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Export to ONNX - if args.use_dynamo_export: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - else: - run_torchscript_export(args, l_config, llama) - - # Change precision of exported models if not FP32 + # Set model paths for FP32 model decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") decoder_with_past_model_fp32_path = os.path.join( args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" ) + decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." + ) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama) + else: + run_torchscript_merged_export(args, l_config, llama) + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) + # Change precision of exported models from FP32 if args.precision == Precision.FLOAT16: - # Convert decoder_model.onnx to FP16 - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") - model = OnnxModel(onnx.load_model(decoder_model_fp32_path, load_external_data=True)) - model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"]) - model.save_model_to_file(decoder_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True) - del model - - # Convert decoder_with_past_model.onnx to FP16 - decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" - ) - model = OnnxModel(onnx.load_model(decoder_with_past_model_fp32_path, load_external_data=True)) - model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"]) - model.save_model_to_file( - decoder_with_past_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True - ) - del model + new_paths = convert_to_float16(args, l_config, old_paths) elif args.precision == Precision.INT8: decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") decoder_with_past_model_int8_path = os.path.join( args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" ) + decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] if args.quantization_method == "smooth_quant": - from neural_compressor import PostTrainingQuantConfig - from neural_compressor import quantization as intel_quantization - from neural_compressor import set_workspace - from onnx.external_data_helper import load_external_data_for_model - from quant_kv_dataloader import QuantKVDataLoader - - set_workspace(args.nc_workspace) - quantization_config = PostTrainingQuantConfig( - calibration_sampling_size=[args.calibration_sampling_size], - recipes={ - "optypes_to_exclude_output_quant": ["MatMul"], - "smooth_quant": args.smooth_quant, - "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, - }, - op_type_dict={ - "^((?!(MatMul|Gather|Conv)).)*$": { - "weight": {"dtype": ["fp32"]}, - "activation": {"dtype": ["fp32"]}, - } - }, - ) - - # Convert decoder_model.onnx to INT8 - decoder_model_int8 = intel_quantization.fit( - decoder_model_fp32_path, - quantization_config, - calib_dataloader=QuantKVDataLoader(args), - ) - load_external_data_for_model( - decoder_model_int8._model, - os.path.split(decoder_model_int8._model_path)[0], - ) - save_onnx_model( - decoder_model_int8._model, - decoder_model_int8_path, - f"{args.model_name}_decoder_model_int8.onnx.data", - ) - del decoder_model_int8 - - # Convert decoder_with_past_model.onnx to INT8 - decoder_with_past_model_int8 = intel_quantization.fit( - decoder_with_past_model_fp32_path, - quantization_config, - calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path), - ) - load_external_data_for_model( - decoder_with_past_model_int8._model, - os.path.split(decoder_with_past_model_int8._model_path)[0], - ) - save_onnx_model( - decoder_with_past_model_int8._model, - decoder_with_past_model_int8_path, - f"{args.model_name}_decoder_with_past_model_int8.onnx.data", - ) - del decoder_with_past_model_int8 - - logger.info(f"Removing {args.nc_workspace}") - os.system(f"rm -R {args.nc_workspace}") + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) elif args.quantization_method == "quantize_dynamic": logger.warning( "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." ) - # Convert decoder_model.onnx to INT8 - ort_quantization.quantize_dynamic( - decoder_model_fp32_path, - decoder_model_int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) - - # Convert decoder_with_past_model.onnx to INT8 - ort_quantization.quantize_dynamic( - decoder_with_past_model_fp32_path, - decoder_with_past_model_int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") else: raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - # Verify parity on all saved ONNX models + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths) + + decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") + decoder_with_past_model_int4_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + else "fp16" + ) + + # Verify parity on all saved ONNX models for filename in os.listdir(args.output): if ".data" in filename or ".onnx" not in filename: continue - precision = filename[filename.rfind("_") + 1 : filename.find(".onnx")] - parity_cmd = ["-m", f"{original_model_name}", "-o", f"{os.path.join(args.output, filename)}", "-fp", precision] + parity_cmd = [ + "-m", + original_model_name, + "-o", + os.path.join(args.output, filename), + "-ep", + args.execution_provider, + "-id", + args.device_id, + "-fp", + args.precision, + ] if "with_past" in filename: parity_cmd.append("--use_past_kv") - parity_check(parity_cmd) + if "merged" in filename: + parity_cmd.append("--merged") + + try: + parity_check(parity_cmd) + except Exception as e: + logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 6a28498a9ffc9..2652e9f0ca64e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,10 +4,13 @@ import torch from transformers import LlamaConfig +from onnxruntime import OrtValue + # Get position_ids from attention_mask def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: position_ids = position_ids[:, -1].unsqueeze(-1) return position_ids @@ -62,11 +65,41 @@ def get_sample_with_past_kv_inputs( return inputs +# Inputs for all passes with past_key_values +def get_merged_sample_with_past_kv_inputs( + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + past_seq_len: int, + use_fp16: bool = False, + return_dict: bool = False, +): + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 + ) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + + if not return_dict: + return (input_ids, attention_mask, position_ids, past_kv) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_kv, + } + return inputs + + # Create past_key_values def get_sample_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool ): - num_heads, head_size = config.num_attention_heads, config.hidden_size // config.num_attention_heads + num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( @@ -89,31 +122,83 @@ def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tenso # Format PyTorch inputs to ONNX Runtime inputs -def convert_inputs_for_ort(pt_inputs: dict, use_fp16: bool): +def convert_inputs_for_ort( + pt_inputs: dict, + use_fp16: bool, + use_buffer_share: bool = False, + past_seq_len: int = 0, + max_seq_len: int = 2048, + device: str = "", + device_id: int = -1, +): ort_inputs = {} for k, v in pt_inputs.items(): - if k == "past_key_values": + if isinstance(v, np.ndarray): + ort_inputs[k] = v + elif k == "past_key_values": ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + elif k == "attention_mask" and use_fp16 and use_buffer_share: + # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, + # and GQA supports a causal mask by default + + # Instead, add the past sequence length input for GQA + ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) else: ort_inputs[k] = v.detach().cpu().numpy() + + # Enable past-present-share-buffer by using device memory directly + if use_buffer_share and device != "" and device != "cpu" and device_id > -1: + for k, v in ort_inputs.items(): + new_v = v + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + return ort_inputs # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool): +def get_msft_sample_inputs( + config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool +): np_dtype = np.float16 if use_fp16 else np.float32 head_size = config.hidden_size // config.num_attention_heads max_seq_len = 2048 - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + return ort_inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index dadf394440c9a..6bfcb9b4f290d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -1,44 +1,143 @@ import argparse import logging import os +import time from typing import List import numpy as np import torch -from benchmark_helper import create_onnxruntime_session, setup_logger -from llama_inputs import convert_inputs_for_ort, get_sample_inputs, get_sample_with_past_kv_inputs +from benchmark_helper import setup_logger +from llama_inputs import ( + convert_inputs_for_ort, + get_merged_sample_with_past_kv_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) from transformers import LlamaConfig, LlamaForCausalLM +import onnxruntime as ort + logger = logging.getLogger("") -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def get_sequence_lengths(args: argparse.Namespace): + past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) + max_sequence_length = 2048 + return past_sequence_length, curr_sequence_length, max_sequence_length + + +def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity - batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + batch_size = 2 + past_sequence_length, sequence_length, _ = get_sequence_lengths(args) - # Run inference with PyTorch - inputs = ( - get_sample_inputs(config, device, batch_size, sequence_length, return_dict=True) - if not args.use_past_kv - else get_sample_with_past_kv_inputs( - config, device, batch_size, sequence_length, use_fp16=(args.precision == "fp16"), return_dict=True + if args.merged: + inputs = get_merged_sample_with_past_kv_inputs( + config, + args.device, + batch_size, + sequence_length, + past_sequence_length, + use_fp16=args.use_fp16, + return_dict=True, ) - ) + elif args.use_past_kv: + inputs = get_sample_with_past_kv_inputs( + config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + ) + else: + inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) + + return inputs + + +def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): + # Add IO bindings for non-CPU execution providers + io_binding = model.io_binding() + + for k, v in inputs.items(): + if args.use_fp16: + # Bind all OrtValue inputs to device + io_binding.bind_ortvalue_input(k, v) + else: + io_binding.bind_cpu_input(k, v) + + for output in model.get_outputs(): + name = output.name + if args.use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to OrtValue with buffer sharing + io_binding.bind_ortvalue_output( + name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] + ) + else: + io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) + + return io_binding + + +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): + inputs = get_inputs(args, config) + + # Run inference with PyTorch + if args.execution_provider != "cpu": + torch.cuda.synchronize() + start_time = time.time() pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy() + if args.execution_provider != "cpu": + torch.cuda.synchronize() + end_time = time.time() + logger.info(f"PyTorch took {end_time - start_time} s") # Run inference with ORT - inputs = convert_inputs_for_ort(inputs, use_fp16=(args.precision == "fp16")) - ort_model = create_onnxruntime_session( + past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) + inputs = convert_inputs_for_ort( + inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.use_fp16, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, + device=args.execution_provider, + device_id=int(args.device_id), + ) + + ep = f"{args.execution_provider.upper()}ExecutionProvider" + if ep == "CUDAExecutionProvider": + ep = (ep, {"device_id": args.device_id}) + ort_model = ort.InferenceSession( args.onnx_model_path, - args.execution_provider != "cpu", # use_gpu - provider=args.execution_provider, - verbose=args.verbose, + sess_options=ort.SessionOptions(), + providers=[ep], ) - ort_outputs = ort_model.run(None, inputs)[0] + + # Add IO bindings for non-CPU execution providers + if args.execution_provider != "cpu": + io_binding = add_io_bindings(args, ort_model, inputs) + + torch.cuda.synchronize() + start_time = time.time() + ort_model.run_with_iobinding(io_binding) + torch.cuda.synchronize() + end_time = time.time() + + ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + + else: + start_time = time.time() + ort_outputs = ort_model.run(None, inputs) + end_time = time.time() + + ort_outputs = ort_outputs[0] # Get logits + + logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = 1e-3 if args.precision == "fp32" else 1e-2 if args.precision == "fp16" else 1e2 + tol = ( + 2e1 + if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path + else 1e-3 + if args.precision == "fp32" + else 5e-1 + ) parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: @@ -80,6 +179,15 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) + parser.add_argument( + "-id", + "--device-id", + required=False, + type=str, + default="0", + help="Device ID for GPUs", + ) + parser.add_argument( "-v", "--verbose", @@ -96,15 +204,29 @@ def get_args(argv: List[str]): ) parser.set_defaults(use_past_kv=False) + parser.add_argument( + "--merged", + action="store_true", + help="Use merged model (i.e. decoder_merged_model.onnx).", + ) + parser.set_defaults(merged=False) + parser.add_argument( "-fp", "--precision", required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision of model", ) args = parser.parse_args() if argv == [] else parser.parse_args(argv) + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu") + else "fp16" + ) return args @@ -114,19 +236,34 @@ def main(argv: List[str] = []): # noqa: B006 logger.info(f"Arguments: {args}") # Load model and config + setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) llama = LlamaForCausalLM.from_pretrained( location, - torch_dtype=(torch.float16 if args.precision == "fp16" else torch.float32), + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), use_auth_token=use_auth_token, use_cache=True, - ) + ).to(args.device) + + if not args.merged: + verify_parity(args, config, llama) + else: + # Verify prompt generation in merged model (decoder_model.onnx) + args.use_past_kv = False + verify_parity(args, config, llama) - verify_parity(args, config, llama) + # Verify token generation in merged model (decoder_with_past_model.onnx) + args.use_past_kv = True + verify_parity(args, config, llama) if __name__ == "__main__": + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) main() diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt index e9ad937cf14e7..e06c3ada834b0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -1,3 +1,2 @@ -r requirements.txt -torch>=2.0.1 -onnxruntime>=1.16.0 \ No newline at end of file +onnxruntime>=1.17.0 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index 5544abcaa1228..773680937bd21 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt -# Please manually install torch>=2.0.1 with CUDA enabled for the CUDA version installed in your system. +# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.16.0 \ No newline at end of file +onnxruntime-gpu>=1.17.0 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index f843ef4dc5568..4210f36982aef 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,5 +1,6 @@ -git+https://github.com/kunal-vaishnavi/optimum.git@kvaishnavi/llama-add-position-ids -transformers>=4.28.1 +git+https://github.com/huggingface/optimum.git +transformers>=4.33.2 +torch>=2.2.0.dev20230920 onnx>=1.14.0 datasets>=2.8.0 protobuf==3.20.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index e9365becd2cd1..8ff5c8a6e1de0 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -79,24 +79,22 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations. -Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`. - ### Variants -1. PyTorch (without `torch.compile`), FP32 +1. PyTorch without `torch.compile`, FP32 ``` python3 -m models.whisper.benchmark \ - --benchmark-type hf-pt \ + --benchmark-type hf-pt-eager \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ --precision fp32 \ --device cpu ``` -2. PyTorch 2.0 (with `torch.compile`), FP16 +2. PyTorch with `torch.compile`, FP16 ``` python3 -m models.whisper.benchmark \ - --benchmark-type hf-pt2 \ + --benchmark-type hf-pt-compile \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ --precision fp16 \ @@ -109,7 +107,7 @@ python3 -m models.whisper.benchmark \ --benchmark-type hf-ort \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ - --hf-ort-model-path ./whisper-large-v2-onnx/ \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ --precision fp32 \ --device cpu ``` @@ -156,7 +154,9 @@ You can use `benchmark_all.py` to benchmark across various platforms and automat ``` python3 -m models.whisper.benchmark_all \ --audio-path ./whisper-test-audios/ \ - --hf-ort-model-path ./whisper-large-v2-onnx/ \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ --ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \ --model-name openai/whisper-large-v2 \ --precision fp32 \ @@ -169,28 +169,28 @@ Here is a benchmark for an MP3 file with 20.7s of audio. #### FP16 -| Engine | Size | Per-Token Latency | Real-Time Factor | -| ------------- | -------- | ----------------- | ---------------- | -| PyTorch | Tiny | 4.697 ms/token | 0.004697 | -| PyTorch 2.0 | Tiny | 3.406 ms/token | 0.003406 | -| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 | -| PyTorch | Medium | 17.837 ms/token | 0.017387 | -| PyTorch 2.0 | Medium | 18.124 ms/token | 0.018124 | -| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 | -| PyTorch | Large v2 | 23.470 ms/token | 0.023470 | -| PyTorch 2.0 | Large v2 | 23.146 ms/token | 0.023146 | -| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 | +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 4.697 ms/token | 0.004697 | +| PyTorch compile | Tiny | 3.406 ms/token | 0.003406 | +| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 | +| PyTorch eager | Medium | 17.837 ms/token | 0.017387 | +| PyTorch compile | Medium | 18.124 ms/token | 0.018124 | +| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 | +| PyTorch eager | Large v2 | 23.470 ms/token | 0.023470 | +| PyTorch compile | Large v2 | 23.146 ms/token | 0.023146 | +| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 | #### FP32 -| Engine | Size | Per-Token Latency | Real-Time Factor | -| ------------- | -------- | ----------------- | ---------------- | -| PyTorch | Tiny | 6.220 ms/token | 0.006220 | -| PyTorch 2.0 | Tiny | 3.944 ms/token | 0.003944 | -| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 | -| PyTorch | Medium | 19.093 ms/token | 0.019093 | -| PyTorch 2.0 | Medium | 20.459 ms/token | 0.020459 | -| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 | -| PyTorch | Large v2 | 25.844 ms/token | 0.025844 | -| PyTorch 2.0 | Large v2 | 26.397 ms/token | 0.026397 | -| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 | +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 6.220 ms/token | 0.006220 | +| PyTorch compile | Tiny | 3.944 ms/token | 0.003944 | +| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 | +| PyTorch eager | Medium | 19.093 ms/token | 0.019093 | +| PyTorch compile | Medium | 20.459 ms/token | 0.020459 | +| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 | +| PyTorch eager | Large v2 | 25.844 ms/token | 0.025844 | +| PyTorch compile | Large v2 | 26.397 ms/token | 0.026397 | +| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 | diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 283528bea7465..759ae6d14f184 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -24,7 +24,7 @@ def get_inputs(args: argparse.Namespace): - if args.benchmark_type not in {"hf-pt", "hf-pt2", "hf-ort", "ort"}: + if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}: raise Exception("Unable to auto-detect inputs for provided model") def load_via_ffmpeg(): @@ -102,7 +102,7 @@ def get_model(args: argparse.Namespace): # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing) # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing) - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name start_time = time.time() model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -112,7 +112,7 @@ def get_model(args: argparse.Namespace): ).to(args.target_device) end_time = time.time() - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": model = torch.compile(model) elif args.benchmark_type in {"hf-ort", "ort"}: @@ -136,7 +136,7 @@ def get_model(args: argparse.Namespace): start_time = time.time() model = ORTModelForSpeechSeq2Seq.from_pretrained( - args.hf_ort_model_path, + args.hf_ort_dir_path, use_io_binding=(args.device != "cpu"), provider=provider, provider_options=provider_options, @@ -214,7 +214,7 @@ def profile_fn(args, fn, inputs, inputs_type): prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" filename = None - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: # Profile PyTorch kernels with profile( # noqa: SIM117 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True @@ -280,7 +280,7 @@ def gen_and_dec(inputs): generate_fn = gen_and_dec - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": # Run forward pass once with each set of inputs to process through Dynamo generate_fn(inputs) @@ -345,7 +345,7 @@ def prepare_ort_inputs(inputs, warmup=False): for k, v in inputs.items(): io_binding.bind_cpu_input(k, v) for output in model.get_outputs(): - io_binding.bind_output(output.name) + io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id) return io_binding return inputs @@ -407,7 +407,7 @@ def handle_output(output): def run_inference(args, inputs, model): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: run_hf_inference(args, inputs, model) elif args.benchmark_type == "ort": run_ort_inference(args, inputs, model) @@ -419,8 +419,13 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"] + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"], ) + parser.add_argument( "-m", "--model-name", @@ -445,7 +450,7 @@ def parse_args(): help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", ) parser.add_argument( - "--hf-ort-model-path", + "--hf-ort-dir-path", type=str, default="", help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)", @@ -538,7 +543,7 @@ def parse_args(): # Check that model paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": - assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`" + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" if args.benchmark_type == "ort": assert args.ort_model_path, "Please specify a path to `--ort-model-path`" diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index 08d7befec3cfd..071b539ac1899 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -54,7 +54,21 @@ def get_args(): ) parser.add_argument( - "--hf-ort-model-path", + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", type=str, help="Path to folder containing ONNX models for Optimum + ORT benchmarking", ) @@ -136,7 +150,7 @@ def process_log_file(device_id, log_file, base_results): load_audio_latency_s, load_audio_throughput_s = None, None feat_ext_latency_s, feat_ext_throughput_s = None, None - latency_s, per_token_latency_s, per_token_latency_ms = None, None, None + token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None throughput, memory = None, None # Detect metrics @@ -310,73 +324,75 @@ def main(): logger.info(f"Testing {audio_path}...") # Benchmark PyTorch without torch.compile - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "hf-pt", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + hf_decoder_input_ids_cmd - logger.info("Benchmark PyTorch without torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch", audio_file, duration) - all_results.extend(results) + if args.hf_pt_eager: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration) + all_results.extend(results) # Benchmark PyTorch with torch.compile - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "hf-pt2", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + hf_decoder_input_ids_cmd - logger.info("Benchmark PyTorch with torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch-2", audio_file, duration) - all_results.extend(results) + if args.hf_pt_compile: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration) + all_results.extend(results) # Benchmark Optimum + ONNX Runtime - if args.hf_ort_model_path: + if args.hf_ort_dir_path: benchmark_cmd = [ # noqa: RUF005 - "python3", + "python", "-m", "models.whisper.benchmark", "--audio-path", audio_path, "--benchmark-type", "hf-ort", - "--hf-ort-model-path", - args.hf_ort_model_path, + "--hf-ort-dir-path", + args.hf_ort_dir_path, "--model-name", args.model_name, "--precision", @@ -393,14 +409,14 @@ def main(): args.log_folder, ] + hf_decoder_input_ids_cmd logger.info("Benchmark Optimum + ONNX Runtime") - results = benchmark(args, benchmark_cmd, "pytorch-ort", audio_file, duration) + results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration) all_results.extend(results) # Benchmark ONNX Runtime if args.ort_model_path: benchmark_cmd = ( [ # noqa: RUF005 - "python3", + "python", "-m", "models.whisper.benchmark", "--audio-path", diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 995f8c6541b4c..7a69922e67072 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -22,7 +22,9 @@ from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul from fusion_reshape import FusionReshape +from fusion_rotary_attention import FusionRotaryEmbeddings from fusion_shape import FusionShape +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization from fusion_utils import FusionUtils from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper @@ -106,10 +108,36 @@ def fuse_layer_norm(self): fusion = FusionQOrderedLayerNormalization(self) fusion.apply() + def fuse_simplified_layer_norm(self): + fusion = FusionSimplifiedLayerNormalization(self) + fusion.apply() + def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() + def fuse_skip_simplified_layer_norm(self): + fusion = FusionSkipSimplifiedLayerNormalization(self) + fusion.apply() + + def fuse_rotary_embeddings(self): + fusion = FusionRotaryEmbeddings(self) + fusion.apply() + # Remove non-MS domain functions + rot_emb_nodes = list( + filter( + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + ) + ) + non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) + i = 0 + while i < len(self.model.functions): + fn = self.model.functions[i] + if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep: + self.model.functions.remove(fn) + else: + i += 1 + # Only relevant in models with Q-DQ nodes def fuse_qordered_mamtul(self): fusion = FusionQOrderedMatMul(self) @@ -367,6 +395,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + self.fuse_simplified_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() @@ -377,6 +406,10 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + self.fuse_skip_simplified_layer_norm() + + if (options is None) or options.enable_rotary_embeddings: + self.fuse_rotary_embeddings() if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) @@ -442,14 +475,17 @@ def get_fused_operator_statistics(self): "BiasGelu", "GemmFastGelu", "LayerNormalization", + "SimplifiedLayerNormalization", "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "RotaryEmbedding", ] q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) - logger.info(f"Optimized operators:{op_count}") + logger.info(f"Optimized operators: {op_count}") return op_count def is_fully_optimized(self): @@ -461,11 +497,20 @@ def is_fully_optimized(self): attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] - is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) + simple_layer_norm = op_count["SimplifiedLayerNormalization"] + op_count["SkipSimplifiedLayerNormalization"] + is_perfect = ( + (embed > 0) + and (attention > 0) + and (attention == gelu) + and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention)) + ) if layer_norm == 0: logger.debug("Layer Normalization not fused") + if simple_layer_norm == 0: + logger.debug("Simple Layer Normalization not fused") + if gelu == 0: logger.debug("Gelu/FastGelu not fused") diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py index 263857ffbc130..6545bb08cdd5e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py @@ -8,6 +8,7 @@ from fusion_gpt_attention import FusionGptAttention from fusion_gpt_attention_megatron import FusionGptAttentionMegatron from fusion_gpt_attention_no_past import FusionGptAttentionNoPast +from fusion_rotary_attention import FusionRotaryAttention from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -27,6 +28,9 @@ def fuse_attention(self): fusion = FusionGptAttentionMegatron(self, self.num_heads) fusion.apply() + fusion = FusionRotaryAttention(self, self.hidden_size, self.num_heads) + fusion.apply() + def postprocess(self): """ Remove extra reshape nodes. @@ -94,4 +98,4 @@ def postprocess(self): reshape_count += 2 self.prune_graph() - logger.info(f"postprocess: remove Reshape count:{reshape_count}") + logger.info(f"postprocess: remove Reshape count: {reshape_count}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index e9f98e956b760..95f40af3fd746 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention from fusion_base import Fusion -from fusion_skiplayernorm import FusionSkipLayerNormalization +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_utils import NumpyHelper from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel @@ -56,8 +56,8 @@ def create_attention_node( Args: mask_index (str): mask input q_matmul (NodeProto): MatMul node in fully connection for Q - k_matmul (NodeProto): MatMul node in fully connection for K - v_matmul (NodeProto): MatMul node in fully connection for V + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. input (str): input name @@ -687,67 +687,6 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name -class FusionSimplifiedLayerNormalization(Fusion): - def __init__(self, model: OnnxModel): - super().__init__(model, "SimplifiedLayerNormalization", "Mul") - - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): - if node.op_type != "Mul": - return - - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - return - - pow_node = sim_ln_nodes[-2] - if self.model.find_constant_input(pow_node, 2.0) != 1: - return - - root_input = pow_node.input[0] - - mul_node_1 = sim_ln_nodes[0] - if root_input != mul_node_1.input[0]: - return - - second_add_node = sim_ln_nodes[3] - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.warning(f"epsilon value is not expeced: {add_weight}") - return - - self.nodes_to_remove.extend(sim_ln_nodes[:-1]) - - normalize_node = helper.make_node( - "SimplifiedLayerNormalization", - inputs=[root_input, node.input[0]], - outputs=[node.output[0]], - name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))]) - normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name - - -class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): - def __init__(self, model: OnnxModel): - super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") - - def fuse(self, node, input_name_to_nodes, output_name_to_node): - super().fuse(node, input_name_to_nodes, output_name_to_node) - - class T5OnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 5ded027b36f74..00b26c019d4b5 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -103,7 +103,7 @@ def optimize_by_onnxruntime( logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path - model = OnnxModel(load_model(onnx_model_path, format=None, load_external_data=False)) + model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -546,7 +546,7 @@ def main(): if args.input_int32: optimizer.change_graph_inputs_to_int32() - if args.model_type in ["bert", "gpt2"]: + if args.model_type in set(MODEL_TYPES.keys()): if optimizer.is_fully_optimized(): logger.info("The model has been fully optimized.") else: diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f8a5464d8af78..f1fc0c952e8e4 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -28,12 +28,12 @@ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_o self.is_inferred_: bool = False self.dynamic_axis_mapping_: Dict[str, int] = {} - def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 128): + def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200): """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided. Args: dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4} - max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 32. + max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200. Returns: bool: whether all shapes has been inferred or not. diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc new file mode 100644 index 0000000000000..29d8219c162a5 --- /dev/null +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -0,0 +1,632 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest( + const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size, + int num_heads, + int max_sequence_length, + int64_t interleaved, + bool use_float16, + bool disable_cpu, + bool disable_cuda) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + // interleaved : 0 = false, 1 = true + + int hidden_size = num_heads * head_size; + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector pos_dims; + std::vector cache_dims = {max_sequence_length, head_size / 2}; + + assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); + assert(max_sequence_length >= sequence_length); + if (position_ids.size() == 1) { + pos_dims = {1}; + } else { + pos_dims = {batch_size, sequence_length}; + } + + std::string op_type = "RotaryEmbedding"; + std::vector> execution_providers; + + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (!use_float16 && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", interleaved); + + if (!use_float16) { + test.AddInput("input", input_dims, input_data); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, cos_cache); + test.AddInput("sin_cache", cache_dims, sin_cache); + test.AddOutput("output", input_dims, output_data); + } else { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); + test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); + test.AddOutput("output", input_dims, ToFloat16(output_data)); + } + test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunTests(const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size = 0, + int num_heads = 0, + int max_sequence_length = 0, + int64_t interleaved = 0, + bool use_float16 = true) { + // FP32 test for CPU + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + true /* disable_cuda */); + + // FP32 test for CUDA + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + false /* disable_cuda */); + + // FP16 test for CUDA + if (use_float16) { + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + true, /* use_fp16 */ + true, /* disable_cpu */ + false /* disable_cuda*/); + } +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f, + -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f, + -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f, + -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f, + -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f, + 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f, + -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f, + -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f, + 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f, + -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f, + -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f, + -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f, + -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f, + -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f, + -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f, + -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f, + 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f, + -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f, + 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f, + -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f, + 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f, + -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f, + -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f, + 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f, + 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f, + 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f, + 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f, + 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f, + -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f, + 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f, + -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f, + 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f, + -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f, + -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f, + -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f, + 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f, + 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f, + -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f, + -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f, + 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f, + -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f, + 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f, + -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f, + 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f, + -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f, + -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f, + -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f, + -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f, + -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f, + -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f, + 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f, + -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f, + 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f, + -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f, + -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f, + -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f, + 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f, + 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f, + 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f, + -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f, + -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f, + 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f, + 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f, + 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f, + 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f, + -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f, + -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f, + -2.0803f, 0.4684f, 0.6009f, -1.4160f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f, + -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f, + -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f, + -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f, + 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f, + -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f, + 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f, + -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f, + -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f, + -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f, + -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f, + -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f, + 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f, + -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f, + -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f, + 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f, + -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f, + 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f, + 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f, + 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f, + -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f, + 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f, + -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f, + 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f, + -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f, + 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f, + 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f, + -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f, + 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f, + -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f, + 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f, + -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f, + -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f, + -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f, + 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f, + 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f, + -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f, + -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f, + -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f, + -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f, + 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f, + 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f, + 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f, + 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f, + -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f, + -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f, + -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f, + -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f, + -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f, + -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f, + -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f, + 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f, + -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f, + -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f, + -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f, + -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f, + -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f, + 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f, + -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f, + -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f, + 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f, + 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f, + 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f, + 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f, + -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f, + 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f, + -1.8002f, 0.9892f, 0.3619f, -1.4522f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f, + 0.1411f, 0.1388f, 0.0065f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py new file mode 100644 index 0000000000000..b17ae5f69aff5 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -0,0 +1,450 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +# Notes +# 1) The test cases in this file are for the following LLaMA-2 scenarios: +# - Microsoft rotary embeddings with interleaved = True +# - Prompt generation +# - Token generation +# - Hugging Face rotary embeddings (equal to Microsoft rotary embeddings with interleaved = False) +# - Prompt generation +# - Token generation +# +# 2) Shapes of position ids in ORT and `interleaved` for LLaMA-2 scenarios: +# - Microsoft model: When shape of position ids == (1), interleaved = True +# - Hugging Face model: When shape of position ids == (batch_size, sequence_length), interleaved = False + + +import unittest +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +from onnx import TensorProto, helper + +import onnxruntime as ort + + +class SampleInputConfig: + def __init__( + self, + batch_size=2, + sequence_length=8, + num_heads=4, + head_size=6, + max_sequence_length=16, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = self.num_heads * self.head_size + self.max_sequence_length = max_sequence_length + + +# LLaMA Hugging Face model +class LlamaHFRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cpu"): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def get_cos_sin_cache(self, seq_len=None, device=torch.device("cpu"), dtype=torch.float32): # noqa: B008 + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), + ) + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope_bnsh(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def apply_rope_bsnh(self, x, cos, sin, position_ids): + # Two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze() # [seq_len, dim] + sin = sin.squeeze() # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def forward(self, x, cos, sin, pos_ids, x_format="bnsh"): + if x_format == "bnsh": + return self.apply_rope_bnsh(x, cos, sin, pos_ids) + return self.apply_rope_bsnh(x, cos, sin, pos_ids) + + +# LLaMA Microsoft model +class LlamaMSRotaryEmbedding(nn.Module): + def __init__(self, hidden_size, num_heads, max_sequence_length): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.max_sequence_length = max_sequence_length + + def get_cos_sin_cache(self, theta=10000.0, head_scale=1.0, device="cpu", dtype=torch.float32): + hidden_size = self.hidden_size + n_heads = self.num_heads + max_seq_len = self.max_sequence_length + + # Precalculate rotary matrices for the sequence + # According to "Attention Is All You Need", theta_i = 10000 ^ (2 * (i - 1)/dim), i in [1, 2, ..., dim//2] + head_dim = head_scale * hidden_size / n_heads + + pos = torch.arange(0, 2 * (head_dim // 2), step=2, device=device, dtype=dtype) + freqs = 1.0 / (theta ** (pos / head_dim)) + + idx = torch.arange(max_seq_len, device=freqs.device) + freqs = torch.outer(idx, freqs) + + cos = torch.reshape(torch.cos(freqs), [1, max_seq_len, 1, -1]) + sin = torch.reshape(torch.sin(freqs), [1, max_seq_len, 1, -1]) + dtype = torch.get_default_dtype() + + return cos.to(dtype), sin.to(dtype) + + def rotate_tensor( + self, + x: torch.Tensor, # BxSxNxH + cos: torch.Tensor, # 1xSx1x(H/2) + sin: torch.Tensor, # 1xSx1x(H/2) + pos: int, + interleaved: bool, + ): + # Dimension of x is [batch_size, seq_len, n_heads, head_dim] + rot_dim = 2 * cos.shape[3] + + # Dolly requires partial rotation + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x[:, :, :, 0:half] + x2 = x[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + cos_x = cos[:, pos : pos + seq_len, :, :] + sin_x = sin[:, pos : pos + seq_len, :, :] + + # cos_x: (1, S, 1, H/2) + # sin_x: (1, S, 1, H/2) + # x1: (B, S, N, H/2) + # x2: (B, S, N, H/2) + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +class TestLlamaRotaryEmbedding(unittest.TestCase): + def setUp(self): + self.config = SampleInputConfig() + self.llama_hf = LlamaHFRotaryEmbedding(self.config.head_size, self.config.max_sequence_length) + self.llama_ms = LlamaMSRotaryEmbedding( + self.config.hidden_size, self.config.num_heads, self.config.max_sequence_length + ) + + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) + torch.set_printoptions(sci_mode=False) + + def create_onnx_graph(self, x_shape, pos_shape, cos, sin, interleaved): + inputs = [ + helper.make_tensor_value_info( + name="input", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + helper.make_tensor_value_info( + name="position_ids", + elem_type=TensorProto.INT64, + shape=list(pos_shape), + ), + ] + outputs = [ + helper.make_tensor_value_info( + name="output", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + ] + + initializers = [ + helper.make_tensor( + name="cos_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(cos).shape), + vals=cos.flatten().tolist(), + ), + helper.make_tensor( + name="sin_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(sin).shape), + vals=sin.flatten().tolist(), + ), + ] + nodes = [ + helper.make_node( + op_type="RotaryEmbedding", + inputs=["input", "position_ids", "cos_cache", "sin_cache"], + outputs=["output"], + interleaved=interleaved, + name="RotaryEmbedding_0", + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes=nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model.SerializeToString() + + def get_eps(self): + eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] + return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) + + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + eps = self.get_eps() + for ep in eps: + sess = ort.InferenceSession(onnx_graph, providers=[ep]) + output_ort = sess.run(None, inputs_ort)[0] + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) + + # apply_rope(x_bnsh) == apply_rope(x_bsnh).transpose(1,2) + def test_hf_bnsh_and_hf_bsnh(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + + x_bnsh_after_rope = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + x_bsnh_after_rope = self.llama_hf( + x_bnsh.transpose(1, 2), cos_hf.transpose(1, 2), sin_hf.transpose(1, 2), pos_hf, "bsnh" + ) # output is BxSxNxH + + self.assertTrue(torch.allclose(x_bnsh_after_rope, x_bsnh_after_rope.transpose(1, 2))) + + # HF rotary == MSFT rotary non-interleaved + def test_hf_rotary_and_msft_rotary_noninterleaved(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = ( + self.llama_ms(x_bsd, cos_ms, sin_ms, pos_ms, interleaved=False).detach().cpu().numpy() # output is BxSxNxH + ) + + # Compare caches as Mx(H/2) + self.assertTrue( + torch.allclose(self.llama_hf.cos_cached.squeeze()[:, : (self.config.head_size // 2)], cos_ms.squeeze()) + ) + self.assertTrue( + torch.allclose(self.llama_hf.sin_cached.squeeze()[:, : (self.config.head_size // 2)], sin_ms.squeeze()) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(output_hf.transpose(1, 2).detach().cpu().numpy(), output_ms)) + + # Prompt step, interleaved = true, pos ids shape = (1) + def test_msft_prompt_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Token generation step, interleaved = true, pos ids shape = (1) + def test_msft_token_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 2 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Prompt step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_prompt_rotary_batched_pos_ids(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Token generation step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_token_rotary_batched_pos_ids(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1) + def test_hf_prompt_rotary_one_pos_id(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1) + def test_hf_token_rotary_one_pos_id(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py new file mode 100644 index 0000000000000..7bca48c29019e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryEmbeddingFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options, opt_level=0) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size]), + helper.make_tensor( + "pos_ids_new_shape", + TensorProto.FLOAT, + [2], + np.array([self.batch_size, self.sequence_length], dtype=np.int64), + ), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + helper.make_tensor("int_max", TensorProto.FLOAT, [1], np.array([sys.maxsize], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str = ""): + inputs = [ + helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + ] + if model_type in {"past", "merged"}: + # Input will be removed in fused model since it's not used in RotaryEmbedding. + # We create this input so that we can check the `past_seq_len` path during + # RotaryEmbedding fusion. + inputs.append( + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ) + ) + # Dummy input to test nodes for `curr_seq_len` path + if model_type != "": + inputs.append( + helper.make_tensor_value_info( + "curr_key", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ) + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.sequence_length, self.head_size], + ) + ] + if model_type in {"merged"}: + # Dummy output to test that nodes for `past_seq_len` path are not removed for merged model + outputs.append(helper.make_tensor_value_info("past_seq_len_plus_zero", TensorProto.FLOAT, [1])) + return inputs, outputs + + def create_fused_model(self, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs() + + rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[inputs[0].name, inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[outputs[0].name], + name="RotaryEmbedding_0", + interleaved=int(interleaved), + ) + + graph = helper.make_graph( + nodes=[rope_node], + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_cache_path(self, model_type: str, use_redundant_squeeze_ops: bool): + # Create position ids path + reshape_node = helper.make_node( + "Reshape", + inputs=["position_ids", "pos_ids_new_shape"], + outputs=["pos_ids_reshaped"], + name="Reshape_0", + ) + pos_ids_nodes = [reshape_node] + + # Create cos path + cos_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["cos_unsqueeze"], + name="Unsqueeze_2", + ) + cos_slice_node = helper.make_node( + "Slice", + inputs=["cos_cache", "zero", "cos_unsqueeze", "two", "one"], + outputs=["cos_sliced"], + name="Slice_2", + ) + cos_nodes = [cos_init_unsqueeze_node, cos_slice_node] + + if use_redundant_squeeze_ops: + # These two nodes are eliminated by this transformers PR: https://github.com/huggingface/transformers/pull/26162 + cos_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["cos_sliced", "zero"], + outputs=["cos_squeeze_1"], + name="Squeeze_0", + ) + cos_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["cos_squeeze_1", "zero"], + outputs=["cos_squeeze_2"], + name="Squeeze_1", + ) + cos_nodes.extend([cos_squeeze_1_node, cos_squeeze_2_node]) + + cos_gather_node = helper.make_node( + "Gather", + inputs=["cos_squeeze_2" if use_redundant_squeeze_ops else "cos_sliced", "pos_ids_reshaped"], + outputs=["cos_indexed"], + name="Gather_1", + ) + cos_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["cos_indexed", "one"], + outputs=["cos"], + name="Unsqueeze_3", + ) + cos_nodes.extend([cos_gather_node, cos_end_unsqueeze_node]) + + # Create sin path + sin_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["sin_unsqueeze"], + name="Unsqueeze_4", + ) + sin_slice_node = helper.make_node( + "Slice", + inputs=["sin_cache", "zero", "sin_unsqueeze", "two", "one"], + outputs=["sin_sliced"], + name="Slice_3", + ) + sin_nodes = [sin_init_unsqueeze_node, sin_slice_node] + + if use_redundant_squeeze_ops: + sin_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["sin_sliced", "zero"], + outputs=["sin_squeeze_1"], + name="Squeeze_2", + ) + sin_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["sin_squeeze_1", "zero"], + outputs=["sin_squeeze_2"], + name="Squeeze_3", + ) + sin_nodes.extend([sin_squeeze_1_node, sin_squeeze_2_node]) + + sin_gather_node = helper.make_node( + "Gather", + inputs=["sin_squeeze_2" if use_redundant_squeeze_ops else "sin_sliced", "pos_ids_reshaped"], + outputs=["sin_indexed"], + name="Gather_2", + ) + sin_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["sin_indexed", "one"], + outputs=["sin"], + name="Unsqueeze_5", + ) + sin_nodes.extend([sin_gather_node, sin_end_unsqueeze_node]) + + # Create beginning nodes before cos and sin paths + + # Create curr seq len path + curr_transpose_node = helper.make_node( + "Transpose", + inputs=["curr_key"], + outputs=["curr_key_transposed"], + name="Transpose_curr", + perm=[0, 2, 1, 3], + ) + curr_shape_node = helper.make_node( + "Shape", + inputs=["curr_key_transposed"], + outputs=["curr_shape"], + name="Shape_curr", + ) + curr_gather_node = helper.make_node( + "Gather", + inputs=["curr_shape", "two"], + outputs=["curr_seq_len" if model_type in {"past", "merged"} else "new_seq_len"], + name="Gather_curr", + ) + beginning_nodes = [curr_transpose_node, curr_shape_node, curr_gather_node] + + if model_type in {"past", "merged"}: + # Create past seq len path + past_shape_node = helper.make_node( + "Shape", + inputs=["past_key"], + outputs=["past_shape"], + name="Shape_past", + ) + past_gather_node = helper.make_node( + "Gather", + inputs=["past_shape", "two"], + outputs=["past_seq_len"], + name="Gather_past", + ) + add_node = helper.make_node( + "Add", + inputs=["curr_seq_len", "past_seq_len"], + outputs=["new_seq_len"], + name="Add_1", + ) + beginning_nodes.extend([past_shape_node, past_gather_node, add_node]) + + if model_type == "merged": + dummy_node = helper.make_node( + "Add", + inputs=["past_seq_len", "zero"], + outputs=["past_seq_len_plus_zero"], + name="Add_dummy_node", + ) + beginning_nodes.append(dummy_node) + + return pos_ids_nodes + cos_nodes + sin_nodes + beginning_nodes + + def create_apply_rope_path(self): + start_node = helper.make_node( + "Transpose", + inputs=["input_0"], + outputs=["x"], + name="Transpose_0", + perm=[0, 2, 1, 3], + ) + + # Calculate x_half_shape + shape_node = helper.make_node( + "Shape", + inputs=["x"], + outputs=["x_shape"], + name="Shape_0", + ) + gather_node = helper.make_node( + "Gather", + inputs=["x_shape", "three"], + outputs=["x_last_idx_shape"], + name="Gather_0", + axis=0, + ) + div_node = helper.make_node( + "Div", + inputs=["x_last_idx_shape", "two"], + outputs=["x_half_shape"], + name="Div_0", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_0"], + name="Unsqueeze_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_1"], + name="Unsqueeze_1", + ) + x_half_shape_nodes = [shape_node, gather_node, div_node, unsqueeze_0_node, unsqueeze_1_node] + + # Calculate rotate_half + x1_node = helper.make_node( + "Slice", + inputs=["x", "zero", "x_half_shape_0", "three", "one"], + outputs=["x1"], + name="Slice_0", + ) + x2_node = helper.make_node( + "Slice", + inputs=["x", "x_half_shape_1", "int_max", "three", "one"], + outputs=["x2"], + name="Slice_1", + ) + neg_node = helper.make_node( + "Neg", + inputs=["x2"], + outputs=["x2_neg"], + name="Neg_0", + ) + x_rotate_half_node = helper.make_node( + "Concat", + inputs=["x2_neg", "x1"], + outputs=["x_rotate_half"], + name="Concat_0", + axis=-1, + ) + rotate_half_nodes = [x1_node, x2_node, neg_node, x_rotate_half_node] + + # Calculate x_embed + x_cos_node = helper.make_node( + "Mul", + inputs=["x", "cos"], + outputs=["x_cos"], + name="Mul_0", + ) + x_sin_node = helper.make_node( + "Mul", + inputs=["x_rotate_half", "sin"], + outputs=["x_rotate_half_sin"], + name="Mul_1", + ) + end_node = helper.make_node( + "Add", + inputs=["x_cos", "x_rotate_half_sin"], + outputs=["output_0"], + name="Add_0", + ) + x_embed_nodes = [start_node, x_cos_node, x_sin_node, end_node] + + return x_half_shape_nodes + rotate_half_nodes + x_embed_nodes + + def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: List[TensorProto]): + apply_rope_nodes = self.create_apply_rope_path() + cache_nodes = self.create_cache_path(model_type, use_redundant_squeeze_ops) + inputs, outputs = self.create_inputs_and_outputs(model_type) + + graph = helper.make_graph( + nodes=apply_rope_nodes + cache_nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=13) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, interleaved: bool, model_type: str): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + use_redundant_squeeze_ops = True + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(original_model_filename) + + use_redundant_squeeze_ops = False + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # Hugging Face's `decoder_model.onnx` + def test_hf_decoder_model(self): + interleaved = False # HF model does not use interleaving + model_type = "no_past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_with_past_model.onnx` + def test_hf_decoder_with_past_model(self): + interleaved = False # HF model does not use interleaving + model_type = "past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_merged.onnx` + def test_hf_decoder_merged_model(self): + interleaved = False # HF model does not use interleaving + model_type = "merged" + self.check_models(interleaved, model_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py new file mode 100644 index 0000000000000..fedba2a25dfc2 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -0,0 +1,795 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import NodeProto, TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryAttentionFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + model_type = "gpt2" + options = FusionOptions(model_type) + optimized_model = optimize_model( + original_model_path, + model_type, + self.num_heads, + self.hidden_size, + optimization_options=options, + opt_level=0, + ) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, fused_model: bool = False): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("q_weight", [self.hidden_size, self.hidden_size]), + float_tensor("k_weight", [self.hidden_size, self.hidden_size]), + float_tensor("v_weight", [self.hidden_size, self.hidden_size]), + float_tensor("o_weight", [self.hidden_size, self.hidden_size]), + helper.make_tensor( + "sqrt_head_size", TensorProto.FLOAT, [1], np.array([np.sqrt(self.head_size)], dtype=np.float32) + ), + helper.make_tensor("neg_int_max", TensorProto.FLOAT, [1], np.array([-sys.maxsize - 1], dtype=np.int64)), + helper.make_tensor("num_heads", TensorProto.FLOAT, [1], np.array([self.num_heads], dtype=np.float32)), + helper.make_tensor("head_size", TensorProto.FLOAT, [1], np.array([self.head_size], dtype=np.float32)), + helper.make_tensor("hidden_size", TensorProto.FLOAT, [1], np.array([self.hidden_size], dtype=np.float32)), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str): + attn_mask_size = [self.batch_size, self.sequence_length] + if model_type == "llama2_msft": + attn_mask_size.append(self.sequence_length) + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), + ] + if model_type in {"past", "merged", "llama2_msft"}: + inputs.extend( + [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + ] + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + ] + return inputs, outputs + + def create_matmul_nodes(self, is_fused: bool, model_type: str): + q_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "q_weight"], + outputs=["q_out" if is_fused or model_type == "llama2_msft" else "q_matmul_out"], + name="Q_MatMul", + ) + + k_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "k_weight"], + outputs=["k_out" if is_fused or model_type == "llama2_msft" else "k_matmul_out"], + name="K_MatMul", + ) + + v_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "v_weight"], + outputs=["v_out"], + name="V_MatMul", + ) + + return [q_matmul_node, k_matmul_node, v_matmul_node] + + def create_rotary_embeddings( + self, + is_fused: bool, + model_type: str, + interleaved: bool, + inputs: List[TensorProto], + initializers: List[TensorProto], + ): + def get_first_rope_input(node_type: str): + if is_fused or model_type == "llama2_msft": + # q_out/k_out + return f"{node_type}_out" + if model_type in {"no_past", "past", "merged"}: + if node_type == "k": + return "k_before_rope" + return "q_before_rope" + return "" + + def get_first_rope_output(node_type: str): + if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if node_type == "q": + return "q_rope" + return "k_rope" + if model_type in {"no_past"}: + if node_type == "k": + return "present_key" + return "q_rope" + return "" + + q_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("q"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("q")], + name="Q_RotaryEmbedding", + interleaved=int(interleaved), + ) + + k_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("k"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("k")], + name="K_RotaryEmbedding", + interleaved=int(interleaved), + ) + + return [q_rope_node, k_rope_node] + + def create_q_path(self, model_type: str): + if model_type == "llama2_msft": + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_rope"], + outputs=["q_transposed"], + name="Transpose_q", + perm=[0, 2, 1, 3], + ) + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_transposed", "concat_q_extra_out"], + outputs=["q"], + name="Reshape_q", + ) + return [transpose_q_node, reshape_q_node] + + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_matmul_out", "concat_q_extra_out"], + outputs=["q_reshaped"], + name="Reshape_q", + ) + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_reshaped"], + outputs=["q_before_rope"], + name="Transpose_q", + ) + return [reshape_q_node, transpose_q_node] + + def create_k_path_llama2_msft(self): + # Create k cache slicing path + k_cache_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["k_pos_id"], + ) + k_cache_slice_node = helper.make_node( + "Slice", + inputs=["past_key", "zero", "k_pos_id", "two", "one"], + outputs=["k_cache_sliced"], + ) + # Create k path + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_rope"], + outputs=["k_rope_transposed"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + concat_k_node = helper.make_node( + "Concat", + inputs=["k_cache_sliced", "k_rope_transposed"], + outputs=["present_key"], + name="Concat_k", + axis=2, + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["present_key_transposed"], + name="Transpose_k_2", + perm=[0, 2, 3, 1], + ) + reshape_k_node = helper.make_node( + "Reshape", + inputs=["present_key_transposed", "concat_k_extra_out"], + outputs=["k"], + name="Reshape_k", + ) + return [ + k_cache_unsqueeze_node, + k_cache_slice_node, + transpose_k_1_node, + concat_k_node, + transpose_k_2_node, + reshape_k_node, + ] + + def create_k_path_hf(self, model_type: str): + reshape_k_node = helper.make_node( + "Reshape", + inputs=["k_matmul_out", "concat_k_extra_out"], + outputs=["k_reshaped"], + name="Reshape_k", + ) + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_reshaped"], + outputs=["k_before_rope"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + k_nodes = [reshape_k_node, transpose_k_1_node] + + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 + + def create_k_path(self, model_type: str): + if model_type == "llama2_msft": + return self.create_k_path_llama2_msft() + return self.create_k_path_hf(model_type) + + def create_attn_mask_path_llama2_msft(self): + x_shape_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape"], + name="Shape_input", + ) + x_get_seq_len_node = helper.make_node( + "Gather", + inputs=["input_0_shape", "one"], + outputs=["input_0_seq_len"], + name="Gather_input", + axis=0, + ) + x_new_seq_len_node = helper.make_node( + "Add", + inputs=["position_ids", "input_0_seq_len"], + outputs=["new_seq_len"], + name="Add_mask", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_mask_0_out"], + name="Unsqueeze_mask_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_1_out"], + name="Unsqueeze_mask_1", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_2_out"], + name="Unsqueeze_mask_2", + ) + slice_mask_1_node = helper.make_node( + "Slice", + inputs=["attn_mask", "unsqueeze_mask_0_out", "unsqueeze_mask_1_out", "one", "one"], + outputs=["slice_mask_1_out"], + name="Slice_mask_1", + ) + slice_mask_2_node = helper.make_node( + "Slice", + inputs=["slice_mask_1_out", "zero", "unsqueeze_mask_2_out", "two", "one"], + outputs=["slice_mask_2_out"], + name="Slice_mask_2", + ) + concat_mask_node = helper.make_node( + "Concat", + inputs=["slice_mask_2_out" for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + return [ + x_shape_node, + x_get_seq_len_node, + x_new_seq_len_node, + unsqueeze_0_node, + unsqueeze_1_node, + unsqueeze_2_node, + slice_mask_1_node, + slice_mask_2_node, + concat_mask_node, + ] + + def create_attn_mask_path_hf(self, model_type: str): + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["attn_mask", "one"], + outputs=["unsqueeze_1_mask_out"], + name="Unsqueeze_1_mask", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["unsqueeze_1_mask_out", "two"], + outputs=["unsqueeze_2_mask_out"], + name="Unsqueeze_2_mask", + ) + expand_node = helper.make_node( + "Expand", + inputs=["unsqueeze_2_mask_out", "zero"], + outputs=["expand_out"], + name="Expand_mask", + ) + cast_node = helper.make_node( + "Cast", + inputs=["expand_out"], + outputs=["cast_out"], + name="Cast_mask", + to=TensorProto.FLOAT, + ) + sub_node = helper.make_node( + "Sub", + inputs=["one", "cast_out"], + outputs=["sub_out"], + name="Sub_mask", + ) + where_node = helper.make_node( + "Where", + inputs=["zero", "neg_int_max", "sub_out"], + outputs=["where_out" if model_type != "past" else "attn_mask_out"], + name="Where_mask", + ) + attn_mask_nodes = [unsqueeze_1_node, unsqueeze_2_node, expand_node, cast_node, sub_node, where_node] + + if model_type == "past": + return attn_mask_nodes + + add_node = helper.make_node( + "Add", + inputs=["where_out", "zero"], + outputs=["attn_mask_out"], + name="Add_mask", + ) + return attn_mask_nodes + [add_node] # noqa: RUF005 + + def create_attn_mask_path(self, is_fused: bool, model_type: str): + if model_type == "llama2_msft": + attn_mask_nodes = self.create_attn_mask_path_llama2_msft() + if is_fused: + attn_mask_nodes.pop() + attn_mask_nodes[-1].output[0] = "attn_mask_out" + return attn_mask_nodes + + attn_mask_nodes = self.create_attn_mask_path_hf(model_type) + if is_fused: + new_output_name = "attn_mask_out_mask" + attn_mask_nodes[-1].output[0] = new_output_name + concat_mask_node = helper.make_node( + "Concat", + inputs=[new_output_name for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + attn_mask_nodes.append(concat_mask_node) + return attn_mask_nodes + + def create_qk_path(self, model_type: str): + matmul_qk_node = helper.make_node( + "MatMul", + inputs=["q" if model_type == "llama2_msft" else "q_rope", "k"], + outputs=["qk"], + name="MatMul_q_k", + ) + div_node = helper.make_node( + "Div", + inputs=["qk", "sqrt_head_size"], + outputs=["qk_div"], + name="Div_0", + ) + add_node = helper.make_node( + "Add", + inputs=["qk_div", "attn_mask_out"], + outputs=["qk_plus_mask"], + name="Add_0", + ) + softmax_node = helper.make_node( + "Softmax", + inputs=["qk_plus_mask"], + outputs=["softmax_out"], + name="Softmax_0", + ) + return [matmul_qk_node, div_node, add_node, softmax_node] + + def create_v_path(self, model_type: str): + reshape_v_1_node = helper.make_node( + "Reshape", + inputs=["v_out", "concat_v_1_extra_out"], + outputs=["reshape_v_1_out"], + name="Reshape_v_1", + ) + transpose_v_1_node = helper.make_node( + "Transpose", + inputs=["reshape_v_1_out"], + outputs=["transpose_v_1_out" if model_type != "no_past" else "present_value"], + name="Transpose_v_1", + ) + v_nodes = [reshape_v_1_node, transpose_v_1_node] + + if model_type == "no_past": + return v_nodes + + if model_type in {"past", "merged"}: + concat_v_node = helper.make_node( + "Concat", + inputs=["past_value", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + return v_nodes + [concat_v_node] # noqa: RUF005 + + # Create extra nodes for `position_ids` + unsqueeze_v_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_v_out"], + name="Unsqueeze_v", + ) + slice_v_node = helper.make_node( + "Slice", + inputs=["past_value", "zero", "unsqueeze_v_out", "two", "one"], + outputs=["v_cache_sliced_out"], + name="Slice_v", + ) + concat_v_node = helper.make_node( + "Concat", + inputs=["v_cache_sliced_out", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + v_nodes.extend([unsqueeze_v_node, slice_v_node, concat_v_node]) + + # Create remaining nodes for v path + transpose_v_2_node = helper.make_node( + "Transpose", + inputs=["present_value"], + outputs=["transpose_v_2_out"], + name="Transpose_v_2", + ) + reshape_v_2_node = helper.make_node( + "Reshape", + inputs=["transpose_v_2_out", "concat_v_2_extra_out"], + outputs=["v"], + name="Reshape_v_2", + ) + return v_nodes + [transpose_v_2_node, reshape_v_2_node] # noqa: RUF005 + + def create_qkv_path(self, model_type: str): + matmul_qkv_node = helper.make_node( + "MatMul", + inputs=["softmax_out", "v" if model_type == "llama2_msft" else "present_value"], + outputs=["softmax_v_out"], + name="MatMul_softmax_v", + ) + qkv_nodes = [matmul_qkv_node] + + if model_type == "llama2_msft": + reshape_qkv_1_node = helper.make_node( + "Reshape", + inputs=["softmax_v_out", "concat_qkv_1_extra_out"], + outputs=["reshape_qkv_1_out"], + name="Reshape_qkv_1", + ) + qkv_nodes.append(reshape_qkv_1_node) + + transpose_qkv_node = helper.make_node( + "Transpose", + inputs=["reshape_qkv_1_out" if model_type == "llama2_msft" else "softmax_v_out"], + outputs=["transpose_qkv_out"], + name="Transpose_qkv", + ) + reshape_qkv_2_node = helper.make_node( + "Reshape", + inputs=["transpose_qkv_out", "concat_qkv_2_extra_out"], + outputs=["attn_output"], + name="Reshape_qkv_2", + ) + + return qkv_nodes + [transpose_qkv_node, reshape_qkv_2_node] # noqa: RUF005 + + def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[NodeProto]): + # Create initial shape paths + shape_0_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_0"], + name="Shape_0", + ) + gather_0_node = helper.make_node( + "Gather", + inputs=["input_0_shape_0", "zero"], + outputs=["input_0_shape_0_indexed"], + name="Gather_0", + axis=0, + ) + shape_1_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_1"], + name="Shape_1", + ) + gather_1_node = helper.make_node( + "Gather", + inputs=["input_0_shape_1", "one"], + outputs=["input_0_shape_1_indexed"], + name="Gather_1", + axis=0, + ) + extra_nodes = [shape_0_node, gather_0_node, shape_1_node, gather_1_node] + + if model_type == "llama2_msft": + mul_node = helper.make_node( + "Mul", + inputs=[gather_0_node.output[0], "num_heads"], + outputs=["mul_extra_out"], + name="Mul_extra_0", + ) + add_node = helper.make_node( + "Add", + inputs=[gather_1_node.output[0], "position_ids"], + outputs=["add_extra_out"], + name="Add_extra_0", + ) + extra_nodes.extend([mul_node, add_node]) + + for i, reshape_node in enumerate(reshape_nodes): + use_mul_and_add_nodes_0 = model_type == "llama2_msft" and reshape_node.output[0] in {"q", "k", "v"} + use_mul_and_add_nodes_1 = model_type == "llama2_msft" and reshape_node.output[0] in {"k", "v"} + + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=[gather_0_node.output[0] if not use_mul_and_add_nodes_0 else "mul_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i}"], + name=f"Unsqueeze_extra_{2*i}", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=[gather_1_node.output[0] if not use_mul_and_add_nodes_1 else "add_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i + 1}"], + name=f"Unsqueeze_extra_{2*i + 1}", + ) + + reshape_name = reshape_node.name + if reshape_name == "Reshape_qkv_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "hidden_size"] + elif reshape_name == "Reshape_qkv_1": + concat_node_inputs = [unsqueeze_0_node.output[0], "num_heads", unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_1": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "num_heads", "head_size"] + elif reshape_name == "Reshape_k": + concat_node_inputs = [unsqueeze_0_node.output[0], "head_size", unsqueeze_1_node.output[0]] + elif reshape_name == "Reshape_q": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + + concat_node = helper.make_node( + "Concat", + inputs=concat_node_inputs, + outputs=[reshape_nodes[i].input[1]], + name=f"Concat_extra_{i}", + axis=0, + ) + extra_nodes.extend([unsqueeze_0_node, unsqueeze_1_node, concat_node]) + + return extra_nodes + + def create_end_nodes(self): + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "output_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, end_node] + + def create_fused_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(True, model_type=model_type) + rope_nodes = self.create_rotary_embeddings(True, model_type, interleaved, inputs, initializers) + attn_mask_nodes = self.create_attn_mask_path(True, model_type) + + mha_inputs = [ + rope_nodes[0].output[0], # q + rope_nodes[1].output[0], # k + matmul_nodes[-1].output[0], # v + "", # bias + "attn_mask_out" if model_type == "llama2_msft" else "", # attn_mask + "attn_mask_out" if model_type != "llama2_msft" else "", # add_qk + "past_key" if model_type != "no_past" else "", # past_key + "past_value" if model_type != "no_past" else "", # past_value + ] + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=["attn_output", "present_key", "present_value"], + name="MultiHeadAttention_0", + num_heads=self.num_heads, + ) + + end_nodes = self.create_end_nodes() + + graph = helper.make_graph( + nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_test_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(False, model_type) + rope_nodes = self.create_rotary_embeddings(False, model_type, interleaved, inputs, initializers) + + # Create main paths + q_nodes = self.create_q_path(model_type) + k_nodes = self.create_k_path(model_type) + attn_mask_nodes = self.create_attn_mask_path(False, model_type) + qk_nodes = self.create_qk_path(model_type) + v_nodes = self.create_v_path(model_type) + qkv_nodes = self.create_qkv_path(model_type) + + reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) + extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) + + end_nodes = self.create_end_nodes() + + first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes + second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes + graph = helper.make_graph( + nodes=first_set_of_nodes + second_set_of_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=17) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, model_type: str, interleaved: bool): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(model_type, interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(model_type, interleaved, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + def test_llama2_msft_model(self): + model_type = "llama2_msft" + interleaved = True + self.check_models(model_type, interleaved) + + def test_hf_decoder_model(self): + model_type = "no_past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_with_past_model(self): + model_type = "past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_merged_model(self): + model_type = "merged" + interleaved = False + self.check_models(model_type, interleaved) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py new file mode 100644 index 0000000000000..e86bdda7baffb --- /dev/null +++ b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestSimplifiedLayerNormFusion(unittest.TestCase): + def setUp(self): + self.vocab_size = 5 + self.batch_size = 2 + self.sequence_length = 8 + self.hidden_size = 16 + self.epsilon = 0.000009999999747378752 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, use_embed_weight: bool = False): + initializers = [ + helper.make_tensor("Two", TensorProto.FLOAT, [1], np.array([2], dtype=np.float32)), + helper.make_tensor("epsilon", TensorProto.FLOAT, [1], np.array([self.epsilon], dtype=np.float32)), + helper.make_tensor("One", TensorProto.FLOAT, [1], np.array([1], dtype=np.float32)), + float_tensor("scale", [self.hidden_size]), + ] + if use_embed_weight: + initializers = [ # noqa: RUF005 + float_tensor("embed_weight", [self.vocab_size, self.hidden_size]) + ] + initializers + return initializers + + def create_inputs_and_outputs(self, start_node_type: str): + inputs, start_node = None, None + if start_node_type == "Add": + start_node = helper.make_node( + "Add", + inputs=["input_0", "input_1"], + outputs=["D"], + name="Add_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + input_1 = helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0, input_1] + elif start_node_type == "Gather": + start_node = helper.make_node( + "Gather", + inputs=["embed_weight", "input_0"], + outputs=["D"], + name="Gather_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.INT64, + [self.batch_size, self.sequence_length], + ) + inputs = [input_0] + else: + # start_node_type is a graph input + assert start_node_type == "GraphInput" + input_0 = helper.make_tensor_value_info( + "D", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0] + + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + ] + return inputs, outputs, start_node + + def create_fused_model(self, start_node_type: str, initializers: List[TensorProto]): + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + sln_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[start_node.output[0] if start_node is not None else "D", initializers[0].name], + outputs=[outputs[0].name], + axis=-1, + epsilon=initializers[2].float_data[0], + stash_type=1, + ) + + graph = helper.make_graph( + nodes=[sln_node] + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + # Notation follows https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary + def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + end_node = helper.make_node( + "Mul", + inputs=["scale", "Normalized"] if first_parent_idx == 1 else ["Normalized", "scale"], + outputs=["output_0"], + name="Mul_1", + ) + mul_node = helper.make_node( + "Mul", + inputs=["D", "InvStdDev"], + outputs=["Normalized"], + name="Mul_0", + ) + div_node = helper.make_node( + "Div", + inputs=["One", "StdDev"], + outputs=["InvStdDev"], + name="Div_0", + ) + sqrt_node = helper.make_node( + "Sqrt", + inputs=["VarEps"], + outputs=["StdDev"], + name="Sqrt_0", + ) + add_node = helper.make_node( + "Add", + inputs=["Var", "epsilon"], + outputs=["VarEps"], + name="Add_1", + ) + reducemean_node = helper.make_node( + "ReduceMean", + inputs=["DD"], + outputs=["Var"], + name="ReduceMean_0", + ) + pow_node = helper.make_node( + "Pow", + inputs=["D", "Two"], + outputs=["DD"], + name="Pow_0", + ) + + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + main_nodes = [pow_node, reducemean_node, add_node, sqrt_node, div_node, mul_node, end_node] + graph = helper.make_graph( + nodes=main_nodes + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(start_node_type, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(start_node_type, first_parent_idx, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # sim_ln_nodes_1 + def test_simplified_layernorm_add_idx1(self): + start_node_type = "Add" + first_parent_idx = 1 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_2 + def test_simplified_layernorm_gather_idx1(self): + start_node_type = "Gather" + first_parent_idx = 1 + initializers = self.create_initializers(use_embed_weight=True) + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_3 + def test_simplified_layernorm_add_idx0(self): + start_node_type = "Add" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_4 + def test_simplified_layernorm_gather_graph_input(self): + start_node_type = "GraphInput" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index ebda0bccaadcf..ceda5a88c3925 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -50,7 +50,7 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) ) - # Attention type #1 in onnx_model_bart.py + # Attention type #1 in fusion_bart_attention.py def test_encoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -67,7 +67,7 @@ def test_encoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx") - # Attention type #2 in onnx_model_bart.py + # Attention type #2 in fusion_bart_attention.py def test_decoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -84,7 +84,7 @@ def test_decoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -100,7 +100,7 @@ def test_decoder_multihead_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -118,7 +118,7 @@ def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(se os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -134,7 +134,7 @@ def test_decoder_with_past_multihead_cross_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64 @@ -151,7 +151,7 @@ def test_decoder_multihead_attention_split_bias_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -171,7 +171,7 @@ def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skipl os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64