diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 7e67ec6d0c94e..5805333a0868c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -90,6 +90,7 @@ Do not modify directly.*
* com.microsoft.RemovePadding
* com.microsoft.RestorePadding
* com.microsoft.Rfft
+ * com.microsoft.RotaryEmbedding
* com.microsoft.SampleOp
* com.microsoft.Sampling
* com.microsoft.SkipLayerNormalization
@@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.RotaryEmbedding**
+
+ RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
+ that are multiplied to query and key before the inner product of query and key is taken.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- interleaved : int
+- Rotate using interleaved pattern. Default value is 0 (False).
+- scale : float
+- Custom scale will be used if specified. Default value is 1.0
+
+
+#### Inputs
+
+
+- input : T
+- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+- position_ids : M
+- 1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+- cos_cache : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+- sin_cache : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+
+
+#### Outputs
+
+
+- output : T
+- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float tensors.
+- M : tensor(int64)
+- Constrain input and output types to integer tensors
+
+
+
### **com.microsoft.SampleOp**
Sample echo operator.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index e2d500006b05f..dea71d81f8df5 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -477,9 +477,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
+|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
@@ -866,6 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 0b55cb7804c61..694c40bf3eda6 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -16,7 +16,6 @@
#include
#include
-#include
using onnxruntime::concurrency::ThreadPool;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 73b83057bdbe9..00e82c9844b3d 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -206,6 +206,7 @@ Status CheckInputs(const T* query,
}
}
+ int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
@@ -216,13 +217,21 @@ Status CheckInputs(const T* query,
} else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
- } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) {
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(kv_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(sequence_length) &&
+ mask_dims[2] == static_cast(total_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}
if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
+ "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}
@@ -257,7 +266,6 @@ Status CheckInputs(const T* query,
}
}
- int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
new file mode 100644
index 0000000000000..4a266af789250
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/bert/rotary_embedding.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+
+#include "core/platform/threadpool.h"
+
+using onnxruntime::concurrency::ThreadPool;
+using namespace onnxruntime::contrib::rotary_embedding_helper;
+
+namespace onnxruntime {
+namespace contrib {
+
+// These ops are internal-only, so register outside of onnx
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ RotaryEmbedding,
+ kMSDomain,
+ 1,
+ float,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()),
+ RotaryEmbedding);
+
+template
+RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
+ scale = info.GetAttrOrDefault("scale", 1.0);
+ interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+}
+
+template
+Status RotaryEmbedding::Compute(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* position_ids = context->Input(1);
+ const Tensor* cos_cache = context->Input(2);
+ const Tensor* sin_cache = context->Input(3);
+
+ RotaryParameters parameters = {};
+ ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ ¶meters));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ if (parameters.sequence_length > parameters.max_sequence_length) {
+ // Launch update_cos_sin_cache kernel with scale
+ ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
+ }
+
+ const T* input_src = input->Data();
+ const int64_t* pos_ids_data = position_ids->Data();
+ const T* cos_cache_data = cos_cache->Data();
+ const T* sin_cache_data = sin_cache->Data();
+ T* output_dest = output->MutableData();
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int head_size = parameters.head_size;
+ const int position_ids_format = parameters.position_ids_format;
+ const int half_head_size = head_size / 2;
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ auto* tp = context->GetOperatorThreadPool();
+
+ const int loop_len = batch_size * sequence_length * num_heads;
+ const double cost = static_cast(head_size);
+ ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
+ const int b = static_cast((ptr / num_heads) / sequence_length);
+ const int s = static_cast((ptr / num_heads) % sequence_length);
+ const int n = static_cast(ptr % num_heads);
+
+ const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
+ const int data_offset = block_offset * head_size;
+
+ const T* input_data = input_src + data_offset;
+ T* output_data = output_dest + data_offset;
+
+ // Cache is (M, H/2)
+ const int position_id = (position_ids_format == 0)
+ ? static_cast(pos_ids_data[0]) + s
+ : static_cast(pos_ids_data[b * sequence_length + s]);
+ const int cache_offset = position_id * half_head_size;
+ const T* cos_data = cos_cache_data + cache_offset;
+ const T* sin_data = sin_cache_data + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ for (int i = 0; i < head_size; i++) {
+ if (interleaved) {
+ cache_idx = (i / 2) % half_head_size;
+ sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
+ j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
+ } else {
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
+ j = (i + half_head_size) % head_size;
+ }
+ output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+ }
+ }
+ });
+
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
new file mode 100644
index 0000000000000..be834a66cdc69
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class RotaryEmbedding final : public OpKernel {
+ public:
+ RotaryEmbedding(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ protected:
+ float scale;
+ bool interleaved;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
new file mode 100644
index 0000000000000..cf8080800e072
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -0,0 +1,121 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace rotary_embedding_helper {
+
+// Parameters deduced from node attributes and inputs/outputs.
+struct RotaryParameters {
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size used by cos/sin cache * 2
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+};
+
+template
+Status CheckInputs(const T* input,
+ const T* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ void* parameters) {
+ // input : (batch_size, sequence_length, hidden_size)
+ // position ids : (1) or (batch_size, sequence_length)
+ // cos cache : (max_sequence_length, head_size / 2)
+ // sin cache : (max_sequence_length, head_size / 2)
+
+ // Check input
+ const auto& input_dims = input->Shape().GetDims();
+ if (input_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
+ input_dims.size());
+ }
+ // Check position_ids
+ const auto& position_ids_dims = position_ids->Shape().GetDims();
+ if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ",
+ "dimensions, got ", position_ids_dims.size());
+ }
+ // Check cos_cache and sin_cache
+ const auto& cos_cache_dims = cos_cache->Shape().GetDims();
+ if (cos_cache_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
+ cos_cache_dims.size());
+ }
+ const auto& sin_cache_dims = sin_cache->Shape().GetDims();
+ if (sin_cache_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
+ sin_cache_dims.size());
+ }
+ if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
+ "the same shape");
+ }
+
+ // Get attributes from inputs
+ int batch_size = static_cast(input_dims[0]);
+ int sequence_length = static_cast(input_dims[1]);
+ int hidden_size = static_cast(input_dims[2]);
+ int max_sequence_length = static_cast(cos_cache_dims[0]);
+ int head_size = static_cast(cos_cache_dims[1]) * 2;
+ int num_heads = hidden_size / head_size;
+ int position_ids_format = -1;
+
+ // Check position_ids input shapes
+ if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
+ if (batch_size != static_cast(position_ids_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ",
+ "batch_size, got ", position_ids_dims[0]);
+ }
+ if (sequence_length != static_cast(position_ids_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ",
+ "sequence_length, got ", position_ids_dims[1]);
+ }
+ position_ids_format = 1;
+ } else {
+ position_ids_format = 0;
+ }
+ // Check cos_cache input shapes
+ if (max_sequence_length != static_cast(cos_cache_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
+ "max_sequence_length, got ", cos_cache_dims[0]);
+ }
+ if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
+ "head_size / 2, got ", cos_cache_dims[1]);
+ }
+ // Check sin_cache input shapes
+ if (max_sequence_length != static_cast(sin_cache_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
+ "max_sequence_length, got ", sin_cache_dims[0]);
+ }
+ if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
+ "head_size / 2, got ", sin_cache_dims[1]);
+ }
+
+ // Set rotary parameters
+ if (parameters != nullptr) {
+ RotaryParameters* output_parameters = reinterpret_cast(parameters);
+ output_parameters->batch_size = batch_size;
+ output_parameters->sequence_length = sequence_length;
+ output_parameters->hidden_size = hidden_size;
+ output_parameters->head_size = head_size;
+ output_parameters->num_heads = num_heads;
+ output_parameters->max_sequence_length = max_sequence_length;
+ output_parameters->position_ids_format = position_ids_format;
+ }
+
+ return Status::OK();
+}
+
+} // namespace rotary_embedding_helper
+} // namespace contrib
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index b4c51ab290eb7..f77e403f26dde 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
@@ -124,6 +125,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
@@ -253,6 +256,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -299,6 +303,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
index e86a12d9fb873..4e103c2556a7a 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
@@ -20,20 +20,29 @@ namespace contrib {
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- SkipLayerNorm);
+ SkipLayerNorm); \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ SkipSimplifiedLayerNormalization, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ SkipLayerNorm);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
-template
-SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
+template
+SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
-template
-Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
+template
+Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input(0);
const Tensor* skip = p_ctx->Input(1);
const Tensor* gamma = p_ctx->Input(2);
@@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
}
mean = mean / hidden_size;
- mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
+ if (simplified) {
+ mean_square = sqrt(mean_square / hidden_size + epsilon_);
+ } else {
+ mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
+ }
for (int64_t h = 0; h < hidden_size; h++) {
- if (nullptr == beta_data) {
+ if (simplified) {
+ p_output[h] = p_output[h] / mean_square * gamma_data[h];
+ } else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
index 7723541cb6b18..69edf4609e340 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
@@ -10,7 +10,7 @@
namespace onnxruntime {
namespace contrib {
-template
+template
class SkipLayerNorm final : public OpKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
new file mode 100644
index 0000000000000..b4b5dac1fbe19
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+#include "contrib_ops/cuda/bert/rotary_embedding.h"
+#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::contrib::rotary_embedding_helper;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RotaryEmbedding, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ RotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+template
+RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
+ scale = info.GetAttrOrDefault("scale", 1.0);
+ interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+}
+
+template
+Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* position_ids = context->Input(1);
+ const Tensor* cos_cache = context->Input(2);
+ const Tensor* sin_cache = context->Input(3);
+
+ RotaryParameters parameters = {};
+ ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ ¶meters));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ if (parameters.sequence_length > parameters.max_sequence_length) {
+ // Launch update_cos_sin_cache kernel with scale
+ ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
+ }
+
+ // Launch rotary embedding kernel
+ typedef typename ToCudaType::MappedType CudaT;
+ auto& device_prop = GetDeviceProp();
+ return LaunchRotaryEmbeddingKernel(
+ Stream(context),
+ reinterpret_cast(output->template MutableData()),
+ reinterpret_cast(input->template Data()),
+ position_ids->Data(),
+ reinterpret_cast(cos_cache->template Data()),
+ reinterpret_cast(sin_cache->template Data()),
+ parameters.batch_size,
+ parameters.sequence_length,
+ parameters.num_heads,
+ parameters.head_size,
+ parameters.max_sequence_length,
+ parameters.position_ids_format,
+ interleaved,
+ device_prop.maxThreadsPerBlock);
+
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
new file mode 100644
index 0000000000000..6dab2ad56749e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class RotaryEmbedding final : public CudaKernel {
+ public:
+ RotaryEmbedding(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ float scale;
+ bool interleaved;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
new file mode 100644
index 0000000000000..c54e72dcfce13
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -0,0 +1,141 @@
+/*
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT License.
+*/
+
+/*
+Kernel implementation for rotary embeddings.
+*/
+
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
+ const T* input, // BxSxNxH
+ const T* cos_cache, // Mx(H/2)
+ const T* sin_cache, // Mx(H/2)
+ const int64_t* position_ids, // (1) or BxS
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int position_ids_format,
+ const bool interleaved) {
+ // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
+ // Use .x in innermost loop to access global memory efficiently
+
+ const int b = blockIdx.z;
+ const int s = blockIdx.y;
+ const int n = blockIdx.x;
+
+ const int i = threadIdx.x;
+
+ const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
+ const int data_offset = block_offset * head_size;
+
+ const T* input_data = input + data_offset;
+ T* output_data = output + data_offset;
+
+ // Cache is (M, H/2)
+ const int half_head_size = head_size / 2;
+ const int position_id = (position_ids_format == 0) ? \
+ static_cast(position_ids[0]) + s \
+ : static_cast(position_ids[b * sequence_length + s]);
+ const int cache_offset = position_id * half_head_size;
+ const T* cos_data = cos_cache + cache_offset;
+ const T* sin_data = sin_cache + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ if (interleaved) {
+ cache_idx = (i / 2) % half_head_size;
+ sign = (i % 2 == 0) ? -1 : 1;
+ j = (i % 2 == 0) ? i+1 : i-1; // i - sign
+ } else {
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? -1 : 1;
+ j = (i + half_head_size) % head_size;
+ }
+ output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+}
+
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block) {
+
+ constexpr int smem_size = 0;
+ const dim3 grid(num_heads, sequence_length, batch_size);
+ const dim3 block(head_size, 1, 1);
+
+ // Note: Current implementation assumes head_size <= max_threads_per_block
+ // because head_size is currently large for LLaMA-2. For smaller head_size
+ // and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
+ // instead. This will require kernel changes to support.
+
+ assert(head_size <= max_threads_per_block);
+ RotaryEmbeddingBSNH<<>>(
+ output, input, cos_cache, sin_cache, position_ids,
+ sequence_length, num_heads, head_size, position_ids_format, interleaved
+ );
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ float* output,
+ const float* input,
+ const int64_t* position_ids,
+ const float* cos_cache,
+ const float* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ half* output,
+ const half* input,
+ const int64_t* position_ids,
+ const half* cos_cache,
+ const half* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
new file mode 100644
index 0000000000000..29ff48a8ad0fb
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 52ff285539360..c52f869d6a9d2 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -91,6 +91,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -250,6 +252,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 3a75b29ffe3c7..76c3f8716ff09 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -946,7 +946,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(4,
"key_padding_mask",
- "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",
+ "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), "
+ "or (batch_size, sequence_length, total_sequence_length)",
"M",
OpSchema::Optional)
.Input(5,
@@ -1129,6 +1130,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
DecoderAttentionTypeAndShapeInference(ctx);
}));
+constexpr const char* RotaryEmbedding_ver1_doc = R"DOC(
+RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
+that are multiplied to query and key before the inner product of query and key is taken.
+)DOC";
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ RotaryEmbedding, 1,
+ OpSchema()
+ .SetDoc(RotaryEmbedding_ver1_doc)
+ .Attr("scale",
+ "Custom scale will be used if specified. Default value is 1.0",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Attr("interleaved",
+ "Rotate using interleaved pattern. Default value is 0 (False).",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Input(0,
+ "input",
+ "3D tensor with shape (batch_size, sequence_length, hidden_size)",
+ "T")
+ .Input(1,
+ "position_ids",
+ "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)",
+ "M")
+ .Input(2,
+ "cos_cache",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "T")
+ .Input(3,
+ "sin_cache",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "T")
+ .Output(0,
+ "output",
+ "3D tensor with shape (batch_size, sequence_length, hidden_size)",
+ "T")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
+ .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ propagateShapeFromInputToOutput(ctx, 0, 0);
+ }));
+
constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC(
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,
@@ -1500,4 +1544,4 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
}));
} // namespace contrib
-} // namespace onnxruntime
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index afa5d101bbd8d..afaa380d6ac79 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -95,6 +95,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBia
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
@@ -200,6 +201,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 67e9f1b55e9ae..272727a9f5375 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -206,9 +206,11 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PythonOp": self._infer_PythonOp,
+ "QuickGelu": self._infer_FastGelu,
"RelativePositionBias": self._infer_RelativePositionBias,
"RemovePadding": self._infer_RemovePadding,
"RestorePadding": self._infer_RestorePadding,
+ "RotaryEmbedding": self._infer_RotaryEmbedding,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
@@ -462,6 +464,8 @@ def _onnx_infer_single_node(self, node):
"BiasSplitGelu",
"BiasAdd",
"NhwcConv",
+ "QuickGelu",
+ "RotaryEmbedding",
]
if not skip_infer:
@@ -2307,6 +2311,9 @@ def _infer_FastGelu(self, node): # noqa: N802
def _infer_Gelu(self, node): # noqa: N802
self._propagate_shape_and_type(node)
+ def _infer_QuickGelu(self, node): # noqa: N802
+ self._propagate_shape_and_type(node)
+
def _infer_GemmFastGelu(self, node): # noqa: N802
self._compute_matmul_shape(node)
@@ -2378,6 +2385,19 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802
def _infer_BiasAdd(self, node): # noqa: N802
self._propagate_shape_and_type(node)
+ def _infer_RotaryEmbedding(self, node): # noqa: N802
+ if len(node.output) == 1:
+ self._propagate_shape_and_type(node)
+ elif len(node.output) == 2:
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
+ self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output
+ elif len(node.output) == 3:
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
+ self._propagate_shape_and_type(node, input_index=1, output_index=1)
+ self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output
+
def _infer_PythonOp(self, node): # noqa: N802
output_tensor_types = get_attribute(node, "output_tensor_types")
assert output_tensor_types
@@ -2583,12 +2603,19 @@ def get_prereq(node):
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
- # Special case: We do not care about the training related
- # outputs of SkipLayerNormalization
+ # Special cases:
+ # 1) We do not care about the training related outputs of SkipLayerNormalization
+ # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because
+ # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding
+ # contrib op
if (
node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization"
) and i_o in [1, 2]:
continue
+ if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
+ # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs
+ # generated by `export_modules_as_functions`
+ continue
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
@@ -2750,13 +2777,13 @@ def get_prereq(node):
if i in self.known_vi_:
logger.debug(self.known_vi_[i])
else:
- logger.debug(f"not in knwon_vi_ for {i}")
+ logger.debug(f"not in known_vi_ for {i}")
logger.debug("node outputs:")
for o in node.output:
if o in self.known_vi_:
logger.debug(self.known_vi_[o])
else:
- logger.debug(f"not in knwon_vi_ for {o}")
+ logger.debug(f"not in known_vi_ for {o}")
if self.auto_merge_ and not out_type_undefined:
logger.debug("Merging: " + str(self.suggested_merge_))
return False
diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py
index 4f898245d01bd..b6f7a44450c62 100644
--- a/onnxruntime/python/tools/transformers/benchmark_helper.py
+++ b/onnxruntime/python/tools/transformers/benchmark_helper.py
@@ -33,6 +33,7 @@ class Precision(Enum):
FLOAT32 = "fp32"
FLOAT16 = "fp16"
INT8 = "int8"
+ INT4 = "int4"
def __str__(self):
return self.value
@@ -610,7 +611,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
return memory_before_test
with ThreadPoolExecutor() as executor:
- monitor = MemoryMonitor()
+ monitor = memory_monitor_type()
mem_thread = executor.submit(monitor.measure_cpu_usage)
try:
fn_thread = executor.submit(func)
diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py
index c1c709d6d759b..4228c892d03ae 100644
--- a/onnxruntime/python/tools/transformers/convert_generation.py
+++ b/onnxruntime/python/tools/transformers/convert_generation.py
@@ -1272,6 +1272,38 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove
+def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
+ past_seq_len = past_seq_len_input
+ if past_seq_len not in model.get_graphs_input_names():
+ # Replace model input for past sequence length
+ new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1])
+ model.model.graph.input.append(new_input)
+
+ # Replace MultiHeadAttention with GroupQueryAttention
+ for node in model.model.graph.node:
+ if node.op_type == "MultiHeadAttention":
+ gqa_node = onnx.helper.make_node(
+ "GroupQueryAttention",
+ inputs=[
+ node.input[0], # query
+ node.input[1], # key
+ node.input[2], # value
+ node.input[6], # past_key
+ node.input[7], # past_value
+ past_seq_len, # past_sequence_length
+ ],
+ outputs=node.output,
+ name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
+ domain="com.microsoft",
+ num_heads=node.attribute[0].i,
+ kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads,
+ is_past_bsnh=0,
+ )
+ model.model.graph.node.remove(node)
+ model.model.graph.node.extend([gqa_node])
+ return model
+
+
def update_decoder_subgraph_output_cross_attention(subg: GraphProto):
input_self_past_0 = 1
# w/wo attention mask, w/wo hidden_state
diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py
index 1dbdf39613cdd..c1b241aa1a5ec 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention.py
@@ -111,7 +111,7 @@ def __init__(
model: OnnxModel,
hidden_size: int,
num_heads: int,
- attention_mask: AttentionMask,
+ attention_mask: Optional[AttentionMask] = None,
use_multi_head_attention: bool = False,
disable_multi_head_attention_bias: bool = False,
search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
@@ -120,7 +120,7 @@ def __init__(
super().__init__(model, attention_op_name, search_op_types)
self.hidden_size = hidden_size
self.num_heads = num_heads
- self.attention_mask = attention_mask
+ self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
self.use_multi_head_attention = use_multi_head_attention
self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
self.mask_filter_value = None
@@ -219,6 +219,31 @@ def get_add_qk_str(self, add_qk: NodeProto):
return add_qk.input[1]
+ def reshape_add_qk(self, add_qk: str):
+ # Convert 4D mask from (B,1,S,T) to (B,N,S,T)
+ # B = batch size, N = num heads, S = source sequence length, T = target sequence length
+ mask_output_name = add_qk + "_mask"
+
+ # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
+ concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
+ if len(concat_node) == 1:
+ return mask_output_name
+
+ assert len(concat_node) == 0
+ concat_node_name = self.model.create_node_name("Concat")
+ concat_add_qk_fp32 = helper.make_node(
+ "Concat",
+ inputs=[add_qk for _ in range(self.num_heads)],
+ outputs=[mask_output_name],
+ name=concat_node_name,
+ axis=1,
+ )
+ # Add new node to graph
+ self.nodes_to_add.append(concat_add_qk_fp32)
+ self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
+
+ return mask_output_name
+
def concat_kv(self, past_k: str, past_v: str) -> str:
"""Concatenate past_k and past_v inputs to create past_kv input.
@@ -875,21 +900,8 @@ def create_attention_node(
past_kv = self.concat_kv(past_k, past_v)
attention_inputs.append(past_kv)
- if add_qk_str:
- # Convert 4d mask from (B,1,M,M) to (B,N,M,M)
- # B = batch size, M = max sequence length, N = num heads
- concat_node_name = self.model.create_node_name("Concat")
- mask_output_name = add_qk_str + "_mask"
- concat_add_qk_fp32 = helper.make_node(
- "Concat",
- inputs=[add_qk_str for _ in range(num_heads)],
- outputs=[mask_output_name],
- name=concat_node_name,
- axis=1,
- )
- # Add new nodes to graph
- self.nodes_to_add.append(concat_add_qk_fp32)
- self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
+ if add_qk_str is not None:
+ mask_output_name = self.reshape_add_qk(add_qk_str)
# Add attention mask to attention node
if not past_exists:
diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py
index 117468be412fa..c5d7bc16d64f7 100644
--- a/onnxruntime/python/tools/transformers/fusion_base.py
+++ b/onnxruntime/python/tools/transformers/fusion_base.py
@@ -113,3 +113,20 @@ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals:
self.model.add_initializer(tensor, self.this_graph_name)
return tensor
+
+ def add_nodes_to_remove(self, nodes: List[NodeProto]):
+ # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
+ # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
+ # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
+ # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
+ # Since path A's shared nodes are removed, path B's shared nodes are not removed because they
+ # were previously removed for path A. This causes an error to print in remove_node that a node
+ # has failed to be removed.
+ #
+ # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
+ # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
+ # be scenarios where the nodes need to be removed in a specific order and converting to a set would
+ # lose this order.
+ for node in nodes:
+ if node not in self.nodes_to_remove:
+ self.nodes_to_remove.append(node)
diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py
index 69b5cd26f4525..8c80fcad0ab49 100644
--- a/onnxruntime/python/tools/transformers/fusion_options.py
+++ b/onnxruntime/python/tools/transformers/fusion_options.py
@@ -26,6 +26,7 @@ def __init__(self, model_type):
self.enable_gelu = True
self.enable_layer_norm = True
self.enable_attention = True
+ self.enable_rotary_embeddings = True
# Use MultiHeadAttention instead of Attention operator. The difference:
# (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
@@ -81,6 +82,8 @@ def parse(args):
options.enable_gelu = False
if args.disable_layer_norm:
options.enable_layer_norm = False
+ if args.disable_rotary_embeddings:
+ options.enable_rotary_embeddings = False
if args.disable_attention:
options.enable_attention = False
if args.use_multi_head_attention:
@@ -294,3 +297,10 @@ def add_arguments(parser: ArgumentParser):
help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae",
)
parser.set_defaults(use_group_norm_channels_first=False)
+
+ parser.add_argument(
+ "--disable_rotary_embeddings",
+ required=False,
+ action="store_true",
+ help="Do not fuse rotary embeddings into RotaryEmbedding op",
+ )
diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py
new file mode 100644
index 0000000000000..3c5029ac5752f
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py
@@ -0,0 +1,1044 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import logging
+from typing import Optional, Union
+
+from fusion_attention import FusionAttention
+from fusion_base import Fusion
+from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
+from onnx_model import OnnxModel
+
+logger = logging.getLogger(__name__)
+
+
+class FusionRotaryAttention(FusionAttention):
+ """
+ Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
+ """
+
+ def __init__(
+ self,
+ model: OnnxModel,
+ hidden_size: int,
+ num_heads: int,
+ ):
+ super().__init__(
+ model,
+ hidden_size,
+ num_heads,
+ use_multi_head_attention=True,
+ search_op_types=["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Add"],
+ )
+
+ def create_mha_node(
+ self,
+ input: str,
+ output: str,
+ q_rotary: NodeProto,
+ k_rotary: NodeProto,
+ v_matmul: NodeProto,
+ attn_mask: str = "",
+ add_qk: str = "",
+ past_k: str = "",
+ past_v: str = "",
+ present_k: str = "",
+ present_v: str = "",
+ scale: Optional[float] = None,
+ ) -> Union[NodeProto, None]:
+ assert self.num_heads > 0
+
+ if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
+ logger.debug(
+ f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
+ )
+ return None
+
+ mha_node_name = self.model.create_node_name("MultiHeadAttention")
+ mha_inputs = [
+ q_rotary.output[0],
+ k_rotary.output[0],
+ v_matmul.output[0],
+ "", # bias
+ attn_mask, # key_padding_mask
+ add_qk, # relative_position_bias
+ past_k,
+ past_v,
+ ]
+
+ mha_outputs = [output]
+ if present_k and present_v:
+ mha_outputs.extend([present_k, present_v])
+
+ mha_node = helper.make_node(
+ "MultiHeadAttention",
+ inputs=mha_inputs,
+ outputs=mha_outputs,
+ name=mha_node_name,
+ )
+
+ mha_node.domain = "com.microsoft"
+ mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
+ if scale is not None:
+ mha_node.attribute.extend([helper.make_attribute("scale", scale)])
+ if self.mask_filter_value is not None:
+ mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
+
+ self.increase_counter("MultiHeadAttention")
+ return mha_node
+
+ def check_runtime_shape_paths_for_function(
+ self,
+ reshape_qkv_2, # Reshape after Transpose
+ reshape_qkv_1, # Reshape before Transpose
+ reshape_q_2, # Reshape after RotaryEmbedding
+ reshape_k_2, # Reshape after RotaryEmbedding
+ reshape_v_2, # Reshape after Transpose
+ reshape_v_1, # Reshape before Transpose
+ add_qk, # Add before Softmax
+ root_input, # Root input to attention subgraph
+ ):
+ # Check #1: check paths for qkv nodes
+ concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
+ concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
+ if concat_qkv_2_path is None or concat_qkv_1_path is None:
+ return False
+ concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
+
+ reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
+ if (
+ reshape_qkv_2_path_1 is None
+ or reshape_qkv_2_path_2 is None
+ or reshape_qkv_1_path_1 is None
+ or reshape_qkv_1_path_2 is None
+ ):
+ return False
+
+ _, gather_1, shape_1 = reshape_qkv_2_path_1
+ _, gather_2, shape_2 = reshape_qkv_2_path_2
+
+ # Check root_input --> Shape --> Gather connection
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
+ return False
+
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
+ if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
+ return False
+
+ # Check #2: check paths for v nodes
+ concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
+ concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
+ if concat_v_2_path is None or concat_v_1_path is None:
+ return False
+ concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
+
+ reshape_v_2_path_1 = self.model.match_parent_path(
+ concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
+ )
+ reshape_v_2_path_2 = self.model.match_parent_path(
+ concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
+ )
+ reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if (
+ reshape_v_2_path_1 is None
+ or reshape_v_2_path_2 is None
+ or reshape_v_1_path_1 is None
+ or reshape_v_1_path_2 is None
+ ):
+ return False
+
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
+ if (
+ reshape_v_2_path_1[2].name != gather_1.name
+ or reshape_v_2_path_2[2].name != gather_2.name
+ or reshape_v_1_path_1[1].name != gather_1.name
+ or reshape_v_1_path_2[1].name != gather_2.name
+ ):
+ return False
+
+ # Check #3: check paths for k nodes
+ concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
+ if concat_k_2_path is None:
+ return False
+ concat_k_2 = concat_k_2_path[0]
+
+ reshape_k_2_path_1 = self.model.match_parent_path(
+ concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
+ )
+ reshape_k_2_path_2 = self.model.match_parent_path(
+ concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
+ )
+ if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
+ return False
+
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
+ if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
+ return False
+
+ # Check #4: check paths for q nodes
+ concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
+ if concat_q_2_path is None:
+ return False
+ concat_q_2 = concat_q_2_path[0]
+
+ reshape_q_2_path_1 = self.model.match_parent_path(
+ concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
+ )
+ reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
+ return False
+
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
+ if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
+ return False
+
+ # Check #5: check Mul nodes are the same for q, k, v
+ mul_q = reshape_q_2_path_1[1]
+ mul_k = reshape_k_2_path_1[1]
+ mul_v = reshape_v_2_path_1[1]
+ gather_1_out = gather_1.output[0]
+ if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
+ return False
+
+ # Check #6: check paths for attention mask nodes
+ attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
+ attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
+ if attn_mask_path_1 is not None:
+ _, slice_qk_2, slice_qk_1 = attn_mask_path_1
+ elif attn_mask_path_2 is not None:
+ _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
+ else:
+ return False
+ # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
+ if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
+ return False
+
+ slice_qk_2_path = self.model.match_parent_path(
+ slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
+ )
+ slice_qk_1_path_1 = self.model.match_parent_path(
+ slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
+ )
+ slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
+ if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
+ return False
+
+ # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
+ # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
+ if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
+ return False
+
+ # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
+ # Check if first input to Add and Unsqueeze #1 is position ids
+ if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
+ return False
+
+ return True
+
+ def check_runtime_shape_paths_for_nodes(
+ self,
+ reshape_qkv, # Final reshape before o_proj MatMul
+ reshape_q, # Reshape before q_proj MatMul
+ reshape_k, # Reshape before k_proj MatMul
+ reshape_v, # Reshape before v_proj MatMul
+ root_input, # Root input to attention subgraph
+ ):
+ # Check #1: check paths for qkv nodes
+ concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
+ if concat_qkv_path is None:
+ return False
+ concat_qkv = concat_qkv_path[0]
+
+ reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
+ return False
+
+ _, gather_1, shape_1 = reshape_qkv_path_1
+ _, gather_2, shape_2 = reshape_qkv_path_2
+
+ # Check root_input --> Shape --> Gather connection
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
+ return False
+
+ # Check #2: check paths for v nodes
+ concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
+ if concat_v_path is None:
+ return False
+ concat_v = concat_v_path[0]
+
+ reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if reshape_v_path_1 is None or reshape_v_path_2 is None:
+ return False
+
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
+ if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
+ return False
+
+ # Check #3: check paths for k nodes
+ concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
+ if concat_k_path is None:
+ return False
+ concat_k = concat_k_path[0]
+
+ reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if reshape_k_path_1 is None or reshape_k_path_2 is None:
+ return False
+
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
+ if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
+ return False
+
+ # Check #4: check paths for q nodes
+ concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
+ if concat_q_path is None:
+ return False
+ concat_q = concat_q_path[0]
+
+ reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+ reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+ if reshape_q_path_1 is None or reshape_q_path_2 is None:
+ return False
+
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
+ if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
+ return False
+
+ return True
+
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
+ if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
+ return
+
+ # qkv_nodes_1 is for LLaMA-2 Microsoft
+ # qkv_nodes_2 is for LLaMA-2 Hugging Face
+ qkv_nodes = None
+ qkv_nodes_1 = self.model.match_parent_path(
+ normalize_node,
+ ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
+ [1, 0, 0, 0, 0],
+ )
+ qkv_nodes_2 = self.model.match_parent_path(
+ normalize_node,
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
+ [1, 0, 0, 0],
+ )
+ if qkv_nodes_1 is not None:
+ _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
+ qkv_nodes = qkv_nodes_1
+ elif qkv_nodes_2 is not None:
+ _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
+ qkv_nodes = qkv_nodes_2
+ else:
+ logger.debug("fuse_rotary_attention: failed to match qkv nodes")
+ return
+
+ # v_nodes_1 is for LLaMA-2 Microsoft
+ # v_nodes_3 is for LLaMA-2 Hugging Face
+ past_v, present_v, past_seq_len = "", "", ""
+ v_nodes = None
+ v_nodes_1 = self.model.match_parent_path(
+ matmul_qkv,
+ ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
+ [1, 0, 0, 1, 0, 0],
+ )
+ v_nodes_2 = self.model.match_parent_path(
+ matmul_qkv,
+ ["Concat", "Transpose", "Reshape", "MatMul"],
+ [1, 1, 0, 0],
+ )
+ v_nodes_3 = self.model.match_parent_path(
+ matmul_qkv,
+ ["Transpose", "Reshape", "MatMul"],
+ [1, 0, 0],
+ )
+ if v_nodes_1 is not None:
+ reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
+ v_nodes = v_nodes_1
+
+ concat_v_path = self.model.match_parent_path(
+ concat_v,
+ ["Slice", "Unsqueeze"],
+ [0, 2],
+ )
+ if concat_v_path is None:
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
+ return
+
+ past_v = concat_v_path[0].input[0]
+ past_seq_len = concat_v_path[-1].input[0]
+ present_v = concat_v.output[0]
+ elif v_nodes_2 is not None:
+ concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
+ v_nodes = v_nodes_2
+ past_v = concat_v.input[0]
+ present_v = concat_v.output[0]
+ elif v_nodes_3 is not None:
+ transpose_v, reshape_v, matmul_v = v_nodes_3
+ v_nodes = v_nodes_3
+ present_v = transpose_v.output[0]
+ else:
+ logger.debug("fuse_rotary_attention: failed to match v path")
+ return
+
+ qk_nodes = self.model.match_parent_path(
+ matmul_qkv,
+ ["Softmax", "Add", "Div", "MatMul"],
+ [0, 0, 0, 0],
+ )
+ add_qk, matmul_qk = None, None
+ if qk_nodes is not None:
+ _, add_qk, _, matmul_qk = qk_nodes
+ else:
+ logger.debug("fuse_rotary_attention: failed to match qk nodes")
+ return
+
+ # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
+ # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
+ attn_mask, add_qk_str = "", ""
+ attn_mask_nodes_1 = self.model.match_parent_path(
+ add_qk,
+ ["Concat", "Slice", "Slice"],
+ [1, 0, 0],
+ )
+ attn_mask_nodes_2 = self.model.match_parent_path(
+ add_qk,
+ ["Cast", "Concat", "Slice", "Slice"],
+ [1, 0, 0, 0],
+ )
+ attn_mask_nodes_3 = self.model.match_parent_path(
+ add_qk,
+ ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
+ [1, 0, 2, 1, 0, 0, 0],
+ )
+ attn_mask_nodes_4 = self.model.match_parent_path(
+ add_qk,
+ ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
+ [1, 2, 1, 0, 0, 0],
+ )
+ if attn_mask_nodes_1 is not None:
+ _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
+ attn_mask = slice_mask_1.output[0]
+ elif attn_mask_nodes_2 is not None:
+ _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
+ attn_mask = slice_mask_1.output[0]
+ elif attn_mask_nodes_3 is not None:
+ # Reshape from (B,1,S,T) to (B,N,S,T)
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
+ elif attn_mask_nodes_4 is not None:
+ # Reshape from (B,1,S,T) to (B,N,S,T)
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
+ else:
+ logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
+ return
+
+ # k_nodes_1 is for LLaMA-2 Microsoft
+ # k_nodes_2 is for LLaMA-2 Hugging Face
+ past_k, present_k = "", ""
+ k_nodes = None
+ k_nodes_1 = self.model.match_parent_path(
+ matmul_qk,
+ ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
+ [1, 0, 0, 1, 0, 0],
+ )
+ k_nodes_2 = self.model.match_parent_path(
+ matmul_qk,
+ ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
+ [1, 0, 0, 0, 0],
+ )
+ k_nodes_3 = self.model.match_parent_path(
+ matmul_qk,
+ ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
+ [1, 0, 1, 0, 0, 0],
+ )
+ if k_nodes_1 is not None:
+ reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
+ k_nodes = k_nodes_1
+
+ concat_k_path = self.model.match_parent_path(
+ concat_k,
+ ["Slice", "Unsqueeze"],
+ [0, 2],
+ )
+ if concat_k_path is None:
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
+ return
+
+ past_k = concat_k_path[0].input[0]
+ shared_past_seq_len = concat_k_path[-1].input[0]
+ present_k = concat_k.output[0]
+
+ assert past_seq_len == shared_past_seq_len
+ elif k_nodes_2 is not None:
+ _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
+ k_nodes = k_nodes_2
+ present_k = rotary_k.output[0]
+ elif k_nodes_3 is not None:
+ _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
+ k_nodes = k_nodes_3
+ past_k = concat_k.input[0]
+ present_k = concat_k.output[0]
+ else:
+ logger.debug("fuse_rotary_attention: failed to match k nodes")
+ return
+
+ # q_nodes_1 is for LLaMA-2 Microsoft
+ # q_nodes_2 is for LLaMA-2 Hugging Face
+ q_nodes = None
+ q_nodes_1 = self.model.match_parent_path(
+ matmul_qk,
+ ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
+ [0, 0, 0, 0],
+ )
+ q_nodes_2 = self.model.match_parent_path(
+ matmul_qk,
+ ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
+ [0, 0, 0, 0],
+ )
+ if q_nodes_1 is not None:
+ reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
+ q_nodes = q_nodes_1
+ elif q_nodes_2 is not None:
+ rotary_q, _, reshape_q, matmul_q = q_nodes_2
+ q_nodes = q_nodes_2
+ else:
+ logger.debug("fuse_rotary_attention: failed to match q nodes")
+ return
+
+ if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
+ logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
+ return
+
+ root_output = ""
+ if qkv_nodes == qkv_nodes_1:
+ if not self.check_runtime_shape_paths_for_function(
+ reshape_qkv_2,
+ reshape_qkv_1,
+ reshape_q_2,
+ reshape_k_2,
+ reshape_v_2,
+ reshape_v_1,
+ add_qk,
+ matmul_q.input[0],
+ ):
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
+ return
+ root_output = reshape_qkv_2.output[0]
+
+ elif qkv_nodes == qkv_nodes_2:
+ if not self.check_runtime_shape_paths_for_nodes(
+ reshape_qkv,
+ reshape_q,
+ reshape_k,
+ reshape_v,
+ matmul_q.input[0],
+ ):
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
+ return
+ root_output = reshape_qkv.output[0]
+
+ # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
+ # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
+ # After: MatMul --> RotaryEmbedding
+ rotary_q.input[0] = matmul_q.output[0]
+ rotary_k.input[0] = matmul_k.output[0]
+
+ # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
+ rotary_k.output[0] = rotary_k.name + "_output_0"
+
+ new_node = self.create_mha_node(
+ matmul_q.input[0],
+ root_output,
+ rotary_q,
+ rotary_k,
+ matmul_v,
+ attn_mask,
+ add_qk_str,
+ past_k,
+ past_v,
+ present_k,
+ present_v,
+ )
+ if new_node is None:
+ logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
+ return
+
+ self.nodes_to_add.append(new_node)
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
+
+ self.nodes_to_remove.extend(qkv_nodes[1:])
+ self.nodes_to_remove.extend(v_nodes[:-1])
+ self.nodes_to_remove.extend(qk_nodes)
+
+ if k_nodes == k_nodes_1:
+ self.nodes_to_remove.extend(k_nodes[:-2])
+ elif k_nodes == k_nodes_2:
+ self.nodes_to_remove.append(k_nodes[0])
+ self.nodes_to_remove.append(k_nodes[2])
+ self.nodes_to_remove.append(k_nodes[3])
+ elif k_nodes == k_nodes_3:
+ self.nodes_to_remove.append(k_nodes[0])
+ self.nodes_to_remove.append(k_nodes[1])
+ self.nodes_to_remove.append(k_nodes[3])
+ self.nodes_to_remove.append(k_nodes[4])
+
+ if q_nodes == q_nodes_1:
+ self.nodes_to_remove.extend(q_nodes[:-2])
+ elif q_nodes == q_nodes_2:
+ self.nodes_to_remove.append(q_nodes[1])
+ self.nodes_to_remove.append(q_nodes[2])
+
+ self.prune_graph = True
+
+
+class FusionRotaryEmbeddings(Fusion):
+ def __init__(self, model: OnnxModel):
+ self.base_name = "RotaryEmbedding"
+ super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
+
+ # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
+ # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
+ # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
+ def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
+ # Find extra outputs and Constant nodes attached to those outputs
+ extra_constants, extra_outputs = [], []
+ for fn_node in function.node:
+ if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
+ extra_constants.append(fn_node)
+ output_index = list(function.output).index(fn_node.output[0])
+ extra_outputs.append(rot_emb_node.output[output_index])
+
+ # Set extra Constant node outputs as initializers
+ extra_initializers = []
+ for extra_constant in extra_constants:
+ constant_tensorproto = extra_constant.attribute[0].t
+ constant_tensorproto.name = self.model.create_node_name("Constant")
+ self.model.add_initializer(constant_tensorproto)
+ extra_initializers.append(constant_tensorproto.name)
+
+ # Update references of Constant node outputs to initializer references
+ for extra_output, extra_initializer in zip(extra_outputs, extra_initializers):
+ nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
+ for node_to_update in nodes_to_update:
+ OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
+
+ return extra_outputs
+
+ def create_rotary_embeddings_from_function(self, node: NodeProto):
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
+
+ matmul_path = self.model.match_parent_path(
+ node,
+ ["Reshape", "MatMul"],
+ [0, 0],
+ )
+ if matmul_path is not None:
+ reshape_node, matmul_node = matmul_path
+ else:
+ logger.debug("fuse_rotary_embeddings: failed to match MatMul")
+ return
+
+ rotary_emb_inputs = [
+ matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
+ node.input[1], # position_ids
+ ]
+
+ # Convert cos_cache and sin_cache from node attributes to model initializers
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
+
+ if (
+ len(cos_cache_node) == 1
+ and len(sin_cache_node) == 1
+ and self.model.get_initializer(cos_cache_name) is None
+ and self.model.get_initializer(sin_cache_name) is None
+ ):
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
+
+ cos_cache_tensor = helper.make_tensor(
+ name=cos_cache_name,
+ data_type=TensorProto.FLOAT,
+ dims=list(cos_cache.shape),
+ vals=cos_cache.flatten().tolist(),
+ )
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
+ sin_cache_tensor = helper.make_tensor(
+ name=sin_cache_name,
+ data_type=TensorProto.FLOAT,
+ dims=list(sin_cache.shape),
+ vals=sin_cache.flatten().tolist(),
+ )
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
+
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
+
+ rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
+
+ rotary_emb_outputs = node.output
+ if len(rotary_emb_outputs) > 1:
+ # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
+ func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
+ assert len(func) == 1
+ extra_outputs = self.reassign_extra_outputs(node, func[0])
+ rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
+ assert len(rotary_emb_outputs) == 1
+
+ rotary_emb_node = helper.make_node(
+ self.base_name,
+ inputs=rotary_emb_inputs,
+ outputs=rotary_emb_outputs,
+ name=rotary_emb_node_name,
+ interleaved=1,
+ )
+ rotary_emb_node.domain = "com.microsoft"
+
+ self.nodes_to_remove.append(reshape_node)
+
+ return rotary_emb_node
+
+ def create_rotary_embeddings_from_nodes(
+ self,
+ root_input: str,
+ position_ids: str,
+ cos_slice: str,
+ sin_slice: str,
+ output: str,
+ ):
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
+
+ # Convert cos_cache and sin_cache from node attributes to model initializers
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
+
+ if (
+ len(cos_cache_node) == 1
+ and len(sin_cache_node) == 1
+ and self.model.get_initializer(cos_cache_name) is None
+ and self.model.get_initializer(sin_cache_name) is None
+ ):
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
+
+ # Reshape cos/sin cache from (M, H) to (M, H/2)
+ head_size = cos_cache.shape[1]
+ cos_cache = cos_cache[:, : (head_size // 2)]
+ sin_cache = sin_cache[:, : (head_size // 2)]
+
+ cos_cache_tensor = helper.make_tensor(
+ name=cos_cache_name,
+ data_type=TensorProto.FLOAT,
+ dims=list(cos_cache.shape),
+ vals=cos_cache.flatten().tolist(),
+ )
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
+ sin_cache_tensor = helper.make_tensor(
+ name=sin_cache_name,
+ data_type=TensorProto.FLOAT,
+ dims=list(sin_cache.shape),
+ vals=sin_cache.flatten().tolist(),
+ )
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
+
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
+
+ rotary_emb_node = helper.make_node(
+ self.base_name,
+ inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
+ outputs=[output],
+ name=rotary_emb_node_name,
+ interleaved=0,
+ )
+ rotary_emb_node.domain = "com.microsoft"
+ return rotary_emb_node
+
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
+ # Node is either RotaryEmbedding function or Add
+ if self.base_name not in node.op_type and node.op_type != "Add":
+ return
+
+ # Check if node is "RotaryEmbedding nn.Module" exported as a function
+ # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
+ rotary_emb_node = None
+ if node.op_type != "Add":
+ # Verify that function has the correct inputs
+ if len(node.input) not in {4, 5} or node.input[1] not in {
+ "pos",
+ "pos_id",
+ "position_id",
+ "pos_ids",
+ "position_ids",
+ }:
+ logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
+ return
+
+ rotary_emb_node = self.create_rotary_embeddings_from_function(node)
+ if rotary_emb_node is None:
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
+ return
+
+ # Remove RotaryEmbedding function
+ self.nodes_to_remove.append(node)
+
+ # Remove RotaryEmbedding function's shape inference stored in value_info
+ # The new shape will be calculated during symbolic shape inference
+ old_shape_infer = list(
+ filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
+ )
+ assert len(old_shape_infer) == 1
+ self.model.model.graph.value_info.remove(old_shape_infer[0])
+
+ else:
+ # Rotary embeddings are defined using the below functions:
+ #
+ # def rotate_half(x):
+ # """Rotates half the hidden dims of the input."""
+ # x1 = x[..., : x.shape[-1] // 2]
+ # x2 = x[..., x.shape[-1] // 2 :]
+ # return torch.cat((-x2, x1), dim=-1)
+ #
+ # def apply_rope(x, cos, sin, position_ids):
+ # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ # x_embed = (x * cos) + (rotate_half(x) * sin)
+ # return x_embed
+
+ # Check paths for rotate_half(x)
+ rotate_half_x2_path_1 = self.model.match_parent_path(
+ node,
+ ["Mul", "Concat", "Neg", "Slice", "Transpose"],
+ [1, 0, 0, 0, 0],
+ )
+ rotate_half_x2_path_2 = self.model.match_parent_path(
+ node,
+ ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
+ [1, 0, 0, 0, 1, 0, 0, 0, 0],
+ )
+ if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
+ logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
+ return
+
+ rotate_half_x1_path_1 = self.model.match_parent_path(
+ node,
+ ["Mul", "Concat", "Slice", "Transpose"],
+ [1, 0, 1, 0],
+ )
+ rotate_half_x1_path_2 = self.model.match_parent_path(
+ node,
+ ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
+ [1, 0, 1, 2, 0, 0, 0, 0],
+ )
+ if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
+ logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
+ return
+
+ if (
+ rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
+ or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
+ or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
+ or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
+ ):
+ logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
+ return
+
+ # Check path for x
+ x_path = self.model.match_parent_path(
+ node,
+ ["Mul", "Transpose"],
+ [0, 0],
+ )
+ if x_path is None:
+ logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
+ return
+
+ # Check path for sin
+ sin_path, sin_cache, position_ids = None, "", ""
+ sin_path_1 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
+ [1, 1, 0, 0, 0, 0, 2, 0, 0],
+ )
+ sin_path_2 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
+ [1, 1, 0, 0, 0, 0, 2, 0],
+ )
+ sin_path_3 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
+ [1, 1, 0, 0, 2, 0, 0],
+ )
+ sin_path_4 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
+ [1, 1, 0, 0, 2, 0],
+ )
+ if sin_path_1 is not None:
+ sin_path = sin_path_1
+ sin_cache = sin_path[-4].input[0]
+ elif sin_path_2 is not None:
+ sin_path = sin_path_2
+ sin_cache = sin_path[-3].input[0]
+ elif sin_path_3 is not None:
+ sin_path = sin_path_3
+ sin_cache = sin_path[-4].input[0]
+ position_ids = sin_path[2].input[1]
+ elif sin_path_4 is not None:
+ sin_path = sin_path_4
+ sin_cache = sin_path[-3].input[0]
+ position_ids = sin_path[2].input[1]
+ else:
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
+ return
+
+ # Check path for cos
+ cos_path, cos_cache = None, ""
+ cos_path_1 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
+ [0, 1, 0, 0, 0, 0, 2, 0, 0],
+ )
+ cos_path_2 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
+ [0, 1, 0, 0, 0, 0, 2, 0],
+ )
+ cos_path_3 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
+ [0, 1, 0, 0, 2, 0, 0],
+ )
+ cos_path_4 = self.model.match_parent_path(
+ node,
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
+ [0, 1, 0, 0, 2, 0],
+ )
+ if cos_path_1 is not None:
+ cos_path = cos_path_1
+ cos_cache = cos_path[-4].input[0]
+ elif cos_path_2 is not None:
+ cos_path = cos_path_2
+ cos_cache = cos_path[-3].input[0]
+ elif cos_path_3 is not None:
+ cos_path = cos_path_3
+ cos_cache = cos_path[-4].input[0]
+ position_ids = cos_path[2].input[1]
+ elif cos_path_4 is not None:
+ cos_path = cos_path_4
+ cos_cache = cos_path[-3].input[0]
+ position_ids = cos_path[2].input[1]
+ else:
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
+ return
+
+ # Check path for position ids
+ if position_ids == "":
+ position_ids_from_sin_path = self.model.match_parent_path(
+ sin_path[2],
+ ["Reshape"],
+ [1],
+ )
+ position_ids_from_cos_path = self.model.match_parent_path(
+ cos_path[2],
+ ["Reshape"],
+ [1],
+ )
+ if (
+ position_ids_from_sin_path is None
+ or position_ids_from_cos_path is None
+ or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
+ ):
+ logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
+ return
+ position_ids = position_ids_from_cos_path[0].input[0]
+ else:
+ position_ids_from_sin_path = []
+ position_ids_from_cos_path = []
+
+ past_seq_len_path, curr_seq_len_path = None, None
+ if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
+ sin_path == sin_path_3 and cos_path == cos_path_3
+ ):
+ if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
+ logger.debug(
+ "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
+ )
+ return
+ elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
+ sin_path == sin_path_4 and cos_path == cos_path_4
+ ):
+ if sin_path[-1].name != cos_path[-1].name:
+ logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
+ return
+ # Match past sequence length path: past_key --> Shape --> Gather --> Add
+ past_seq_len_path = self.model.match_parent_path(
+ sin_path[-1],
+ ["Gather", "Shape"],
+ [1, 0],
+ )
+ # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
+ curr_seq_len_path = self.model.match_parent_path(
+ sin_path[-1],
+ ["Gather", "Shape", "Transpose"],
+ [0, 0, 0],
+ )
+ if (
+ past_seq_len_path is None
+ or curr_seq_len_path is None
+ or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
+ or curr_seq_len_path[-1].op_type != "Transpose"
+ ):
+ logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
+ return
+ else:
+ logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
+
+ rotary_emb_node = self.create_rotary_embeddings_from_nodes(
+ rotate_half_x1_path_1[-1].output[0],
+ position_ids,
+ cos_cache,
+ sin_cache,
+ node.output[0],
+ )
+ if rotary_emb_node is None:
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
+ return
+
+ # Remove rotary embedding nodes
+ self.add_nodes_to_remove([node])
+ self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
+ self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
+ self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
+ self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
+ self.add_nodes_to_remove(x_path[:-1])
+ self.add_nodes_to_remove(sin_path)
+ self.add_nodes_to_remove(cos_path)
+ self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
+ self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
+
+ if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
+ # In merged HF model, output of Gather in past_seq_len_path is used twice
+ # for past_key_values.0.key and once for other past_key_values
+ self.add_nodes_to_remove(past_seq_len_path)
+ if curr_seq_len_path is not None:
+ self.add_nodes_to_remove(curr_seq_len_path[:-1])
+
+ self.increase_counter(self.base_name)
+ self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
+ self.nodes_to_add.append(rotary_emb_node)
+ self.prune_graph = True
diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py
index 11d6b7a8d3cf4..bc32d78eda66c 100644
--- a/onnxruntime/python/tools/transformers/fusion_shape.py
+++ b/onnxruntime/python/tools/transformers/fusion_shape.py
@@ -48,22 +48,22 @@ def fuse(
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
- """
- Smplify subgraph like
-
- (2d_input)
- / \
- Shape shape
- / \
- Gather(indices=0) Gather(indices=1)
- | |
- Unsqueeze(axes=0) Unsqueeze(axes=0)
- \\ /
- Concat
- |
-
- into (2d_input) --> Shape -->
- """
+ #
+ # Simplify subgraph like
+ #
+ # (2d_input)
+ # / \
+ # Shape shape
+ # / \
+ # Gather(indices=0) Gather(indices=1)
+ # | |
+ # Unsqueeze(axes=0) Unsqueeze(axes=0)
+ # \ /
+ # Concat
+ # |
+ #
+ # into (2d_input) --> Shape -->
+ #
opset_version = self.model.get_opset_version()
inputs = len(concat_node.input)
diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py
new file mode 100644
index 0000000000000..6f35fa5617a39
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py
@@ -0,0 +1,141 @@
+import logging
+from typing import Dict
+
+from fusion_base import Fusion
+from fusion_skiplayernorm import FusionSkipLayerNormalization
+from onnx import helper
+from onnx_model import OnnxModel
+
+logger = logging.getLogger(__name__)
+
+
+class FusionSimplifiedLayerNormalization(Fusion):
+ def __init__(self, model: OnnxModel):
+ super().__init__(model, "SimplifiedLayerNormalization", "Mul")
+
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
+ if node.op_type != "Mul":
+ return
+
+ sim_ln_nodes = None
+ # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary):
+ # DD = Pow(D, 2)
+ # Var = ReduceMean(DD)
+ # VarEps = Add(Var, epsilon)
+ # StdDev = Sqrt(VarEps)
+ # InvStdDev = Div(1, StdDev)
+ # Normalized = Mul(D, InvStdDev)
+ # NormalizedScaled = Mul(Normalized, Scale)
+
+ # SimplifiedLayerNorm
+ # +-------------------------------------------------------+
+ # | |
+ # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
+ # |
+ # node
+ sim_ln_nodes_1 = self.model.match_parent_path(
+ node,
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
+ [1, 1, 1, 0, 0, 0, 0],
+ )
+ # SimplifiedLayerNorm
+ # +-------------------------------------------------------+
+ # | |
+ # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
+ # |
+ # node
+ sim_ln_nodes_2 = self.model.match_parent_path(
+ node,
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
+ [1, 1, 1, 0, 0, 0, 0],
+ )
+
+ # For LLaMA from Microsoft custom export:
+ # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1
+ #
+ # SimplifiedLayerNorm
+ # +-------------------------------------------------------+
+ # | |
+ # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
+ # |
+ # node
+ sim_ln_nodes_3 = self.model.match_parent_path(
+ node,
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
+ [0, 1, 1, 0, 0, 0, 0],
+ )
+
+ # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3
+ #
+ # SimplifiedLayerNorm
+ # +-----------------------------------------------+
+ # | |
+ # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
+ # |
+ # node
+ sim_ln_nodes_4 = self.model.match_parent_path(
+ node,
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"],
+ [0, 1, 1, 0, 0, 0],
+ )
+
+ add_node, pow_node = None, None
+ if sim_ln_nodes_1 is not None:
+ sim_ln_nodes = sim_ln_nodes_1
+ add_node = sim_ln_nodes[3]
+ pow_node = sim_ln_nodes[-2]
+ elif sim_ln_nodes_2 is not None:
+ sim_ln_nodes = sim_ln_nodes_2
+ add_node = sim_ln_nodes[3]
+ pow_node = sim_ln_nodes[-2]
+ elif sim_ln_nodes_3 is not None:
+ sim_ln_nodes = sim_ln_nodes_3
+ add_node = sim_ln_nodes[3]
+ pow_node = sim_ln_nodes[-2]
+ elif sim_ln_nodes_4 is not None:
+ sim_ln_nodes = sim_ln_nodes_4
+ add_node = sim_ln_nodes[3]
+ pow_node = sim_ln_nodes[-1]
+ # Verify that parent input to Pow node is graph_input
+ if pow_node.input[0] not in self.model.get_graphs_input_names():
+ return
+ else:
+ return
+
+ layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0
+ starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4
+
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
+ return
+
+ root_input = pow_node.input[0]
+ if root_input != sim_ln_nodes[0].input[0]:
+ return
+
+ i, add_weight = self.model.get_constant_input(add_node)
+ if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
+ logger.warning(f"epsilon value is not expected: {add_weight}")
+ return
+
+ self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes)
+ self.nodes_to_remove.append(node)
+
+ normalize_node = helper.make_node(
+ "SimplifiedLayerNormalization",
+ inputs=[root_input, node.input[layernorm_weight_index]],
+ outputs=[node.output[0]],
+ name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
+ )
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
+ normalize_node.attribute.extend([helper.make_attribute("axis", -1)])
+ normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
+ self.nodes_to_add.append(normalize_node)
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
+
+
+class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
+ def __init__(self, model: OnnxModel):
+ super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
+
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
+ super().fuse(node, input_name_to_nodes, output_name_to_node)
diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md
index b4461a2eadb8c..6057b46667fe6 100644
--- a/onnxruntime/python/tools/transformers/models/llama/README.md
+++ b/onnxruntime/python/tools/transformers/models/llama/README.md
@@ -17,12 +17,31 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
+As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:
+
+```
+# Before
+if self.use_cache:
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+ # Flatten the past_key_values (no need to flatten for models using multi-query attn)
+
+
+# After
+if self.use_cache:
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
+ # Flatten the past_key_values (no need to flatten for models using multi-query attn)
+```
+
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum)
+Note that this will produce two ONNX models whereas the above two options produce one ONNX model.
+
First, log into the Hugging Face CLI in your terminal:
```
@@ -56,38 +75,81 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b
```
-Export for FP16
+Export for FP32 CUDA
+```
+# From source:
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda
+```
+
+Export for FP32 CPU
```
# From source:
-$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu
# From wheel:
-$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu
```
-Export for INT8
+Export for FP16 CUDA
```
# From source:
-$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda
# From wheel:
-$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda
+```
+
+Export for INT8 CPU (SmoothQuant)
+```
+# From source:
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu
```
Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models.
+Export for INT8 CPU (DynamicQuant)
+```
+# From source:
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu
+```
+
+Export for INT4 CUDA
+```
+# From source:
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda
+```
+
+Export for INT4 CPU
+```
+# From source:
+$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu
+```
+
## Benchmark LLaMA-2
Here are some examples of how you can benchmark LLaMA-2.
-Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`.
-
### Variants
-1. PyTorch (without `torch.compile`), FP32
+1. PyTorch without `torch.compile`, FP32
```
python3 -m models.llama.benchmark \
- --benchmark-type hf-pt \
+ --benchmark-type hf-pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
@@ -96,10 +158,10 @@ python3 -m models.llama.benchmark \
--auth
```
-2. PyTorch 2.0 (with `torch.compile`), FP16
+2. PyTorch with `torch.compile`, FP16
```
python3 -m models.llama.benchmark \
- --benchmark-type hf-pt2 \
+ --benchmark-type hf-pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
@@ -112,7 +174,7 @@ python3 -m models.llama.benchmark \
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
- --hf-ort-model-path ./Llama-2-7b-hf-onnx/ \
+ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
@@ -125,7 +187,7 @@ python3 -m models.llama.benchmark \
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
- --hf-ort-model-path ./llama2-7b-fp16/ \
+ --hf-ort-dir-path ./llama2-7b-fp16/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
@@ -134,24 +196,35 @@ python3 -m models.llama.benchmark \
--auth
```
-5. Optimum + ONNX Runtime, INT8, export via convert_to_onnx
+5. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
- --benchmark-type hf-ort \
- --hf-ort-model-path ./llama2-7b-int8/ \
+ --benchmark-type ort-msft \
+ --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
- --precision int8 \
+ --precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
- --device cpu \
- --auth
+ --device cpu
```
-6. ONNX Runtime, FP32, Microsoft custom export
+6. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
- --benchmark-type ort \
- --ort-model-path llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
+ --benchmark-type ort-msft \
+ --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
+ --model-name meta-llama/Llama-2-7b-hf \
+ --precision fp16 \
+ --batch-sizes "1 2" \
+ --sequence-lengths "8 16" \
+ --device cuda
+```
+
+7. ONNX Runtime, FP32, convert_to_onnx
+```
+python3 -m models.llama.benchmark \
+ --benchmark-type ort-convert-to-onnx \
+ --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
--batch-sizes "1 2" \
@@ -159,11 +232,11 @@ python3 -m models.llama.benchmark \
--device cpu
```
-7. ONNX Runtime, FP16, Microsoft custom export
+8. ONNX Runtime, FP16, convert_to_onnx
```
python3 -m models.llama.benchmark \
- --benchmark-type ort \
- --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
+ --benchmark-type ort-convert-to-onnx \
+ --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
@@ -174,11 +247,14 @@ python3 -m models.llama.benchmark \
You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination.
### Benchmark All
-You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example.
+You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
```
python3 -m models.llama.benchmark_all \
- --hf-ort-model-path ./llama2-7b-fp16/ \
- --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
+ --hf-pt-eager \
+ --hf-pt-compile \
+ --hf-ort-dir-path ./llama2-7b-fp16/ \
+ --ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
+ --ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
index d19ed5cc28fed..976de2abc7c57 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
@@ -8,10 +8,17 @@
import time
import numpy as np
+import onnx
import psutil
import torch
from benchmark_helper import setup_logger
-from llama_inputs import get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
+from llama_inputs import (
+ convert_inputs_for_ort,
+ get_merged_sample_with_past_kv_inputs,
+ get_msft_sample_inputs,
+ get_sample_inputs,
+ get_sample_with_past_kv_inputs,
+)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
@@ -23,8 +30,29 @@
logger = logging.getLogger(__name__)
-def get_inputs(args: argparse.Namespace):
- if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
+# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
+def get_ort_model_inputs_len(args, model):
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
+ return 0
+ if args.benchmark_type == "hf-ort":
+ try:
+ # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
+ return len(model.inputs_names)
+ except Exception:
+ # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
+ return len(model.decoder.input_names)
+ return len(model.get_inputs())
+
+
+def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
+ init_inputs, iter_inputs = None, None
+
+ # For past_present_share_buffer:
+ # Set max_seq_len to 2048 for Hugging Face model since that is the default value
+ # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported
+ max_seq_len = 2048
+
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
@@ -41,14 +69,95 @@ def get_inputs(args: argparse.Namespace):
return_dict=True,
)
- elif args.benchmark_type == "ort":
+ elif args.benchmark_type == "hf-ort":
+ if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
+ # Using split models in Optimum (e.g. created by Optimum export)
+ init_inputs = get_sample_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ args.sequence_length,
+ return_dict=True,
+ )
+ iter_inputs = get_sample_with_past_kv_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ args.sequence_length,
+ use_fp16=args.use_fp16,
+ return_dict=True,
+ )
+ else:
+ # Using merged model in Optimum (e.g. created by convert_to_onnx export)
+ init_inputs = get_merged_sample_with_past_kv_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ seq_len=args.sequence_length,
+ past_seq_len=0,
+ use_fp16=args.use_fp16,
+ return_dict=True,
+ )
+ iter_inputs = get_merged_sample_with_past_kv_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ seq_len=1,
+ past_seq_len=args.sequence_length,
+ use_fp16=args.use_fp16,
+ return_dict=True,
+ )
+
+ elif args.benchmark_type == "ort-convert-to-onnx":
+ # Microsoft export from convert_to_onnx
+ init_inputs = get_merged_sample_with_past_kv_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ seq_len=args.sequence_length,
+ past_seq_len=0,
+ use_fp16=args.use_fp16,
+ return_dict=True,
+ )
+ iter_inputs = get_merged_sample_with_past_kv_inputs(
+ args.config,
+ args.target_device,
+ args.batch_size,
+ seq_len=1,
+ past_seq_len=args.sequence_length,
+ use_fp16=args.use_fp16,
+ return_dict=True,
+ )
+ init_inputs = convert_inputs_for_ort(
+ init_inputs,
+ use_fp16=args.use_fp16,
+ use_buffer_share=args.past_present_share_buffer,
+ past_seq_len=0,
+ max_seq_len=max_seq_len,
+ device=args.device,
+ device_id=args.device_id,
+ )
+ iter_inputs = convert_inputs_for_ort(
+ iter_inputs,
+ use_fp16=args.use_fp16,
+ use_buffer_share=args.past_present_share_buffer,
+ past_seq_len=args.sequence_length,
+ max_seq_len=max_seq_len,
+ device=args.device,
+ device_id=args.device_id,
+ )
+
+ elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
+ split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
+
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
use_fp16=args.use_fp16,
+ split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
args.config,
@@ -56,6 +165,25 @@ def get_inputs(args: argparse.Namespace):
past_seq_len=args.sequence_length,
seq_len=1,
use_fp16=args.use_fp16,
+ split_kv=split_kv,
+ )
+ init_inputs = convert_inputs_for_ort(
+ init_inputs,
+ use_fp16=args.use_fp16,
+ use_buffer_share=args.past_present_share_buffer,
+ past_seq_len=0,
+ max_seq_len=max_seq_len,
+ device=args.device,
+ device_id=args.device_id,
+ )
+ iter_inputs = convert_inputs_for_ort(
+ iter_inputs,
+ use_fp16=args.use_fp16,
+ use_buffer_share=args.past_present_share_buffer,
+ past_seq_len=args.sequence_length,
+ max_seq_len=max_seq_len,
+ device=args.device,
+ device_id=args.device_id,
)
else:
@@ -69,12 +197,14 @@ def get_model(args: argparse.Namespace):
start_time, end_time = None, None
# There are multiple sources that the model could come from:
- # 1) Benchmark LLaMA from unofficial source on Hugging Face
- # 2) Benchmark LLaMA from official source on Hugging Face, which requires an authentication token
- # 3) Benchmark LLaMA from local download of model
-
- if args.benchmark_type in {"hf-pt", "hf-pt2"}:
- source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
+ # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
+ # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
+ # 3) Benchmark LLaMA-2 from local download of model
+ # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
+ # 5) Benchmark LLaMA-2 from convert_to_onnx
+
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
+ source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
start_time = time.time()
model = LlamaForCausalLM.from_pretrained(
source,
@@ -84,10 +214,10 @@ def get_model(args: argparse.Namespace):
).to(args.target_device)
end_time = time.time()
- if args.benchmark_type == "hf-pt2":
+ if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
- elif args.benchmark_type in {"hf-ort", "ort"}:
+ elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
@@ -104,32 +234,33 @@ def get_model(args: argparse.Namespace):
decoder_file_name = None
decoder_with_past_file_name = None
- for filename in os.listdir(args.hf_ort_model_path):
+ for filename in os.listdir(args.hf_ort_dir_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
- if "decoder_model.onnx" in filename or f"decoder_model_{args.precision}.onnx" in filename:
+ if "decoder_model" in filename or filename == "model.onnx":
+ decoder_file_name = filename
+ if "decoder_with_past_model" in filename:
+ decoder_with_past_file_name = filename
+ if "decoder_merged_model" in filename:
decoder_file_name = filename
- if (
- "decoder_with_past_model.onnx" in filename
- or f"decoder_with_past_model_{args.precision}.onnx" in filename
- ):
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
- args.hf_ort_model_path,
+ args.hf_ort_dir_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
use_io_binding=(args.device != "cpu"),
+ use_merged=(True if decoder_file_name == "model.onnx" else None),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
- if args.benchmark_type == "ort":
- # Microsoft export from https://github.com/microsoft/Llama-2-Onnx
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
+ # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path}")
start_time = time.time()
model = ort.InferenceSession(
@@ -140,7 +271,6 @@ def get_model(args: argparse.Namespace):
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
-
return model
@@ -148,7 +278,7 @@ def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
- if args.benchmark_type == "ort"
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
@@ -166,7 +296,7 @@ def time_fn(args, fn, inputs):
bench_range = (
range(args.num_runs)
- if args.benchmark_type == "ort"
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
@@ -177,7 +307,7 @@ def time_fn(args, fn, inputs):
end_time = time.time()
# Newline print after trange in order to print metrics on new lines without progress bar on same line
- if args.benchmark_type != "ort":
+ if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")
latency = (end_time - start_time) / args.num_runs
@@ -186,7 +316,7 @@ def time_fn(args, fn, inputs):
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
- logger.info(f"Throughput: {throughput} qps")
+ logger.info(f"Throughput: {throughput} tps")
return
@@ -196,7 +326,7 @@ def profile_fn(args, fn, inputs, inputs_type):
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
- if args.benchmark_type in {"hf-pt", "hf-pt2"}:
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
@@ -267,7 +397,7 @@ def get_logits(inputs):
generate_fn = get_logits
- if args.benchmark_type == "hf-pt2":
+ if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
@@ -280,7 +410,7 @@ def get_logits(inputs):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
- new_logname = profile_fn(args, generate_fn, iter_inputs, "per-token")
+ new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
@@ -319,10 +449,24 @@ def prepare_ort_inputs(inputs):
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
+
for k, v in inputs.items():
- io_binding.bind_cpu_input(k, v)
+ if args.past_present_share_buffer:
+ # Bind all OrtValue inputs to device
+ io_binding.bind_ortvalue_input(k, v)
+ else:
+ io_binding.bind_cpu_input(k, v)
+
for output in model.get_outputs():
- io_binding.bind_output(output.name)
+ name = output.name
+ if args.past_present_share_buffer and ("out" in name or "present" in name):
+ # Bind present KV cache outputs to OrtValue with buffer sharing
+ io_binding.bind_ortvalue_output(
+ name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
+ )
+ else:
+ io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)
+
return io_binding
return inputs
@@ -350,7 +494,7 @@ def without_io_binding(inputs):
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
- new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "per-token")
+ new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
@@ -371,9 +515,9 @@ def without_io_binding(inputs):
def run_inference(args, init_inputs, iter_inputs, model):
- if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
- elif args.benchmark_type == "ort":
+ elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
@@ -382,7 +526,11 @@ def run_inference(args, init_inputs, iter_inputs, model):
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
- "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
+ "-bt",
+ "--benchmark-type",
+ type=str,
+ required=True,
+ choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"],
)
parser.add_argument(
"-m",
@@ -402,20 +550,20 @@ def get_args():
required=True,
type=str,
default="fp32",
- choices=["int8", "fp16", "fp32"],
+ choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
- "--hf-pt-model-path",
+ "--hf-pt-dir-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
- "--hf-ort-model-path",
+ "--hf-ort-dir-path",
type=str,
default="",
- help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
+ help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
@@ -475,15 +623,20 @@ def get_args():
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
args.device = "cuda"
- # Check that model paths have been specified for any benchmarking with ORT
+ # Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
- assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
- if args.benchmark_type == "ort":
+ assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
+ args.precision = (
+ "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
+ )
+
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert (
@@ -509,14 +662,27 @@ def main():
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
- # Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
+ # Get model and model info
model = get_model(args)
+ ort_model_inputs_len = get_ort_model_inputs_len(args, model)
+
+ # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
+ if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
+ onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False)
+ gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
+
+ use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
+ setattr(args, "past_present_share_buffer", use_buffer_share) # noqa: B010
+ else:
+ setattr(args, "past_present_share_buffer", False) # noqa: B010
+
+ # Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
- init_inputs, iter_inputs = get_inputs(args)
+ init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
run_inference(args, init_inputs, iter_inputs, model)
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
index 7199c945fe6ba..951b2549368f7 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
@@ -43,15 +43,38 @@ def get_args():
)
parser.add_argument(
- "--hf-ort-model-path",
+ "--hf-pt-eager",
+ default=False,
+ action="store_true",
+ help="Benchmark in PyTorch without `torch.compile`",
+ )
+
+ parser.add_argument(
+ "--hf-pt-compile",
+ default=False,
+ action="store_true",
+ help="Benchmark in PyTorch with `torch.compile`",
+ )
+
+ parser.add_argument(
+ "--hf-ort-dir-path",
type=str,
+ default="",
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
- "--ort-model-path",
+ "--ort-msft-model-path",
+ type=str,
+ default="",
+ help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx",
+ )
+
+ parser.add_argument(
+ "--ort-convert-to-onnx-model-path",
type=str,
- help="Path to ONNX model for ORT benchmarking",
+ default="",
+ help="Path to ONNX model from convert_to_onnx",
)
parser.add_argument(
@@ -65,7 +88,7 @@ def get_args():
"--precision",
type=str,
required=True,
- choices=["int8", "fp16", "fp32"],
+ choices=["int4", "int8", "fp16", "fp32"],
help="Precision to run model",
)
@@ -138,8 +161,6 @@ def process_log_file(device_id, log_file, base_results):
step = "per-token"
elif latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
- if step == "prompt":
- latency_s /= sequence_length
latency_ms = latency_s * 1000
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
@@ -184,7 +205,7 @@ def save_results(results, filename):
"Step",
"Latency (s)",
"Latency (ms)",
- "Throughput (qps)",
+ "Throughput (tps)",
"Memory (GB)",
],
)
@@ -194,7 +215,7 @@ def save_results(results, filename):
df["Sequence Length"] = df["Sequence Length"].astype("int")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
- df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
+ df["Throughput (tps)"] = df["Throughput (tps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
df.to_csv(filename, index=False)
@@ -226,75 +247,81 @@ def main():
torch.backends.cudnn.benchmark = True
all_results = []
+
# Benchmark PyTorch without torch.compile
- benchmark_cmd = [
- "python3",
- "benchmark.py",
- "--benchmark-type",
- "hf-pt",
- "--model-name",
- args.model_name,
- "--precision",
- args.precision,
- "--batch-sizes",
- args.batch_sizes,
- "--sequence-lengths",
- args.sequence_lengths,
- "--device",
- args.device,
- "--device-id",
- str(args.device_id),
- "--warmup-runs",
- str(args.warmup_runs),
- "--num-runs",
- str(args.num_runs),
- "--log-folder",
- args.log_folder,
- "--auth",
- ]
- logger.info("Benchmark PyTorch without torch.compile")
- results = benchmark(args, benchmark_cmd, "pytorch")
- all_results.extend(results)
+ if args.hf_pt_eager:
+ benchmark_cmd = [
+ "python",
+ "-m",
+ "models.llama.benchmark",
+ "--benchmark-type",
+ "hf-pt-eager",
+ "--model-name",
+ args.model_name,
+ "--precision",
+ args.precision,
+ "--batch-sizes",
+ args.batch_sizes,
+ "--sequence-lengths",
+ args.sequence_lengths,
+ "--device",
+ args.device,
+ "--device-id",
+ str(args.device_id),
+ "--warmup-runs",
+ str(args.warmup_runs),
+ "--num-runs",
+ str(args.num_runs),
+ "--log-folder",
+ args.log_folder,
+ "--auth",
+ ]
+ logger.info("Benchmark PyTorch without torch.compile")
+ results = benchmark(args, benchmark_cmd, "pytorch-eager")
+ all_results.extend(results)
# Benchmark PyTorch with torch.compile
- benchmark_cmd = [
- "python3",
- "benchmark.py",
- "--benchmark-type",
- "hf-pt2",
- "--model-name",
- args.model_name,
- "--precision",
- args.precision,
- "--batch-sizes",
- args.batch_sizes,
- "--sequence-lengths",
- args.sequence_lengths,
- "--device",
- args.device,
- "--device-id",
- str(args.device_id),
- "--warmup-runs",
- str(args.warmup_runs),
- "--num-runs",
- str(args.num_runs),
- "--log-folder",
- args.log_folder,
- "--auth",
- ]
- logger.info("Benchmark PyTorch with torch.compile")
- results = benchmark(args, benchmark_cmd, "pytorch-2")
- all_results.extend(results)
+ if args.hf_pt_compile:
+ benchmark_cmd = [
+ "python",
+ "-m",
+ "models.llama.benchmark",
+ "--benchmark-type",
+ "hf-pt-compile",
+ "--model-name",
+ args.model_name,
+ "--precision",
+ args.precision,
+ "--batch-sizes",
+ args.batch_sizes,
+ "--sequence-lengths",
+ args.sequence_lengths,
+ "--device",
+ args.device,
+ "--device-id",
+ str(args.device_id),
+ "--warmup-runs",
+ str(args.warmup_runs),
+ "--num-runs",
+ str(args.num_runs),
+ "--log-folder",
+ args.log_folder,
+ "--auth",
+ ]
+ logger.info("Benchmark PyTorch with torch.compile")
+ results = benchmark(args, benchmark_cmd, "pytorch-compile")
+ all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
- if args.hf_ort_model_path:
+ if args.hf_ort_dir_path:
benchmark_cmd = [
- "python3",
- "benchmark.py",
+ "python",
+ "-m",
+ "models.llama.benchmark",
"--benchmark-type",
"hf-ort",
- "--hf-ort-model-path",
- args.hf_ort_model_path,
+ "--hf-ort-dir-path",
+ args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
@@ -316,18 +343,52 @@ def main():
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
- results = benchmark(args, benchmark_cmd, "pytorch-ort")
+ results = benchmark(args, benchmark_cmd, "optimum-ort")
+ all_results.extend(results)
+
+ # Benchmark Microsoft model in ONNX Runtime
+ if args.ort_msft_model_path:
+ benchmark_cmd = [
+ "python",
+ "-m",
+ "models.llama.benchmark",
+ "--benchmark-type",
+ "ort-msft",
+ "--ort-model-path",
+ args.ort_msft_model_path,
+ "--model-name",
+ args.model_name,
+ "--precision",
+ args.precision,
+ "--batch-sizes",
+ args.batch_sizes,
+ "--sequence-lengths",
+ args.sequence_lengths,
+ "--device",
+ args.device,
+ "--device-id",
+ str(args.device_id),
+ "--warmup-runs",
+ str(args.warmup_runs),
+ "--num-runs",
+ str(args.num_runs),
+ "--log-folder",
+ args.log_folder,
+ ]
+ logger.info("Benchmark Microsoft model in ONNX Runtime")
+ results = benchmark(args, benchmark_cmd, "ort-msft")
all_results.extend(results)
- # Benchmark ONNX Runtime
- if args.ort_model_path:
+ # Benchmark convert_to_onnx model in ONNX Runtime
+ if args.ort_convert_to_onnx_model_path:
benchmark_cmd = [
- "python3",
- "benchmark.py",
+ "python",
+ "-m",
+ "models.llama.benchmark",
"--benchmark-type",
- "ort",
+ "ort-convert-to-onnx",
"--ort-model-path",
- args.ort_model_path,
+ args.ort_convert_to_onnx_model_path,
"--model-name",
args.model_name,
"--precision",
@@ -347,7 +408,7 @@ def main():
"--log-folder",
args.log_folder,
]
- logger.info("Benchmark ONNX Runtime")
+ logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
all_results.extend(results)
diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
index f96347ba67aa6..61d71bc38f4e9 100644
--- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
@@ -8,12 +8,16 @@
import onnx
import torch
from benchmark_helper import Precision, prepare_environment, setup_logger
-from llama_inputs import get_sample_inputs, get_sample_with_past_kv_inputs
+from convert_generation import replace_mha_with_gqa
+from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
from llama_parity import main as parity_check
from onnx_model import OnnxModel
+from optimizer import optimize_model
+from packaging import version
from transformers import LlamaConfig, LlamaForCausalLM
from onnxruntime import quantization as ort_quantization
+from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
logger = logging.getLogger("")
@@ -58,6 +62,33 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li
return dynamic_axes
+def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]):
+ dynamic_axes = {}
+ for name in input_names + output_names:
+ if name in {"input_ids", "position_ids"}:
+ # shape is (batch_size, sequence_length)
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
+ elif name == "attention_mask":
+ # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length)
+ # for prompt generation, past_sequence_length = 0
+ # for token generation, sequence_length = 1
+ dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"}
+ elif "past" in name:
+ # shape is (batch_size, num_heads, past_sequence_length, head_size)
+ dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
+ elif name == "logits":
+ # shape is (batch_size, sequence_length, vocab_size)
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
+ elif "present" in name:
+ # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size)
+ # for prompt generation, past_sequence_length = 0
+ # for token generation, sequence_length = 1
+ dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
+ else:
+ raise Exception("Unknown input or output name found")
+ return dynamic_axes
+
+
def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str):
onnx.save(
onnx_model,
@@ -152,7 +183,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
-def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
+def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
@@ -248,12 +279,206 @@ def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llam
logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
+def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM):
+ # Dummy values for export
+ batch_size, sequence_length, past_sequence_length = 2, 8, 0
+ device = torch.device("cpu")
+
+ # Export decoder_merged_model.onnx
+ decoder_merged_inputs = get_merged_sample_with_past_kv_inputs(
+ l_config, device, batch_size, sequence_length, past_sequence_length
+ )
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "position_ids",
+ *list(
+ chain.from_iterable(
+ (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
+ )
+ ),
+ ]
+ output_names = [
+ "logits",
+ *list(
+ chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
+ ),
+ ]
+ dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
+ temp_dir = tempfile.TemporaryDirectory()
+ temp_path = os.path.join(temp_dir.name, "temp.onnx")
+ torch.onnx.export(
+ llama,
+ args=decoder_merged_inputs,
+ f=temp_path,
+ export_params=True,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ opset_version=13,
+ do_constant_folding=True,
+ verbose=args.verbose,
+ )
+
+ # Check decoder_merged_model.onnx and save all external data to one file
+ onnx.checker.check_model(temp_path)
+ onnx.shape_inference.infer_shapes_path(temp_path)
+
+ output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
+ onnx_model = onnx.load_model(temp_path, load_external_data=True)
+ save_onnx_model(
+ onnx_model,
+ output_path,
+ f"{args.model_name}_decoder_merged_model_fp32.onnx.data",
+ )
+ del onnx_model
+ temp_dir.cleanup()
+
+ logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!")
+
+
+# Optimize the model as FP32
+def optimize_export(config: LlamaConfig, input_path: str, output_path: str):
+ from fusion_options import FusionOptions
+
+ optimization_options = FusionOptions("gpt2")
+
+ model_opt = optimize_model(
+ input_path,
+ model_type="gpt2",
+ num_heads=config.num_attention_heads,
+ hidden_size=config.hidden_size,
+ opt_level=0,
+ optimization_options=optimization_options,
+ only_onnxruntime=False,
+ )
+ model_opt.save_model_to_file(output_path, use_external_data_format=True)
+ logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
+ remove_existing_model(input_path)
+
+
+def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]):
+ decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
+ decoder_with_past_model_fp16_path = os.path.join(
+ args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
+ )
+ decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx")
+ new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path]
+
+ logger.info("Converting to float16...")
+ for fp32_path, fp16_path in zip(old_paths, new_paths):
+ if os.path.exists(fp32_path):
+ model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True))
+ model.convert_float_to_float16(keep_io_types=False)
+ model = use_group_query_attention(config, model)
+ model.save_model_to_file(fp16_path, use_external_data_format=True)
+ del model
+ logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!")
+ remove_existing_model(fp32_path)
+
+ logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!")
+ return new_paths
+
+
+def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel):
+ # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes
+ fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads)
+ fp16_model_opt.prune_graph()
+ fp16_model_opt.update_graph(allow_remove_graph_inputs=True)
+ return fp16_model_opt
+
+
+def smooth_quant(
+ args: argparse.Namespace,
+ decoder_model_fp32_path: str,
+ decoder_with_past_model_fp32_path: str,
+ decoder_model_int8_path: str,
+ decoder_with_past_model_int8_path: str,
+):
+ from neural_compressor import PostTrainingQuantConfig
+ from neural_compressor import quantization as intel_quantization
+ from neural_compressor import set_workspace
+ from onnx.external_data_helper import load_external_data_for_model
+ from quant_kv_dataloader import QuantKVDataLoader
+
+ set_workspace(args.nc_workspace)
+ quantization_config = PostTrainingQuantConfig(
+ calibration_sampling_size=[args.calibration_sampling_size],
+ recipes={
+ "optypes_to_exclude_output_quant": ["MatMul"],
+ "smooth_quant": args.smooth_quant,
+ "smooth_quant_args": {"alpha": args.smooth_quant_alpha},
+ },
+ op_type_dict={
+ "^((?!(MatMul|Gather|Conv)).)*$": {
+ "weight": {"dtype": ["fp32"]},
+ "activation": {"dtype": ["fp32"]},
+ }
+ },
+ )
+
+ # Convert decoder_model.onnx to INT8
+ decoder_model_int8 = intel_quantization.fit(
+ decoder_model_fp32_path,
+ quantization_config,
+ calib_dataloader=QuantKVDataLoader(args),
+ )
+ load_external_data_for_model(
+ decoder_model_int8._model,
+ os.path.split(decoder_model_int8._model_path)[0],
+ )
+ save_onnx_model(
+ decoder_model_int8._model,
+ decoder_model_int8_path,
+ f"{args.model_name}_decoder_model_int8.onnx.data",
+ )
+ del decoder_model_int8
+ logger.info(
+ f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!"
+ )
+ remove_existing_model(decoder_model_fp32_path)
+
+ # Convert decoder_with_past_model.onnx to INT8
+ decoder_with_past_model_int8 = intel_quantization.fit(
+ decoder_with_past_model_fp32_path,
+ quantization_config,
+ calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
+ )
+ load_external_data_for_model(
+ decoder_with_past_model_int8._model,
+ os.path.split(decoder_with_past_model_int8._model_path)[0],
+ )
+ save_onnx_model(
+ decoder_with_past_model_int8._model,
+ decoder_with_past_model_int8_path,
+ f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
+ )
+ del decoder_with_past_model_int8
+ logger.info(
+ f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!"
+ )
+ remove_existing_model(decoder_with_past_model_fp32_path)
+
+ logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
+
+ logger.info(f"Removing {args.nc_workspace}")
+ os.system(f"rm -R {args.nc_workspace}")
+
+
+def remove_existing_model(model_path: str):
+ # Remove ONNX model and its external data
+ data_path = os.path.join(model_path + ".data")
+ os.remove(model_path)
+ os.remove(data_path)
+ logger.warning(f"Removed {model_path} and {data_path}")
+
+
def remove_existing_files(output_path: str):
for filename in os.listdir(output_path):
filepath = os.path.join(output_path, filename)
if ".onnx" in filename or ".onnx.data" in filename:
os.remove(filepath)
- logger.warning(f"Removing {filepath}")
+ logger.warning(f"Removed {filepath}")
def get_args():
@@ -288,7 +513,7 @@ def get_args():
required=False,
type=Precision,
default=Precision.FLOAT32,
- choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
+ choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
help="Precision to export model in",
)
@@ -301,15 +526,51 @@ def get_args():
help="Execution provider to verify parity with",
)
+ parser.add_argument(
+ "-id",
+ "--device-id",
+ required=False,
+ type=str,
+ default="0",
+ help="Device ID for GPUs",
+ )
+
+ parser.add_argument(
+ "-r",
+ "--reexport",
+ required=False,
+ action="store_true",
+ help="Re-export models and overwrite existing models in output folder",
+ )
+ parser.set_defaults(reexport=False)
+
+ parser.add_argument(
+ "--no_merged",
+ required=False,
+ action="store_true",
+ help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.",
+ )
+ parser.set_defaults(no_merged=False)
+
parser.add_argument(
"-q",
"--quantization_method",
default="",
- choices=["smooth_quant", "quantize_dynamic"],
- help="Run a specific quantization algorithm. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
+ choices=["blockwise", "smooth_quant", "quantize_dynamic"],
+ help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
)
- smooth_quant_group = parser.add_argument_group("smooth_quant")
+ blockwise_group = parser.add_argument_group("4-bit quantization")
+
+ blockwise_group.add_argument(
+ "--block_size",
+ required=False,
+ default=32,
+ type=int,
+ help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
+ )
+
+ smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)")
smooth_quant_group.add_argument(
"--smooth_quant_alpha",
@@ -352,7 +613,7 @@ def get_args():
help="Workspace to save intermediate files generated by Intel's Neural Compressor package.",
)
- quantize_dynamic_group = parser.add_argument_group("quantize_dynamic")
+ quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)")
quantize_dynamic_group.add_argument(
"--quantize_embedding_layer",
@@ -399,177 +660,193 @@ def get_args():
def main():
+ if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__:
+ # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false
+ # in that scenario. It can be removed when torch v2.2.0 is released in stable.
+ logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.")
+ return
+
args = get_args()
setup_logger(args.verbose)
prepare_environment(args.input, args.output, args.execution_provider != "cpu")
- remove_existing_files(args.output)
+ if args.reexport:
+ remove_existing_files(args.output)
logger.info(f"Arguments: {args}")
# Load model and config
use_auth_token = args.input == os.path.join(".")
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
- l_config = LlamaConfig.from_pretrained(
- args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token
- )
- llama = LlamaForCausalLM.from_pretrained(
- args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token, use_cache=True
- )
+
+ location = args.model_name if use_auth_token else args.input
+ l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
+ llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True)
original_model_name = args.model_name
setattr(args, "original_model_name", original_model_name) # noqa: B010
args.model_name = args.model_name.split("/")[-1]
- # Export to ONNX
- if args.use_dynamo_export:
- logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
- logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
- logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
- logger.warning(
- "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
- )
- logger.warning(
- "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
- )
- run_dynamo_export(args, l_config, llama)
- else:
- run_torchscript_export(args, l_config, llama)
-
- # Change precision of exported models if not FP32
+ # Set model paths for FP32 model
decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx")
decoder_with_past_model_fp32_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx"
)
+ decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx")
+ old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
+
+ missing_separate_exports = (
+ args.no_merged
+ and not os.path.exists(decoder_model_fp32_path)
+ and not os.path.exists(decoder_with_past_model_fp32_path)
+ )
+ missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
+
+ # Export to ONNX
+ if missing_separate_exports or missing_merged_export:
+ if args.use_dynamo_export and missing_separate_exports:
+ logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
+ logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
+ logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
+ logger.warning(
+ "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
+ )
+ logger.warning(
+ "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
+ )
+ run_dynamo_export(args, l_config, llama)
+ elif args.no_merged:
+ run_torchscript_separate_export(args, l_config, llama)
+ else:
+ run_torchscript_merged_export(args, l_config, llama)
+
+ # Set model paths to store FP32 optimized model
+ decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
+ decoder_with_past_model_fp32_opt_path = os.path.join(
+ args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
+ )
+ decoder_merged_model_fp32_opt_path = os.path.join(
+ args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx"
+ )
+ new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path]
+
+ # Run the optimizer script
+ logger.info("Optimizing models...")
+ for orig_path, opt_path in zip(old_paths, new_paths):
+ if os.path.exists(orig_path):
+ optimize_export(l_config, input_path=orig_path, output_path=opt_path)
+
+ # Re-assign default FP32 model paths as their optimized versions
+ decoder_model_fp32_path = decoder_model_fp32_opt_path
+ decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
+ decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
+ old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
+
+ logger.info(
+ f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
+ )
+ # Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
- # Convert decoder_model.onnx to FP16
- decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx")
- model = OnnxModel(onnx.load_model(decoder_model_fp32_path, load_external_data=True))
- model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
- model.save_model_to_file(decoder_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True)
- del model
-
- # Convert decoder_with_past_model.onnx to FP16
- decoder_with_past_model_fp16_path = os.path.join(
- args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx"
- )
- model = OnnxModel(onnx.load_model(decoder_with_past_model_fp32_path, load_external_data=True))
- model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"])
- model.save_model_to_file(
- decoder_with_past_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True
- )
- del model
+ new_paths = convert_to_float16(args, l_config, old_paths)
elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx")
decoder_with_past_model_int8_path = os.path.join(
args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx"
)
+ decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx")
+ new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
if args.quantization_method == "smooth_quant":
- from neural_compressor import PostTrainingQuantConfig
- from neural_compressor import quantization as intel_quantization
- from neural_compressor import set_workspace
- from onnx.external_data_helper import load_external_data_for_model
- from quant_kv_dataloader import QuantKVDataLoader
-
- set_workspace(args.nc_workspace)
- quantization_config = PostTrainingQuantConfig(
- calibration_sampling_size=[args.calibration_sampling_size],
- recipes={
- "optypes_to_exclude_output_quant": ["MatMul"],
- "smooth_quant": args.smooth_quant,
- "smooth_quant_args": {"alpha": args.smooth_quant_alpha},
- },
- op_type_dict={
- "^((?!(MatMul|Gather|Conv)).)*$": {
- "weight": {"dtype": ["fp32"]},
- "activation": {"dtype": ["fp32"]},
- }
- },
- )
-
- # Convert decoder_model.onnx to INT8
- decoder_model_int8 = intel_quantization.fit(
- decoder_model_fp32_path,
- quantization_config,
- calib_dataloader=QuantKVDataLoader(args),
- )
- load_external_data_for_model(
- decoder_model_int8._model,
- os.path.split(decoder_model_int8._model_path)[0],
- )
- save_onnx_model(
- decoder_model_int8._model,
- decoder_model_int8_path,
- f"{args.model_name}_decoder_model_int8.onnx.data",
- )
- del decoder_model_int8
-
- # Convert decoder_with_past_model.onnx to INT8
- decoder_with_past_model_int8 = intel_quantization.fit(
- decoder_with_past_model_fp32_path,
- quantization_config,
- calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
- )
- load_external_data_for_model(
- decoder_with_past_model_int8._model,
- os.path.split(decoder_with_past_model_int8._model_path)[0],
- )
- save_onnx_model(
- decoder_with_past_model_int8._model,
- decoder_with_past_model_int8_path,
- f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
- )
- del decoder_with_past_model_int8
-
- logger.info(f"Removing {args.nc_workspace}")
- os.system(f"rm -R {args.nc_workspace}")
+ if not args.no_merged:
+ logger.error("SmoothQuant must be used on separately exported models")
+ else:
+ logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8")
+ smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
)
- # Convert decoder_model.onnx to INT8
- ort_quantization.quantize_dynamic(
- decoder_model_fp32_path,
- decoder_model_int8_path,
- op_types_to_quantize=["MatMul", "Gemm", "Gather"]
- if args.quantize_embedding_layer
- else ["MatMul", "Gemm"],
- per_channel=args.quantize_per_channel,
- reduce_range=args.quantize_reduce_range,
- use_external_data_format=True,
- extra_options={"MatMulConstBOnly": True},
- )
-
- # Convert decoder_with_past_model.onnx to INT8
- ort_quantization.quantize_dynamic(
- decoder_with_past_model_fp32_path,
- decoder_with_past_model_int8_path,
- op_types_to_quantize=["MatMul", "Gemm", "Gather"]
- if args.quantize_embedding_layer
- else ["MatMul", "Gemm"],
- per_channel=args.quantize_per_channel,
- reduce_range=args.quantize_reduce_range,
- use_external_data_format=True,
- extra_options={"MatMulConstBOnly": True},
- )
+ logger.info("Quantizing to int8...")
+ for fp32_path, int8_path in zip(old_paths, new_paths):
+ if os.path.exists(fp32_path):
+ ort_quantization.quantize_dynamic(
+ fp32_path,
+ int8_path,
+ op_types_to_quantize=["MatMul", "Gemm", "Gather"]
+ if args.quantize_embedding_layer
+ else ["MatMul", "Gemm"],
+ per_channel=args.quantize_per_channel,
+ reduce_range=args.quantize_reduce_range,
+ use_external_data_format=True,
+ extra_options={"MatMulConstBOnly": True},
+ )
+ logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!")
+ remove_existing_model(decoder_model_fp32_path)
+
+ logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
- # Verify parity on all saved ONNX models
+ elif args.precision == Precision.INT4:
+ if args.execution_provider != "cpu":
+ old_paths = convert_to_float16(args, l_config, old_paths)
+
+ decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx")
+ decoder_with_past_model_int4_path = os.path.join(
+ args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx"
+ )
+ decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx")
+ new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
+
+ for fp_path, int4_path in zip(old_paths, new_paths):
+ if os.path.exists(fp_path):
+ model = onnx.load_model(fp_path, load_external_data=True)
+ quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
+ quant.process()
+ quant.model.save_model_to_file(int4_path, use_external_data_format=True)
+ del model
+ del quant
+ logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
+ remove_existing_model(fp_path)
+
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info("Verifying parity on all ONNX models created")
+
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
+ args.precision = (
+ "fp32"
+ if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu")
+ else "fp16"
+ )
+
+ # Verify parity on all saved ONNX models
for filename in os.listdir(args.output):
if ".data" in filename or ".onnx" not in filename:
continue
- precision = filename[filename.rfind("_") + 1 : filename.find(".onnx")]
- parity_cmd = ["-m", f"{original_model_name}", "-o", f"{os.path.join(args.output, filename)}", "-fp", precision]
+ parity_cmd = [
+ "-m",
+ original_model_name,
+ "-o",
+ os.path.join(args.output, filename),
+ "-ep",
+ args.execution_provider,
+ "-id",
+ args.device_id,
+ "-fp",
+ args.precision,
+ ]
if "with_past" in filename:
parity_cmd.append("--use_past_kv")
- parity_check(parity_cmd)
+ if "merged" in filename:
+ parity_cmd.append("--merged")
+
+ try:
+ parity_check(parity_cmd)
+ except Exception as e:
+ logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True)
if __name__ == "__main__":
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
index 6a28498a9ffc9..2652e9f0ca64e 100644
--- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
@@ -4,10 +4,13 @@
import torch
from transformers import LlamaConfig
+from onnxruntime import OrtValue
+
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
position_ids = position_ids[:, -1].unsqueeze(-1)
return position_ids
@@ -62,11 +65,41 @@ def get_sample_with_past_kv_inputs(
return inputs
+# Inputs for all passes with past_key_values
+def get_merged_sample_with_past_kv_inputs(
+ config: LlamaConfig,
+ device: torch.device,
+ batch_size: int,
+ seq_len: int,
+ past_seq_len: int,
+ use_fp16: bool = False,
+ return_dict: bool = False,
+):
+ input_ids = torch.randint(
+ low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64
+ )
+ attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64)
+ # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
+ position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
+ past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16)
+
+ if not return_dict:
+ return (input_ids, attention_mask, position_ids, past_kv)
+
+ inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "past_key_values": past_kv,
+ }
+ return inputs
+
+
# Create past_key_values
def get_sample_past_kv_inputs(
config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool
):
- num_heads, head_size = config.num_attention_heads, config.hidden_size // config.num_attention_heads
+ num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
@@ -89,31 +122,83 @@ def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tenso
# Format PyTorch inputs to ONNX Runtime inputs
-def convert_inputs_for_ort(pt_inputs: dict, use_fp16: bool):
+def convert_inputs_for_ort(
+ pt_inputs: dict,
+ use_fp16: bool,
+ use_buffer_share: bool = False,
+ past_seq_len: int = 0,
+ max_seq_len: int = 2048,
+ device: str = "",
+ device_id: int = -1,
+):
ort_inputs = {}
for k, v in pt_inputs.items():
- if k == "past_key_values":
+ if isinstance(v, np.ndarray):
+ ort_inputs[k] = v
+ elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v, use_fp16))
+ elif k == "attention_mask" and use_fp16 and use_buffer_share:
+ # Skip because FP16 model has GroupQueryAttention, uses buffer sharing,
+ # and GQA supports a causal mask by default
+
+ # Instead, add the past sequence length input for GQA
+ ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64)
else:
ort_inputs[k] = v.detach().cpu().numpy()
+
+ # Enable past-present-share-buffer by using device memory directly
+ if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
+ for k, v in ort_inputs.items():
+ new_v = v
+ # Allocate new buffers with max_sequence_length for GQA
+ if "cache" in k or "past_key_values" in k:
+ # Copy v (BxSxPxH) into new_v (BxSxMxH)
+ batch_size, num_heads, _, head_size = v.shape
+ new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
+ new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
+ ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id)
+
return ort_inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
-def get_msft_sample_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool):
+def get_msft_sample_inputs(
+ config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool
+):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
max_seq_len = 2048
- ort_inputs = {
- "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
- "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
- "k_cache": np.random.rand(
- batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
- ).astype(np_dtype),
- "v_cache": np.random.rand(
- batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
- ).astype(np_dtype),
- "pos": np.array(past_seq_len, dtype=np.int64),
- }
+ if not split_kv:
+ ort_inputs = {
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
+ "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
+ "k_cache": np.random.rand(
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
+ ).astype(np_dtype),
+ "v_cache": np.random.rand(
+ batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
+ ).astype(np_dtype),
+ "pos": np.array(past_seq_len, dtype=np.int64),
+ }
+ else:
+ ort_inputs = {
+ "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
+ "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
+ np.int32
+ ),
+ "pos": np.array(past_seq_len, dtype=np.int64),
+ }
+ for i in range(config.num_hidden_layers):
+ ort_inputs.update(
+ {
+ f"k_{i}_cache": np.random.rand(
+ batch_size, config.num_attention_heads, past_seq_len, head_size
+ ).astype(np_dtype),
+ f"v_{i}_cache": np.random.rand(
+ batch_size, config.num_attention_heads, past_seq_len, head_size
+ ).astype(np_dtype),
+ }
+ )
+
return ort_inputs
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
index dadf394440c9a..6bfcb9b4f290d 100644
--- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
@@ -1,44 +1,143 @@
import argparse
import logging
import os
+import time
from typing import List
import numpy as np
import torch
-from benchmark_helper import create_onnxruntime_session, setup_logger
-from llama_inputs import convert_inputs_for_ort, get_sample_inputs, get_sample_with_past_kv_inputs
+from benchmark_helper import setup_logger
+from llama_inputs import (
+ convert_inputs_for_ort,
+ get_merged_sample_with_past_kv_inputs,
+ get_sample_inputs,
+ get_sample_with_past_kv_inputs,
+)
from transformers import LlamaConfig, LlamaForCausalLM
+import onnxruntime as ort
+
logger = logging.getLogger("")
-def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
+def get_sequence_lengths(args: argparse.Namespace):
+ past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
+ max_sequence_length = 2048
+ return past_sequence_length, curr_sequence_length, max_sequence_length
+
+
+def get_inputs(args: argparse.Namespace, config: LlamaConfig):
# Dummy values for parity
- batch_size, sequence_length = 2, 8
- device = torch.device("cpu")
+ batch_size = 2
+ past_sequence_length, sequence_length, _ = get_sequence_lengths(args)
- # Run inference with PyTorch
- inputs = (
- get_sample_inputs(config, device, batch_size, sequence_length, return_dict=True)
- if not args.use_past_kv
- else get_sample_with_past_kv_inputs(
- config, device, batch_size, sequence_length, use_fp16=(args.precision == "fp16"), return_dict=True
+ if args.merged:
+ inputs = get_merged_sample_with_past_kv_inputs(
+ config,
+ args.device,
+ batch_size,
+ sequence_length,
+ past_sequence_length,
+ use_fp16=args.use_fp16,
+ return_dict=True,
)
- )
+ elif args.use_past_kv:
+ inputs = get_sample_with_past_kv_inputs(
+ config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True
+ )
+ else:
+ inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
+
+ return inputs
+
+
+def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict):
+ # Add IO bindings for non-CPU execution providers
+ io_binding = model.io_binding()
+
+ for k, v in inputs.items():
+ if args.use_fp16:
+ # Bind all OrtValue inputs to device
+ io_binding.bind_ortvalue_input(k, v)
+ else:
+ io_binding.bind_cpu_input(k, v)
+
+ for output in model.get_outputs():
+ name = output.name
+ if args.use_fp16 and ("out" in name or "present" in name):
+ # Bind present KV cache outputs to OrtValue with buffer sharing
+ io_binding.bind_ortvalue_output(
+ name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
+ )
+ else:
+ io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id))
+
+ return io_binding
+
+
+def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM):
+ inputs = get_inputs(args, config)
+
+ # Run inference with PyTorch
+ if args.execution_provider != "cpu":
+ torch.cuda.synchronize()
+ start_time = time.time()
pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy()
+ if args.execution_provider != "cpu":
+ torch.cuda.synchronize()
+ end_time = time.time()
+ logger.info(f"PyTorch took {end_time - start_time} s")
# Run inference with ORT
- inputs = convert_inputs_for_ort(inputs, use_fp16=(args.precision == "fp16"))
- ort_model = create_onnxruntime_session(
+ past_sequence_length, _, max_sequence_length = get_sequence_lengths(args)
+ inputs = convert_inputs_for_ort(
+ inputs,
+ use_fp16=args.use_fp16,
+ use_buffer_share=args.use_fp16,
+ past_seq_len=past_sequence_length,
+ max_seq_len=max_sequence_length,
+ device=args.execution_provider,
+ device_id=int(args.device_id),
+ )
+
+ ep = f"{args.execution_provider.upper()}ExecutionProvider"
+ if ep == "CUDAExecutionProvider":
+ ep = (ep, {"device_id": args.device_id})
+ ort_model = ort.InferenceSession(
args.onnx_model_path,
- args.execution_provider != "cpu", # use_gpu
- provider=args.execution_provider,
- verbose=args.verbose,
+ sess_options=ort.SessionOptions(),
+ providers=[ep],
)
- ort_outputs = ort_model.run(None, inputs)[0]
+
+ # Add IO bindings for non-CPU execution providers
+ if args.execution_provider != "cpu":
+ io_binding = add_io_bindings(args, ort_model, inputs)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+ ort_model.run_with_iobinding(io_binding)
+ torch.cuda.synchronize()
+ end_time = time.time()
+
+ ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
+
+ else:
+ start_time = time.time()
+ ort_outputs = ort_model.run(None, inputs)
+ end_time = time.time()
+
+ ort_outputs = ort_outputs[0] # Get logits
+
+ logger.info(f"ONNX Runtime took {end_time - start_time} s")
# Compare PyTorch and ONNX Runtime accuracy
- tol = 1e-3 if args.precision == "fp32" else 1e-2 if args.precision == "fp16" else 1e2
+ tol = (
+ 2e1
+ if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path
+ else 1e-3
+ if args.precision == "fp32"
+ else 5e-1
+ )
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
@@ -80,6 +179,15 @@ def get_args(argv: List[str]):
help="Execution provider to verify parity with",
)
+ parser.add_argument(
+ "-id",
+ "--device-id",
+ required=False,
+ type=str,
+ default="0",
+ help="Device ID for GPUs",
+ )
+
parser.add_argument(
"-v",
"--verbose",
@@ -96,15 +204,29 @@ def get_args(argv: List[str]):
)
parser.set_defaults(use_past_kv=False)
+ parser.add_argument(
+ "--merged",
+ action="store_true",
+ help="Use merged model (i.e. decoder_merged_model.onnx).",
+ )
+ parser.set_defaults(merged=False)
+
parser.add_argument(
"-fp",
"--precision",
required=True,
- choices=["int8", "fp16", "fp32"],
+ choices=["int4", "int8", "fp16", "fp32"],
help="Precision of model",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
+
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
+ args.precision = (
+ "fp32"
+ if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
+ else "fp16"
+ )
return args
@@ -114,19 +236,34 @@ def main(argv: List[str] = []): # noqa: B006
logger.info(f"Arguments: {args}")
# Load model and config
+ setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
+ setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010
+ setattr(args, "device", torch.device(args.device_name)) # noqa: B010
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token)
llama = LlamaForCausalLM.from_pretrained(
location,
- torch_dtype=(torch.float16 if args.precision == "fp16" else torch.float32),
+ torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
use_auth_token=use_auth_token,
use_cache=True,
- )
+ ).to(args.device)
+
+ if not args.merged:
+ verify_parity(args, config, llama)
+ else:
+ # Verify prompt generation in merged model (decoder_model.onnx)
+ args.use_past_kv = False
+ verify_parity(args, config, llama)
- verify_parity(args, config, llama)
+ # Verify token generation in merged model (decoder_with_past_model.onnx)
+ args.use_past_kv = True
+ verify_parity(args, config, llama)
if __name__ == "__main__":
+ seed = 2
+ np.random.seed(seed)
+ torch.manual_seed(seed)
main()
diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt
index e9ad937cf14e7..e06c3ada834b0 100644
--- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt
+++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt
@@ -1,3 +1,2 @@
-r requirements.txt
-torch>=2.0.1
-onnxruntime>=1.16.0
\ No newline at end of file
+onnxruntime>=1.17.0
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
index 5544abcaa1228..773680937bd21 100644
--- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
+++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
@@ -1,4 +1,4 @@
-r requirements.txt
-# Please manually install torch>=2.0.1 with CUDA enabled for the CUDA version installed in your system.
+# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system.
# Instructions can be found here: https://pytorch.org/get-started/locally/
-onnxruntime-gpu>=1.16.0
\ No newline at end of file
+onnxruntime-gpu>=1.17.0
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt
index f843ef4dc5568..4210f36982aef 100644
--- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt
+++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt
@@ -1,5 +1,6 @@
-git+https://github.com/kunal-vaishnavi/optimum.git@kvaishnavi/llama-add-position-ids
-transformers>=4.28.1
+git+https://github.com/huggingface/optimum.git
+transformers>=4.33.2
+torch>=2.2.0.dev20230920
onnx>=1.14.0
datasets>=2.8.0
protobuf==3.20.2
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md
index e9365becd2cd1..8ff5c8a6e1de0 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/README.md
+++ b/onnxruntime/python/tools/transformers/models/whisper/README.md
@@ -79,24 +79,22 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations.
-Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`.
-
### Variants
-1. PyTorch (without `torch.compile`), FP32
+1. PyTorch without `torch.compile`, FP32
```
python3 -m models.whisper.benchmark \
- --benchmark-type hf-pt \
+ --benchmark-type hf-pt-eager \
--audio-path 1272-141231-0002.mp3 \
--model-name openai/whisper-large-v2 \
--precision fp32 \
--device cpu
```
-2. PyTorch 2.0 (with `torch.compile`), FP16
+2. PyTorch with `torch.compile`, FP16
```
python3 -m models.whisper.benchmark \
- --benchmark-type hf-pt2 \
+ --benchmark-type hf-pt-compile \
--audio-path 1272-141231-0002.mp3 \
--model-name openai/whisper-large-v2 \
--precision fp16 \
@@ -109,7 +107,7 @@ python3 -m models.whisper.benchmark \
--benchmark-type hf-ort \
--audio-path 1272-141231-0002.mp3 \
--model-name openai/whisper-large-v2 \
- --hf-ort-model-path ./whisper-large-v2-onnx/ \
+ --hf-ort-dir-path ./whisper-large-v2-onnx/ \
--precision fp32 \
--device cpu
```
@@ -156,7 +154,9 @@ You can use `benchmark_all.py` to benchmark across various platforms and automat
```
python3 -m models.whisper.benchmark_all \
--audio-path ./whisper-test-audios/ \
- --hf-ort-model-path ./whisper-large-v2-onnx/ \
+ --hf-pt-eager \
+ --hf-pt-compile \
+ --hf-ort-dir-path ./whisper-large-v2-onnx/ \
--ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \
--model-name openai/whisper-large-v2 \
--precision fp32 \
@@ -169,28 +169,28 @@ Here is a benchmark for an MP3 file with 20.7s of audio.
#### FP16
-| Engine | Size | Per-Token Latency | Real-Time Factor |
-| ------------- | -------- | ----------------- | ---------------- |
-| PyTorch | Tiny | 4.697 ms/token | 0.004697 |
-| PyTorch 2.0 | Tiny | 3.406 ms/token | 0.003406 |
-| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 |
-| PyTorch | Medium | 17.837 ms/token | 0.017387 |
-| PyTorch 2.0 | Medium | 18.124 ms/token | 0.018124 |
-| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 |
-| PyTorch | Large v2 | 23.470 ms/token | 0.023470 |
-| PyTorch 2.0 | Large v2 | 23.146 ms/token | 0.023146 |
-| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 |
+| Engine | Size | Per-Token Latency | Real-Time Factor |
+| --------------- | -------- | ----------------- | ---------------- |
+| PyTorch eager | Tiny | 4.697 ms/token | 0.004697 |
+| PyTorch compile | Tiny | 3.406 ms/token | 0.003406 |
+| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 |
+| PyTorch eager | Medium | 17.837 ms/token | 0.017387 |
+| PyTorch compile | Medium | 18.124 ms/token | 0.018124 |
+| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 |
+| PyTorch eager | Large v2 | 23.470 ms/token | 0.023470 |
+| PyTorch compile | Large v2 | 23.146 ms/token | 0.023146 |
+| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 |
#### FP32
-| Engine | Size | Per-Token Latency | Real-Time Factor |
-| ------------- | -------- | ----------------- | ---------------- |
-| PyTorch | Tiny | 6.220 ms/token | 0.006220 |
-| PyTorch 2.0 | Tiny | 3.944 ms/token | 0.003944 |
-| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 |
-| PyTorch | Medium | 19.093 ms/token | 0.019093 |
-| PyTorch 2.0 | Medium | 20.459 ms/token | 0.020459 |
-| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 |
-| PyTorch | Large v2 | 25.844 ms/token | 0.025844 |
-| PyTorch 2.0 | Large v2 | 26.397 ms/token | 0.026397 |
-| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 |
+| Engine | Size | Per-Token Latency | Real-Time Factor |
+| --------------- | -------- | ----------------- | ---------------- |
+| PyTorch eager | Tiny | 6.220 ms/token | 0.006220 |
+| PyTorch compile | Tiny | 3.944 ms/token | 0.003944 |
+| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 |
+| PyTorch eager | Medium | 19.093 ms/token | 0.019093 |
+| PyTorch compile | Medium | 20.459 ms/token | 0.020459 |
+| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 |
+| PyTorch eager | Large v2 | 25.844 ms/token | 0.025844 |
+| PyTorch compile | Large v2 | 26.397 ms/token | 0.026397 |
+| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 |
diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
index 283528bea7465..759ae6d14f184 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
@@ -24,7 +24,7 @@
def get_inputs(args: argparse.Namespace):
- if args.benchmark_type not in {"hf-pt", "hf-pt2", "hf-ort", "ort"}:
+ if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
raise Exception("Unable to auto-detect inputs for provided model")
def load_via_ffmpeg():
@@ -102,7 +102,7 @@ def get_model(args: argparse.Namespace):
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
- if args.benchmark_type in {"hf-pt", "hf-pt2"}:
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
start_time = time.time()
model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -112,7 +112,7 @@ def get_model(args: argparse.Namespace):
).to(args.target_device)
end_time = time.time()
- if args.benchmark_type == "hf-pt2":
+ if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort"}:
@@ -136,7 +136,7 @@ def get_model(args: argparse.Namespace):
start_time = time.time()
model = ORTModelForSpeechSeq2Seq.from_pretrained(
- args.hf_ort_model_path,
+ args.hf_ort_dir_path,
use_io_binding=(args.device != "cpu"),
provider=provider,
provider_options=provider_options,
@@ -214,7 +214,7 @@ def profile_fn(args, fn, inputs, inputs_type):
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
- if args.benchmark_type in {"hf-pt", "hf-pt2"}:
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
@@ -280,7 +280,7 @@ def gen_and_dec(inputs):
generate_fn = gen_and_dec
- if args.benchmark_type == "hf-pt2":
+ if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(inputs)
@@ -345,7 +345,7 @@ def prepare_ort_inputs(inputs, warmup=False):
for k, v in inputs.items():
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
- io_binding.bind_output(output.name)
+ io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
return io_binding
return inputs
@@ -407,7 +407,7 @@ def handle_output(output):
def run_inference(args, inputs, model):
- if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}:
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, inputs, model)
elif args.benchmark_type == "ort":
run_ort_inference(args, inputs, model)
@@ -419,8 +419,13 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
- "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"]
+ "-bt",
+ "--benchmark-type",
+ type=str,
+ required=True,
+ choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
)
+
parser.add_argument(
"-m",
"--model-name",
@@ -445,7 +450,7 @@ def parse_args():
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
- "--hf-ort-model-path",
+ "--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
@@ -538,7 +543,7 @@ def parse_args():
# Check that model paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
- assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`"
+ assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type == "ort":
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
index 08d7befec3cfd..071b539ac1899 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
@@ -54,7 +54,21 @@ def get_args():
)
parser.add_argument(
- "--hf-ort-model-path",
+ "--hf-pt-eager",
+ default=False,
+ action="store_true",
+ help="Benchmark in PyTorch without `torch.compile`",
+ )
+
+ parser.add_argument(
+ "--hf-pt-compile",
+ default=False,
+ action="store_true",
+ help="Benchmark in PyTorch with `torch.compile`",
+ )
+
+ parser.add_argument(
+ "--hf-ort-dir-path",
type=str,
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
@@ -136,7 +150,7 @@ def process_log_file(device_id, log_file, base_results):
load_audio_latency_s, load_audio_throughput_s = None, None
feat_ext_latency_s, feat_ext_throughput_s = None, None
- latency_s, per_token_latency_s, per_token_latency_ms = None, None, None
+ token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
throughput, memory = None, None
# Detect metrics
@@ -310,73 +324,75 @@ def main():
logger.info(f"Testing {audio_path}...")
# Benchmark PyTorch without torch.compile
- benchmark_cmd = [ # noqa: RUF005
- "python3",
- "-m",
- "models.whisper.benchmark",
- "--audio-path",
- audio_path,
- "--benchmark-type",
- "hf-pt",
- "--model-name",
- args.model_name,
- "--precision",
- args.precision,
- "--device",
- args.device,
- "--device-id",
- str(args.device_id),
- "--warmup-runs",
- str(args.warmup_runs),
- "--num-runs",
- str(args.num_runs),
- "--log-folder",
- args.log_folder,
- ] + hf_decoder_input_ids_cmd
- logger.info("Benchmark PyTorch without torch.compile")
- results = benchmark(args, benchmark_cmd, "pytorch", audio_file, duration)
- all_results.extend(results)
+ if args.hf_pt_eager:
+ benchmark_cmd = [ # noqa: RUF005
+ "python",
+ "-m",
+ "models.whisper.benchmark",
+ "--audio-path",
+ audio_path,
+ "--benchmark-type",
+ "hf-pt-eager",
+ "--model-name",
+ args.model_name,
+ "--precision",
+ args.precision,
+ "--device",
+ args.device,
+ "--device-id",
+ str(args.device_id),
+ "--warmup-runs",
+ str(args.warmup_runs),
+ "--num-runs",
+ str(args.num_runs),
+ "--log-folder",
+ args.log_folder,
+ ] + hf_decoder_input_ids_cmd
+ logger.info("Benchmark PyTorch without torch.compile")
+ results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
+ all_results.extend(results)
# Benchmark PyTorch with torch.compile
- benchmark_cmd = [ # noqa: RUF005
- "python3",
- "-m",
- "models.whisper.benchmark",
- "--audio-path",
- audio_path,
- "--benchmark-type",
- "hf-pt2",
- "--model-name",
- args.model_name,
- "--precision",
- args.precision,
- "--device",
- args.device,
- "--device-id",
- str(args.device_id),
- "--warmup-runs",
- str(args.warmup_runs),
- "--num-runs",
- str(args.num_runs),
- "--log-folder",
- args.log_folder,
- ] + hf_decoder_input_ids_cmd
- logger.info("Benchmark PyTorch with torch.compile")
- results = benchmark(args, benchmark_cmd, "pytorch-2", audio_file, duration)
- all_results.extend(results)
+ if args.hf_pt_compile:
+ benchmark_cmd = [ # noqa: RUF005
+ "python",
+ "-m",
+ "models.whisper.benchmark",
+ "--audio-path",
+ audio_path,
+ "--benchmark-type",
+ "hf-pt-compile",
+ "--model-name",
+ args.model_name,
+ "--precision",
+ args.precision,
+ "--device",
+ args.device,
+ "--device-id",
+ str(args.device_id),
+ "--warmup-runs",
+ str(args.warmup_runs),
+ "--num-runs",
+ str(args.num_runs),
+ "--log-folder",
+ args.log_folder,
+ ] + hf_decoder_input_ids_cmd
+ logger.info("Benchmark PyTorch with torch.compile")
+ results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
+ all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
- if args.hf_ort_model_path:
+ if args.hf_ort_dir_path:
benchmark_cmd = [ # noqa: RUF005
- "python3",
+ "python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-ort",
- "--hf-ort-model-path",
- args.hf_ort_model_path,
+ "--hf-ort-dir-path",
+ args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
@@ -393,14 +409,14 @@ def main():
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark Optimum + ONNX Runtime")
- results = benchmark(args, benchmark_cmd, "pytorch-ort", audio_file, duration)
+ results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
all_results.extend(results)
# Benchmark ONNX Runtime
if args.ort_model_path:
benchmark_cmd = (
[ # noqa: RUF005
- "python3",
+ "python",
"-m",
"models.whisper.benchmark",
"--audio-path",
diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py
index 995f8c6541b4c..7a69922e67072 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_bert.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py
@@ -22,7 +22,9 @@
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
from fusion_qordered_matmul import FusionQOrderedMatMul
from fusion_reshape import FusionReshape
+from fusion_rotary_attention import FusionRotaryEmbeddings
from fusion_shape import FusionShape
+from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
from fusion_utils import FusionUtils
from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper
@@ -106,10 +108,36 @@ def fuse_layer_norm(self):
fusion = FusionQOrderedLayerNormalization(self)
fusion.apply()
+ def fuse_simplified_layer_norm(self):
+ fusion = FusionSimplifiedLayerNormalization(self)
+ fusion.apply()
+
def fuse_skip_layer_norm(self):
fusion = FusionSkipLayerNormalization(self)
fusion.apply()
+ def fuse_skip_simplified_layer_norm(self):
+ fusion = FusionSkipSimplifiedLayerNormalization(self)
+ fusion.apply()
+
+ def fuse_rotary_embeddings(self):
+ fusion = FusionRotaryEmbeddings(self)
+ fusion.apply()
+ # Remove non-MS domain functions
+ rot_emb_nodes = list(
+ filter(
+ lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node
+ )
+ )
+ non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes))
+ i = 0
+ while i < len(self.model.functions):
+ fn = self.model.functions[i]
+ if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep:
+ self.model.functions.remove(fn)
+ else:
+ i += 1
+
# Only relevant in models with Q-DQ nodes
def fuse_qordered_mamtul(self):
fusion = FusionQOrderedMatMul(self)
@@ -367,6 +395,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo
if (options is None) or options.enable_layer_norm:
self.fuse_layer_norm()
+ self.fuse_simplified_layer_norm()
if (options is None) or options.enable_gelu:
self.fuse_gelu()
@@ -377,6 +406,10 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()
+ self.fuse_skip_simplified_layer_norm()
+
+ if (options is None) or options.enable_rotary_embeddings:
+ self.fuse_rotary_embeddings()
if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
@@ -442,14 +475,17 @@ def get_fused_operator_statistics(self):
"BiasGelu",
"GemmFastGelu",
"LayerNormalization",
+ "SimplifiedLayerNormalization",
"SkipLayerNormalization",
+ "SkipSimplifiedLayerNormalization",
+ "RotaryEmbedding",
]
q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"]
for op in ops + q_ops:
nodes = self.get_nodes_by_op_type(op)
op_count[op] = len(nodes)
- logger.info(f"Optimized operators:{op_count}")
+ logger.info(f"Optimized operators: {op_count}")
return op_count
def is_fully_optimized(self):
@@ -461,11 +497,20 @@ def is_fully_optimized(self):
attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
- is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)
+ simple_layer_norm = op_count["SimplifiedLayerNormalization"] + op_count["SkipSimplifiedLayerNormalization"]
+ is_perfect = (
+ (embed > 0)
+ and (attention > 0)
+ and (attention == gelu)
+ and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention))
+ )
if layer_norm == 0:
logger.debug("Layer Normalization not fused")
+ if simple_layer_norm == 0:
+ logger.debug("Simple Layer Normalization not fused")
+
if gelu == 0:
logger.debug("Gelu/FastGelu not fused")
diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py
index 263857ffbc130..6545bb08cdd5e 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py
@@ -8,6 +8,7 @@
from fusion_gpt_attention import FusionGptAttention
from fusion_gpt_attention_megatron import FusionGptAttentionMegatron
from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
+from fusion_rotary_attention import FusionRotaryAttention
from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
@@ -27,6 +28,9 @@ def fuse_attention(self):
fusion = FusionGptAttentionMegatron(self, self.num_heads)
fusion.apply()
+ fusion = FusionRotaryAttention(self, self.hidden_size, self.num_heads)
+ fusion.apply()
+
def postprocess(self):
"""
Remove extra reshape nodes.
@@ -94,4 +98,4 @@ def postprocess(self):
reshape_count += 2
self.prune_graph()
- logger.info(f"postprocess: remove Reshape count:{reshape_count}")
+ logger.info(f"postprocess: remove Reshape count: {reshape_count}")
diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py
index e9f98e956b760..95f40af3fd746 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_t5.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py
@@ -3,12 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from fusion_base import Fusion
-from fusion_skiplayernorm import FusionSkipLayerNormalization
+from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
@@ -56,8 +56,8 @@ def create_attention_node(
Args:
mask_index (str): mask input
q_matmul (NodeProto): MatMul node in fully connection for Q
- k_matmul (NodeProto): MatMul node in fully connection for K
- v_matmul (NodeProto): MatMul node in fully connection for V
+ k_matmul (NodeProto): MatMul node in fully connection for K
+ v_matmul (NodeProto): MatMul node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
@@ -687,67 +687,6 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
-class FusionSimplifiedLayerNormalization(Fusion):
- def __init__(self, model: OnnxModel):
- super().__init__(model, "SimplifiedLayerNormalization", "Mul")
-
- def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
- if node.op_type != "Mul":
- return
-
- sim_ln_nodes = self.model.match_parent_path(
- node,
- ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
- [1, 1, 1, 0, 0, 0, 0],
- )
- if sim_ln_nodes is None:
- sim_ln_nodes = self.model.match_parent_path(
- node,
- ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
- [1, 1, 1, 0, 0, 0, 0],
- )
- if sim_ln_nodes is None:
- return
-
- pow_node = sim_ln_nodes[-2]
- if self.model.find_constant_input(pow_node, 2.0) != 1:
- return
-
- root_input = pow_node.input[0]
-
- mul_node_1 = sim_ln_nodes[0]
- if root_input != mul_node_1.input[0]:
- return
-
- second_add_node = sim_ln_nodes[3]
- i, add_weight = self.model.get_constant_input(second_add_node)
- if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
- logger.warning(f"epsilon value is not expeced: {add_weight}")
- return
-
- self.nodes_to_remove.extend(sim_ln_nodes[:-1])
-
- normalize_node = helper.make_node(
- "SimplifiedLayerNormalization",
- inputs=[root_input, node.input[0]],
- outputs=[node.output[0]],
- name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
- )
- normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
- normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))])
- normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
- self.nodes_to_add.append(normalize_node)
- self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
-
-
-class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
- def __init__(self, model: OnnxModel):
- super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
-
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- super().fuse(node, input_name_to_nodes, output_name_to_node)
-
-
class T5OnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py
index 5ded027b36f74..00b26c019d4b5 100644
--- a/onnxruntime/python/tools/transformers/optimizer.py
+++ b/onnxruntime/python/tools/transformers/optimizer.py
@@ -103,7 +103,7 @@ def optimize_by_onnxruntime(
logger.error("There is no gpu for onnxruntime to do optimization.")
return onnx_model_path
- model = OnnxModel(load_model(onnx_model_path, format=None, load_external_data=False))
+ model = OnnxModel(load_model(onnx_model_path, load_external_data=False))
if model.use_float16() and not use_gpu:
logger.warning(
"This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. "
@@ -546,7 +546,7 @@ def main():
if args.input_int32:
optimizer.change_graph_inputs_to_int32()
- if args.model_type in ["bert", "gpt2"]:
+ if args.model_type in set(MODEL_TYPES.keys()):
if optimizer.is_fully_optimized():
logger.info("The model has been fully optimized.")
else:
diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py
index f8a5464d8af78..f1fc0c952e8e4 100644
--- a/onnxruntime/python/tools/transformers/shape_infer_helper.py
+++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py
@@ -28,12 +28,12 @@ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_o
self.is_inferred_: bool = False
self.dynamic_axis_mapping_: Dict[str, int] = {}
- def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 128):
+ def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200):
"""Run shape inference, and try replace dynamic axis from string to integer when mapping is provided.
Args:
dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4}
- max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 32.
+ max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200.
Returns:
bool: whether all shapes has been inferred or not.
diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
new file mode 100644
index 0000000000000..29d8219c162a5
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
@@ -0,0 +1,632 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include "gtest/gtest.h"
+#include "core/session/onnxruntime_cxx_api.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+static void RunTest(
+ const std::vector& input_data,
+ const std::vector& position_ids,
+ const std::vector& cos_cache,
+ const std::vector& sin_cache,
+ const std::vector& output_data,
+ int batch_size,
+ int sequence_length,
+ int head_size,
+ int num_heads,
+ int max_sequence_length,
+ int64_t interleaved,
+ bool use_float16,
+ bool disable_cpu,
+ bool disable_cuda) {
+ // input : (batch_size, sequence_length, hidden_size)
+ // position ids : (1) or (batch_size, sequence_length)
+ // cos cache : (max_sequence_length, head_size / 2)
+ // sin cache : (max_sequence_length, head_size / 2)
+ // interleaved : 0 = false, 1 = true
+
+ int hidden_size = num_heads * head_size;
+ std::vector input_dims = {batch_size, sequence_length, hidden_size};
+ std::vector pos_dims;
+ std::vector cache_dims = {max_sequence_length, head_size / 2};
+
+ assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0);
+ assert(max_sequence_length >= sequence_length);
+ if (position_ids.size() == 1) {
+ pos_dims = {1};
+ } else {
+ pos_dims = {batch_size, sequence_length};
+ }
+
+ std::string op_type = "RotaryEmbedding";
+ std::vector> execution_providers;
+
+ int min_cuda_architecture = use_float16 ? 530 : 0;
+ bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+ if (enable_cuda && !disable_cuda) {
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ }
+ if (!use_float16 && !disable_cpu) {
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ }
+ if (execution_providers.size() == 0) {
+ // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline)
+ return;
+ }
+
+ OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
+ test.AddAttribute("interleaved", interleaved);
+
+ if (!use_float16) {
+ test.AddInput("input", input_dims, input_data);
+ test.AddInput("position_ids", pos_dims, position_ids);
+ test.AddInput("cos_cache", cache_dims, cos_cache);
+ test.AddInput("sin_cache", cache_dims, sin_cache);
+ test.AddOutput("output", input_dims, output_data);
+ } else {
+ test.AddInput("input", input_dims, ToFloat16(input_data));
+ test.AddInput("position_ids", pos_dims, position_ids);
+ test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache));
+ test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache));
+ test.AddOutput("output", input_dims, ToFloat16(output_data));
+ }
+ test.SetOutputAbsErr("output", 0.002f);
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+static void RunTests(const std::vector& input_data,
+ const std::vector& position_ids,
+ const std::vector& cos_cache,
+ const std::vector& sin_cache,
+ const std::vector& output_data,
+ int batch_size,
+ int sequence_length,
+ int head_size = 0,
+ int num_heads = 0,
+ int max_sequence_length = 0,
+ int64_t interleaved = 0,
+ bool use_float16 = true) {
+ // FP32 test for CPU
+ RunTest(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved,
+ false, /* use_fp16 */
+ false, /* disable_cpu */
+ true /* disable_cuda */);
+
+ // FP32 test for CUDA
+ RunTest(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved,
+ false, /* use_fp16 */
+ false, /* disable_cpu */
+ false /* disable_cuda */);
+
+ // FP16 test for CUDA
+ if (use_float16) {
+ RunTest(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved,
+ true, /* use_fp16 */
+ true, /* disable_cpu */
+ false /* disable_cuda*/);
+ }
+}
+
+// Interleaved = true, pos ids shape = (1)
+TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) {
+ int batch_size = 1;
+ int sequence_length = 3;
+ int num_heads = 2;
+ int head_size = 4;
+ int max_sequence_length = 8;
+ int64_t interleaved = 1; // true
+
+ std::vector input_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f,
+ -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f,
+ -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f};
+
+ std::vector position_ids = {0};
+
+ std::vector cos_cache = {
+ 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f,
+ -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f};
+
+ std::vector sin_cache = {
+ 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f,
+ -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f};
+
+ std::vector output_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f,
+ -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f,
+ -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f};
+
+ RunTests(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved);
+}
+
+// Interleaved = true, pos ids shape = (1)
+TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) {
+ int batch_size = 2;
+ int sequence_length = 8;
+ int num_heads = 4;
+ int head_size = 6;
+ int max_sequence_length = 16;
+ int64_t interleaved = 1; // true
+
+ std::vector input_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f,
+ 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f,
+ 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f,
+ 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f,
+ -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f,
+ -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f,
+ -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f,
+ -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f,
+ -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f,
+ 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f,
+ -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f,
+ 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f,
+ 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f,
+ -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f,
+ -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f,
+ -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f,
+ -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f,
+ 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f,
+ 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f,
+ -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f,
+ 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f,
+ -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f,
+ 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f,
+ 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f,
+ 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f,
+ -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f,
+ -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f,
+ 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f,
+ 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f,
+ -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f,
+ 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f,
+ 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f,
+ -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f,
+ 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f,
+ -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f,
+ 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f,
+ -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f,
+ -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f,
+ -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f,
+ 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f,
+ 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f,
+ -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f,
+ 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f,
+ 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f,
+ 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f,
+ 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f,
+ -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f,
+ -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f,
+ 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f,
+ -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f,
+ 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f,
+ 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f,
+ 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f,
+ 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f,
+ 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f,
+ -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f,
+ -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f,
+ 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f,
+ -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f,
+ -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f,
+ 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f,
+ 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f,
+ -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f,
+ -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f,
+ -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f,
+ -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f,
+ -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f,
+ 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f,
+ -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f,
+ -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f,
+ 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f,
+ 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f,
+ 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f,
+ 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f,
+ -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f,
+ -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f,
+ -1.8219f, 1.1079f, 0.5795f, -1.4249f};
+
+ std::vector position_ids = {0};
+
+ std::vector cos_cache = {
+ 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f,
+ 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f,
+ 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f,
+ -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f,
+ 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f,
+ 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f};
+
+ std::vector sin_cache = {
+ 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f,
+ 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f,
+ 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f,
+ 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f,
+ 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f,
+ 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f};
+
+ std::vector output_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f,
+ 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f,
+ 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f,
+ 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f,
+ -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f,
+ -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f,
+ -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f,
+ -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f,
+ -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f,
+ 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f,
+ -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f,
+ -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f,
+ 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f,
+ -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f,
+ -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f,
+ -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f,
+ -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f,
+ -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f,
+ -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f,
+ -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f,
+ 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f,
+ -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f,
+ 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f,
+ -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f,
+ 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f,
+ -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f,
+ -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f,
+ 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f,
+ 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f,
+ 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f,
+ 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f,
+ 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f,
+ -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f,
+ 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f,
+ -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f,
+ 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f,
+ -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f,
+ -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f,
+ -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f,
+ 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f,
+ 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f,
+ -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f,
+ 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f,
+ 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f,
+ 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f,
+ 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f,
+ -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f,
+ -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f,
+ 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f,
+ -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f,
+ 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f,
+ -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f,
+ 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f,
+ -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f,
+ -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f,
+ -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f,
+ -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f,
+ -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f,
+ -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f,
+ 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f,
+ -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f,
+ 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f,
+ -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f,
+ -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f,
+ -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f,
+ 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f,
+ 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f,
+ 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f,
+ -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f,
+ -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f,
+ 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f,
+ 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f,
+ 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f,
+ 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f,
+ -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f,
+ -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f,
+ -2.0803f, 0.4684f, 0.6009f, -1.4160f};
+
+ RunTests(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved);
+}
+
+// Interleaved = false, pos ids shape = (1)
+TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) {
+ int batch_size = 2;
+ int sequence_length = 8;
+ int num_heads = 4;
+ int head_size = 6;
+ int max_sequence_length = 16;
+ int64_t interleaved = 0; // false
+
+ std::vector input_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f,
+ 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f,
+ 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f,
+ 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f,
+ -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f,
+ -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f,
+ -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f,
+ -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f,
+ -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f,
+ 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f,
+ -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f,
+ 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f,
+ 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f,
+ -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f,
+ -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f,
+ -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f,
+ -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f,
+ 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f,
+ 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f,
+ -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f,
+ 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f,
+ -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f,
+ 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f,
+ 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f,
+ 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f,
+ -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f,
+ -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f,
+ 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f,
+ 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f,
+ -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f,
+ 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f,
+ 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f,
+ -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f,
+ 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f,
+ -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f,
+ 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f,
+ -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f,
+ -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f,
+ -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f,
+ 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f,
+ 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f,
+ -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f,
+ 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f,
+ 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f,
+ 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f,
+ 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f,
+ -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f,
+ -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f,
+ 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f,
+ -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f,
+ 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f,
+ 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f,
+ 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f,
+ 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f,
+ 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f,
+ -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f,
+ -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f,
+ 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f,
+ -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f,
+ -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f,
+ 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f,
+ 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f,
+ -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f,
+ -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f,
+ -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f,
+ -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f,
+ -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f,
+ 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f,
+ -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f,
+ -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f,
+ 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f,
+ 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f,
+ 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f,
+ 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f,
+ -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f,
+ -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f,
+ -1.8219f, 1.1079f, 0.5795f, -1.4249f};
+
+ std::vector position_ids = {0};
+
+ std::vector cos_cache = {
+ 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f,
+ 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f,
+ 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f,
+ -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f,
+ 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f,
+ 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f};
+
+ std::vector sin_cache = {
+ 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f,
+ 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f,
+ 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f,
+ 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f,
+ 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f,
+ 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f};
+
+ std::vector output_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f,
+ 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f,
+ 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f,
+ 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f,
+ -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f,
+ -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f,
+ 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f,
+ -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f,
+ -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f,
+ 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f,
+ -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f,
+ 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f,
+ -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f,
+ -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f,
+ -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f,
+ -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f,
+ -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f,
+ 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f,
+ -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f,
+ -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f,
+ 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f,
+ -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f,
+ 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f,
+ 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f,
+ 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f,
+ -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f,
+ 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f,
+ -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f,
+ 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f,
+ -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f,
+ 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f,
+ 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f,
+ -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f,
+ 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f,
+ -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f,
+ 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f,
+ -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f,
+ -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f,
+ -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f,
+ 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f,
+ 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f,
+ -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f,
+ 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f,
+ 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f,
+ 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f,
+ 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f,
+ -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f,
+ -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f,
+ -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f,
+ -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f,
+ 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f,
+ 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f,
+ 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f,
+ 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f,
+ -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f,
+ -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f,
+ -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f,
+ -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f,
+ -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f,
+ -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f,
+ -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f,
+ 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f,
+ -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f,
+ -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f,
+ -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f,
+ -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f,
+ -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f,
+ 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f,
+ -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f,
+ -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f,
+ 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f,
+ 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f,
+ 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f,
+ 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f,
+ -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f,
+ 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f,
+ -1.8002f, 0.9892f, 0.3619f, -1.4522f};
+
+ RunTests(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved);
+}
+
+// Interleaved = false, pos ids shape = (batch_size, sequence_length)
+TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int num_heads = 3;
+ int head_size = 6;
+ int max_sequence_length = 4;
+ int64_t interleaved = 0; // false
+
+ std::vector input_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
+ -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f,
+ -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f,
+ -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f,
+ -0.2994f, -0.0650f, -1.5720f, -1.3211f};
+
+ std::vector position_ids = {0, 1};
+
+ std::vector cos_cache = {
+ 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f,
+ 1.0000f, -0.9900f, 0.9903f, 1.0000f};
+
+ std::vector sin_cache = {
+ 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f,
+ 0.1411f, 0.1388f, 0.0065f};
+
+ std::vector output_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
+ -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f,
+ -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f,
+ -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f,
+ -0.2965f, -0.8469f, -1.5749f, -1.3217f};
+
+ RunTests(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ num_heads,
+ max_sequence_length,
+ interleaved);
+}
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py
new file mode 100644
index 0000000000000..b17ae5f69aff5
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py
@@ -0,0 +1,450 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+
+# Notes
+# 1) The test cases in this file are for the following LLaMA-2 scenarios:
+# - Microsoft rotary embeddings with interleaved = True
+# - Prompt generation
+# - Token generation
+# - Hugging Face rotary embeddings (equal to Microsoft rotary embeddings with interleaved = False)
+# - Prompt generation
+# - Token generation
+#
+# 2) Shapes of position ids in ORT and `interleaved` for LLaMA-2 scenarios:
+# - Microsoft model: When shape of position ids == (1), interleaved = True
+# - Hugging Face model: When shape of position ids == (batch_size, sequence_length), interleaved = False
+
+
+import unittest
+from copy import deepcopy
+
+import numpy as np
+import torch
+import torch.nn as nn
+from onnx import TensorProto, helper
+
+import onnxruntime as ort
+
+
+class SampleInputConfig:
+ def __init__(
+ self,
+ batch_size=2,
+ sequence_length=8,
+ num_heads=4,
+ head_size=6,
+ max_sequence_length=16,
+ ):
+ self.batch_size = batch_size
+ self.sequence_length = sequence_length
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.hidden_size = self.num_heads * self.head_size
+ self.max_sequence_length = max_sequence_length
+
+
+# LLaMA Hugging Face model
+class LlamaHFRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cpu"):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
+
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
+
+ def get_cos_sin_cache(self, seq_len=None, device=torch.device("cpu"), dtype=torch.float32): # noqa: B008
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
+
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype),
+ )
+
+ def rotate_half(self, x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope_bnsh(self, x, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ x_embed = (x * cos) + (self.rotate_half(x) * sin)
+ return x_embed
+
+ def apply_rope_bsnh(self, x, cos, sin, position_ids):
+ # Two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze() # [seq_len, dim]
+ sin = sin.squeeze() # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
+ sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
+ x_embed = (x * cos) + (self.rotate_half(x) * sin)
+ return x_embed
+
+ def forward(self, x, cos, sin, pos_ids, x_format="bnsh"):
+ if x_format == "bnsh":
+ return self.apply_rope_bnsh(x, cos, sin, pos_ids)
+ return self.apply_rope_bsnh(x, cos, sin, pos_ids)
+
+
+# LLaMA Microsoft model
+class LlamaMSRotaryEmbedding(nn.Module):
+ def __init__(self, hidden_size, num_heads, max_sequence_length):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.max_sequence_length = max_sequence_length
+
+ def get_cos_sin_cache(self, theta=10000.0, head_scale=1.0, device="cpu", dtype=torch.float32):
+ hidden_size = self.hidden_size
+ n_heads = self.num_heads
+ max_seq_len = self.max_sequence_length
+
+ # Precalculate rotary matrices for the sequence
+ # According to "Attention Is All You Need", theta_i = 10000 ^ (2 * (i - 1)/dim), i in [1, 2, ..., dim//2]
+ head_dim = head_scale * hidden_size / n_heads
+
+ pos = torch.arange(0, 2 * (head_dim // 2), step=2, device=device, dtype=dtype)
+ freqs = 1.0 / (theta ** (pos / head_dim))
+
+ idx = torch.arange(max_seq_len, device=freqs.device)
+ freqs = torch.outer(idx, freqs)
+
+ cos = torch.reshape(torch.cos(freqs), [1, max_seq_len, 1, -1])
+ sin = torch.reshape(torch.sin(freqs), [1, max_seq_len, 1, -1])
+ dtype = torch.get_default_dtype()
+
+ return cos.to(dtype), sin.to(dtype)
+
+ def rotate_tensor(
+ self,
+ x: torch.Tensor, # BxSxNxH
+ cos: torch.Tensor, # 1xSx1x(H/2)
+ sin: torch.Tensor, # 1xSx1x(H/2)
+ pos: int,
+ interleaved: bool,
+ ):
+ # Dimension of x is [batch_size, seq_len, n_heads, head_dim]
+ rot_dim = 2 * cos.shape[3]
+
+ # Dolly requires partial rotation
+ x_rot = x[:, :, :, :rot_dim]
+
+ if interleaved:
+ x1 = x_rot[:, :, :, 0::2]
+ x2 = x_rot[:, :, :, 1::2]
+ else:
+ half = x_rot.shape[-1] // 2
+ x1 = x[:, :, :, 0:half]
+ x2 = x[:, :, :, half : 2 * half]
+
+ seq_len = x.shape[1]
+ cos_x = cos[:, pos : pos + seq_len, :, :]
+ sin_x = sin[:, pos : pos + seq_len, :, :]
+
+ # cos_x: (1, S, 1, H/2)
+ # sin_x: (1, S, 1, H/2)
+ # x1: (B, S, N, H/2)
+ # x2: (B, S, N, H/2)
+ real = cos_x * x1 - sin_x * x2
+ imag = sin_x * x1 + cos_x * x2
+
+ if interleaved:
+ x_rot[:, :, :, 0::2] = real
+ x_rot[:, :, :, 1::2] = imag
+ else:
+ x_rot = torch.cat((real, imag), dim=-1)
+
+ return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1)
+
+ def forward(self, x, cos, sin, pos, interleaved):
+ return self.rotate_tensor(x, cos, sin, pos, interleaved)
+
+
+class TestLlamaRotaryEmbedding(unittest.TestCase):
+ def setUp(self):
+ self.config = SampleInputConfig()
+ self.llama_hf = LlamaHFRotaryEmbedding(self.config.head_size, self.config.max_sequence_length)
+ self.llama_ms = LlamaMSRotaryEmbedding(
+ self.config.hidden_size, self.config.num_heads, self.config.max_sequence_length
+ )
+
+ seed = 2
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.set_printoptions(sci_mode=False)
+
+ def create_onnx_graph(self, x_shape, pos_shape, cos, sin, interleaved):
+ inputs = [
+ helper.make_tensor_value_info(
+ name="input",
+ elem_type=TensorProto.FLOAT,
+ shape=list(x_shape),
+ ),
+ helper.make_tensor_value_info(
+ name="position_ids",
+ elem_type=TensorProto.INT64,
+ shape=list(pos_shape),
+ ),
+ ]
+ outputs = [
+ helper.make_tensor_value_info(
+ name="output",
+ elem_type=TensorProto.FLOAT,
+ shape=list(x_shape),
+ ),
+ ]
+
+ initializers = [
+ helper.make_tensor(
+ name="cos_cache",
+ data_type=TensorProto.FLOAT,
+ dims=list(torch.squeeze(cos).shape),
+ vals=cos.flatten().tolist(),
+ ),
+ helper.make_tensor(
+ name="sin_cache",
+ data_type=TensorProto.FLOAT,
+ dims=list(torch.squeeze(sin).shape),
+ vals=sin.flatten().tolist(),
+ ),
+ ]
+ nodes = [
+ helper.make_node(
+ op_type="RotaryEmbedding",
+ inputs=["input", "position_ids", "cos_cache", "sin_cache"],
+ outputs=["output"],
+ interleaved=interleaved,
+ name="RotaryEmbedding_0",
+ domain="com.microsoft",
+ ),
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="RotaryEmbedding_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model.SerializeToString()
+
+ def get_eps(self):
+ eps = ["CPUExecutionProvider", "CUDAExecutionProvider"]
+ return list(filter(lambda ep: ep in ort.get_available_providers(), eps))
+
+ def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh):
+ eps = self.get_eps()
+ for ep in eps:
+ sess = ort.InferenceSession(onnx_graph, providers=[ep])
+ output_ort = sess.run(None, inputs_ort)[0]
+ output_ort = output_ort.reshape(
+ (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size)
+ )
+
+ # Compare outputs as BxSxNxH
+ self.assertTrue(np.allclose(expected_output_bsnh, output_ort))
+
+ # apply_rope(x_bnsh) == apply_rope(x_bsnh).transpose(1,2)
+ def test_hf_bnsh_and_hf_bsnh(self):
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
+
+ x_bnsh_after_rope = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH
+ x_bsnh_after_rope = self.llama_hf(
+ x_bnsh.transpose(1, 2), cos_hf.transpose(1, 2), sin_hf.transpose(1, 2), pos_hf, "bsnh"
+ ) # output is BxSxNxH
+
+ self.assertTrue(torch.allclose(x_bnsh_after_rope, x_bsnh_after_rope.transpose(1, 2)))
+
+ # HF rotary == MSFT rotary non-interleaved
+ def test_hf_rotary_and_msft_rotary_noninterleaved(self):
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH
+
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = 0
+ output_ms = (
+ self.llama_ms(x_bsd, cos_ms, sin_ms, pos_ms, interleaved=False).detach().cpu().numpy() # output is BxSxNxH
+ )
+
+ # Compare caches as Mx(H/2)
+ self.assertTrue(
+ torch.allclose(self.llama_hf.cos_cached.squeeze()[:, : (self.config.head_size // 2)], cos_ms.squeeze())
+ )
+ self.assertTrue(
+ torch.allclose(self.llama_hf.sin_cached.squeeze()[:, : (self.config.head_size // 2)], sin_ms.squeeze())
+ )
+
+ # Compare outputs as BxSxNxH
+ self.assertTrue(np.allclose(output_hf.transpose(1, 2).detach().cpu().numpy(), output_ms))
+
+ # Prompt step, interleaved = true, pos ids shape = (1)
+ def test_msft_prompt_rotary_interleaved(self):
+ # Calculated this way to match the data in rotary_embedding_op_test.cc
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = 0
+ output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy()
+
+ x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size)
+ pos_ms = torch.tensor([pos_ms])
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Compare inputs/outputs as BxSxNxH
+ self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten()))
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms)
+
+ # Token generation step, interleaved = true, pos ids shape = (1)
+ def test_msft_token_rotary_interleaved(self):
+ # Calculated this way to match the data in rotary_embedding_op_test.cc
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = 2
+ output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy()
+
+ x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size)
+ pos_ms = torch.tensor([pos_ms])
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Compare inputs/outputs as BxSxNxH
+ self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten()))
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms)
+
+ # Prompt step, interleaved = false, pos ids shape = (batch_size, sequence_length)
+ def test_hf_prompt_rotary_batched_pos_ids(self):
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_ids = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH
+
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size)
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ids.detach().cpu().numpy(),
+ }
+
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())
+
+ # Token generation step, interleaved = false, pos ids shape = (batch_size, sequence_length)
+ def test_hf_token_rotary_batched_pos_ids(self):
+ x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size)
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH
+
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size)
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ids.detach().cpu().numpy(),
+ }
+
+ # Compare outputs as BxSxNxH
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())
+
+ # Bonus test: Prompt step, interleaved = false, pos ids shape = (1)
+ def test_hf_prompt_rotary_one_pos_id(self):
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH
+
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size)
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = torch.tensor([0])
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Compare outputs as BxSxNxH
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())
+
+ # Bonus test: Token generation step, interleaved = false, pos ids shape = (1)
+ def test_hf_token_rotary_one_pos_id(self):
+ x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size)
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH
+
+ x_bsnh = x_bnsh.transpose(1, 2)
+ x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size)
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = torch.tensor([2])
+ onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bsd.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Compare outputs as BxSxNxH
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py
new file mode 100644
index 0000000000000..7bca48c29019e
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py
@@ -0,0 +1,447 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import os
+import sys
+import unittest
+from typing import List
+
+import numpy as np
+import onnx
+from onnx import TensorProto, helper
+from parity_utilities import find_transformers_source
+
+if find_transformers_source():
+ from fusion_options import FusionOptions
+ from onnx_model import OnnxModel
+ from optimizer import optimize_model
+else:
+ from onnxruntime.transformers.fusion_options import FusionOptions
+ from onnxruntime.transformers.onnx_model import OnnxModel
+ from onnxruntime.transformers.optimizer import optimize_model
+
+
+def float_tensor(name: str, shape: List[int], random=False):
+ low = 0.0
+ high = 1.0
+ total_elements = 1
+ for x in shape:
+ total_elements *= x
+ weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
+ return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
+
+
+class TestRotaryEmbeddingFusion(unittest.TestCase):
+ def setUp(self):
+ self.batch_size = 2
+ self.sequence_length = 8
+ self.num_heads = 4
+ self.head_size = 6
+ self.hidden_size = self.num_heads * self.head_size
+
+ self.past_sequence_length = 2
+ self.max_sequence_length = 12
+
+ def verify_fusion(self, expected_model_path, original_model_path):
+ expected_model = OnnxModel(onnx.load(expected_model_path))
+ expected_model.topological_sort(is_deterministic=True)
+
+ options = FusionOptions("gpt2")
+ optimized_model = optimize_model(original_model_path, optimization_options=options, opt_level=0)
+ optimized_model.topological_sort(is_deterministic=True)
+
+ self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph))
+
+ def create_initializers(self):
+ initializers = [
+ float_tensor("cos_cache", [self.max_sequence_length, self.head_size]),
+ float_tensor("sin_cache", [self.max_sequence_length, self.head_size]),
+ helper.make_tensor(
+ "pos_ids_new_shape",
+ TensorProto.FLOAT,
+ [2],
+ np.array([self.batch_size, self.sequence_length], dtype=np.int64),
+ ),
+ helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)),
+ helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)),
+ helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)),
+ helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)),
+ helper.make_tensor("int_max", TensorProto.FLOAT, [1], np.array([sys.maxsize], dtype=np.int64)),
+ ]
+ return initializers
+
+ def create_inputs_and_outputs(self, model_type: str = ""):
+ inputs = [
+ helper.make_tensor_value_info(
+ "input_0",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.num_heads, self.head_size],
+ ),
+ helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]),
+ ]
+ if model_type in {"past", "merged"}:
+ # Input will be removed in fused model since it's not used in RotaryEmbedding.
+ # We create this input so that we can check the `past_seq_len` path during
+ # RotaryEmbedding fusion.
+ inputs.append(
+ helper.make_tensor_value_info(
+ "past_key",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size],
+ )
+ )
+ # Dummy input to test nodes for `curr_seq_len` path
+ if model_type != "":
+ inputs.append(
+ helper.make_tensor_value_info(
+ "curr_key",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.num_heads, self.head_size],
+ )
+ )
+ outputs = [
+ helper.make_tensor_value_info(
+ "output_0",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.sequence_length, self.head_size],
+ )
+ ]
+ if model_type in {"merged"}:
+ # Dummy output to test that nodes for `past_seq_len` path are not removed for merged model
+ outputs.append(helper.make_tensor_value_info("past_seq_len_plus_zero", TensorProto.FLOAT, [1]))
+ return inputs, outputs
+
+ def create_fused_model(self, interleaved: bool, initializers: List[TensorProto]):
+ inputs, outputs = self.create_inputs_and_outputs()
+
+ rope_node = helper.make_node(
+ "RotaryEmbedding",
+ inputs=[inputs[0].name, inputs[1].name, initializers[0].name, initializers[1].name],
+ outputs=[outputs[0].name],
+ name="RotaryEmbedding_0",
+ interleaved=int(interleaved),
+ )
+
+ graph = helper.make_graph(
+ nodes=[rope_node],
+ name="RotaryEmbedding_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ def create_cache_path(self, model_type: str, use_redundant_squeeze_ops: bool):
+ # Create position ids path
+ reshape_node = helper.make_node(
+ "Reshape",
+ inputs=["position_ids", "pos_ids_new_shape"],
+ outputs=["pos_ids_reshaped"],
+ name="Reshape_0",
+ )
+ pos_ids_nodes = [reshape_node]
+
+ # Create cos path
+ cos_init_unsqueeze_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["new_seq_len", "zero"],
+ outputs=["cos_unsqueeze"],
+ name="Unsqueeze_2",
+ )
+ cos_slice_node = helper.make_node(
+ "Slice",
+ inputs=["cos_cache", "zero", "cos_unsqueeze", "two", "one"],
+ outputs=["cos_sliced"],
+ name="Slice_2",
+ )
+ cos_nodes = [cos_init_unsqueeze_node, cos_slice_node]
+
+ if use_redundant_squeeze_ops:
+ # These two nodes are eliminated by this transformers PR: https://github.com/huggingface/transformers/pull/26162
+ cos_squeeze_1_node = helper.make_node(
+ "Squeeze",
+ inputs=["cos_sliced", "zero"],
+ outputs=["cos_squeeze_1"],
+ name="Squeeze_0",
+ )
+ cos_squeeze_2_node = helper.make_node(
+ "Squeeze",
+ inputs=["cos_squeeze_1", "zero"],
+ outputs=["cos_squeeze_2"],
+ name="Squeeze_1",
+ )
+ cos_nodes.extend([cos_squeeze_1_node, cos_squeeze_2_node])
+
+ cos_gather_node = helper.make_node(
+ "Gather",
+ inputs=["cos_squeeze_2" if use_redundant_squeeze_ops else "cos_sliced", "pos_ids_reshaped"],
+ outputs=["cos_indexed"],
+ name="Gather_1",
+ )
+ cos_end_unsqueeze_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["cos_indexed", "one"],
+ outputs=["cos"],
+ name="Unsqueeze_3",
+ )
+ cos_nodes.extend([cos_gather_node, cos_end_unsqueeze_node])
+
+ # Create sin path
+ sin_init_unsqueeze_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["new_seq_len", "zero"],
+ outputs=["sin_unsqueeze"],
+ name="Unsqueeze_4",
+ )
+ sin_slice_node = helper.make_node(
+ "Slice",
+ inputs=["sin_cache", "zero", "sin_unsqueeze", "two", "one"],
+ outputs=["sin_sliced"],
+ name="Slice_3",
+ )
+ sin_nodes = [sin_init_unsqueeze_node, sin_slice_node]
+
+ if use_redundant_squeeze_ops:
+ sin_squeeze_1_node = helper.make_node(
+ "Squeeze",
+ inputs=["sin_sliced", "zero"],
+ outputs=["sin_squeeze_1"],
+ name="Squeeze_2",
+ )
+ sin_squeeze_2_node = helper.make_node(
+ "Squeeze",
+ inputs=["sin_squeeze_1", "zero"],
+ outputs=["sin_squeeze_2"],
+ name="Squeeze_3",
+ )
+ sin_nodes.extend([sin_squeeze_1_node, sin_squeeze_2_node])
+
+ sin_gather_node = helper.make_node(
+ "Gather",
+ inputs=["sin_squeeze_2" if use_redundant_squeeze_ops else "sin_sliced", "pos_ids_reshaped"],
+ outputs=["sin_indexed"],
+ name="Gather_2",
+ )
+ sin_end_unsqueeze_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["sin_indexed", "one"],
+ outputs=["sin"],
+ name="Unsqueeze_5",
+ )
+ sin_nodes.extend([sin_gather_node, sin_end_unsqueeze_node])
+
+ # Create beginning nodes before cos and sin paths
+
+ # Create curr seq len path
+ curr_transpose_node = helper.make_node(
+ "Transpose",
+ inputs=["curr_key"],
+ outputs=["curr_key_transposed"],
+ name="Transpose_curr",
+ perm=[0, 2, 1, 3],
+ )
+ curr_shape_node = helper.make_node(
+ "Shape",
+ inputs=["curr_key_transposed"],
+ outputs=["curr_shape"],
+ name="Shape_curr",
+ )
+ curr_gather_node = helper.make_node(
+ "Gather",
+ inputs=["curr_shape", "two"],
+ outputs=["curr_seq_len" if model_type in {"past", "merged"} else "new_seq_len"],
+ name="Gather_curr",
+ )
+ beginning_nodes = [curr_transpose_node, curr_shape_node, curr_gather_node]
+
+ if model_type in {"past", "merged"}:
+ # Create past seq len path
+ past_shape_node = helper.make_node(
+ "Shape",
+ inputs=["past_key"],
+ outputs=["past_shape"],
+ name="Shape_past",
+ )
+ past_gather_node = helper.make_node(
+ "Gather",
+ inputs=["past_shape", "two"],
+ outputs=["past_seq_len"],
+ name="Gather_past",
+ )
+ add_node = helper.make_node(
+ "Add",
+ inputs=["curr_seq_len", "past_seq_len"],
+ outputs=["new_seq_len"],
+ name="Add_1",
+ )
+ beginning_nodes.extend([past_shape_node, past_gather_node, add_node])
+
+ if model_type == "merged":
+ dummy_node = helper.make_node(
+ "Add",
+ inputs=["past_seq_len", "zero"],
+ outputs=["past_seq_len_plus_zero"],
+ name="Add_dummy_node",
+ )
+ beginning_nodes.append(dummy_node)
+
+ return pos_ids_nodes + cos_nodes + sin_nodes + beginning_nodes
+
+ def create_apply_rope_path(self):
+ start_node = helper.make_node(
+ "Transpose",
+ inputs=["input_0"],
+ outputs=["x"],
+ name="Transpose_0",
+ perm=[0, 2, 1, 3],
+ )
+
+ # Calculate x_half_shape
+ shape_node = helper.make_node(
+ "Shape",
+ inputs=["x"],
+ outputs=["x_shape"],
+ name="Shape_0",
+ )
+ gather_node = helper.make_node(
+ "Gather",
+ inputs=["x_shape", "three"],
+ outputs=["x_last_idx_shape"],
+ name="Gather_0",
+ axis=0,
+ )
+ div_node = helper.make_node(
+ "Div",
+ inputs=["x_last_idx_shape", "two"],
+ outputs=["x_half_shape"],
+ name="Div_0",
+ )
+ unsqueeze_0_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["x_half_shape", "zero"],
+ outputs=["x_half_shape_0"],
+ name="Unsqueeze_0",
+ )
+ unsqueeze_1_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["x_half_shape", "zero"],
+ outputs=["x_half_shape_1"],
+ name="Unsqueeze_1",
+ )
+ x_half_shape_nodes = [shape_node, gather_node, div_node, unsqueeze_0_node, unsqueeze_1_node]
+
+ # Calculate rotate_half
+ x1_node = helper.make_node(
+ "Slice",
+ inputs=["x", "zero", "x_half_shape_0", "three", "one"],
+ outputs=["x1"],
+ name="Slice_0",
+ )
+ x2_node = helper.make_node(
+ "Slice",
+ inputs=["x", "x_half_shape_1", "int_max", "three", "one"],
+ outputs=["x2"],
+ name="Slice_1",
+ )
+ neg_node = helper.make_node(
+ "Neg",
+ inputs=["x2"],
+ outputs=["x2_neg"],
+ name="Neg_0",
+ )
+ x_rotate_half_node = helper.make_node(
+ "Concat",
+ inputs=["x2_neg", "x1"],
+ outputs=["x_rotate_half"],
+ name="Concat_0",
+ axis=-1,
+ )
+ rotate_half_nodes = [x1_node, x2_node, neg_node, x_rotate_half_node]
+
+ # Calculate x_embed
+ x_cos_node = helper.make_node(
+ "Mul",
+ inputs=["x", "cos"],
+ outputs=["x_cos"],
+ name="Mul_0",
+ )
+ x_sin_node = helper.make_node(
+ "Mul",
+ inputs=["x_rotate_half", "sin"],
+ outputs=["x_rotate_half_sin"],
+ name="Mul_1",
+ )
+ end_node = helper.make_node(
+ "Add",
+ inputs=["x_cos", "x_rotate_half_sin"],
+ outputs=["output_0"],
+ name="Add_0",
+ )
+ x_embed_nodes = [start_node, x_cos_node, x_sin_node, end_node]
+
+ return x_half_shape_nodes + rotate_half_nodes + x_embed_nodes
+
+ def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: List[TensorProto]):
+ apply_rope_nodes = self.create_apply_rope_path()
+ cache_nodes = self.create_cache_path(model_type, use_redundant_squeeze_ops)
+ inputs, outputs = self.create_inputs_and_outputs(model_type)
+
+ graph = helper.make_graph(
+ nodes=apply_rope_nodes + cache_nodes,
+ name="RotaryEmbedding_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="ai.onnx", version=13)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ def check_models(self, interleaved: bool, model_type: str):
+ initializers = self.create_initializers()
+
+ expected_model_filename = "expected_model.onnx"
+ expected_model = self.create_fused_model(interleaved, initializers)
+ onnx.save(expected_model, expected_model_filename)
+
+ original_model_filename = "original_model.onnx"
+ use_redundant_squeeze_ops = True
+ original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers)
+ onnx.save(original_model, original_model_filename)
+
+ self.verify_fusion(expected_model_filename, original_model_filename)
+ os.remove(original_model_filename)
+
+ use_redundant_squeeze_ops = False
+ original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers)
+ onnx.save(original_model, original_model_filename)
+
+ self.verify_fusion(expected_model_filename, original_model_filename)
+ os.remove(expected_model_filename)
+ os.remove(original_model_filename)
+
+ # Hugging Face's `decoder_model.onnx`
+ def test_hf_decoder_model(self):
+ interleaved = False # HF model does not use interleaving
+ model_type = "no_past"
+ self.check_models(interleaved, model_type)
+
+ # Hugging Face's `decoder_with_past_model.onnx`
+ def test_hf_decoder_with_past_model(self):
+ interleaved = False # HF model does not use interleaving
+ model_type = "past"
+ self.check_models(interleaved, model_type)
+
+ # Hugging Face's `decoder_merged.onnx`
+ def test_hf_decoder_merged_model(self):
+ interleaved = False # HF model does not use interleaving
+ model_type = "merged"
+ self.check_models(interleaved, model_type)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py
new file mode 100644
index 0000000000000..fedba2a25dfc2
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py
@@ -0,0 +1,795 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import os
+import sys
+import unittest
+from typing import List
+
+import numpy as np
+import onnx
+from onnx import NodeProto, TensorProto, helper
+from parity_utilities import find_transformers_source
+
+if find_transformers_source():
+ from fusion_options import FusionOptions
+ from onnx_model import OnnxModel
+ from optimizer import optimize_model
+else:
+ from onnxruntime.transformers.fusion_options import FusionOptions
+ from onnxruntime.transformers.onnx_model import OnnxModel
+ from onnxruntime.transformers.optimizer import optimize_model
+
+
+def float_tensor(name: str, shape: List[int], random=False):
+ low = 0.0
+ high = 1.0
+ total_elements = 1
+ for x in shape:
+ total_elements *= x
+ weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
+ return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
+
+
+class TestRotaryAttentionFusion(unittest.TestCase):
+ def setUp(self):
+ self.batch_size = 2
+ self.sequence_length = 8
+ self.num_heads = 4
+ self.head_size = 6
+ self.hidden_size = self.num_heads * self.head_size
+
+ self.past_sequence_length = 2
+ self.max_sequence_length = 12
+
+ def verify_fusion(self, expected_model_path, original_model_path):
+ expected_model = OnnxModel(onnx.load(expected_model_path))
+ expected_model.topological_sort(is_deterministic=True)
+
+ model_type = "gpt2"
+ options = FusionOptions(model_type)
+ optimized_model = optimize_model(
+ original_model_path,
+ model_type,
+ self.num_heads,
+ self.hidden_size,
+ optimization_options=options,
+ opt_level=0,
+ )
+ optimized_model.topological_sort(is_deterministic=True)
+
+ self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph))
+
+ def create_initializers(self, fused_model: bool = False):
+ initializers = [
+ float_tensor("cos_cache", [self.max_sequence_length, self.head_size // 2]),
+ float_tensor("sin_cache", [self.max_sequence_length, self.head_size // 2]),
+ float_tensor("q_weight", [self.hidden_size, self.hidden_size]),
+ float_tensor("k_weight", [self.hidden_size, self.hidden_size]),
+ float_tensor("v_weight", [self.hidden_size, self.hidden_size]),
+ float_tensor("o_weight", [self.hidden_size, self.hidden_size]),
+ helper.make_tensor(
+ "sqrt_head_size", TensorProto.FLOAT, [1], np.array([np.sqrt(self.head_size)], dtype=np.float32)
+ ),
+ helper.make_tensor("neg_int_max", TensorProto.FLOAT, [1], np.array([-sys.maxsize - 1], dtype=np.int64)),
+ helper.make_tensor("num_heads", TensorProto.FLOAT, [1], np.array([self.num_heads], dtype=np.float32)),
+ helper.make_tensor("head_size", TensorProto.FLOAT, [1], np.array([self.head_size], dtype=np.float32)),
+ helper.make_tensor("hidden_size", TensorProto.FLOAT, [1], np.array([self.hidden_size], dtype=np.float32)),
+ helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)),
+ helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)),
+ helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)),
+ helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)),
+ ]
+ return initializers
+
+ def create_inputs_and_outputs(self, model_type: str):
+ attn_mask_size = [self.batch_size, self.sequence_length]
+ if model_type == "llama2_msft":
+ attn_mask_size.append(self.sequence_length)
+
+ inputs = [
+ helper.make_tensor_value_info(
+ "input_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size]
+ ),
+ helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]),
+ helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size),
+ ]
+ if model_type in {"past", "merged", "llama2_msft"}:
+ inputs.extend(
+ [
+ helper.make_tensor_value_info(
+ "past_key",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size],
+ ),
+ helper.make_tensor_value_info(
+ "past_value",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size],
+ ),
+ ]
+ )
+ outputs = [
+ helper.make_tensor_value_info(
+ "output_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size]
+ ),
+ helper.make_tensor_value_info(
+ "present_key",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size],
+ ),
+ helper.make_tensor_value_info(
+ "present_value",
+ TensorProto.FLOAT,
+ [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size],
+ ),
+ ]
+ return inputs, outputs
+
+ def create_matmul_nodes(self, is_fused: bool, model_type: str):
+ q_matmul_node = helper.make_node(
+ "MatMul",
+ inputs=["input_0", "q_weight"],
+ outputs=["q_out" if is_fused or model_type == "llama2_msft" else "q_matmul_out"],
+ name="Q_MatMul",
+ )
+
+ k_matmul_node = helper.make_node(
+ "MatMul",
+ inputs=["input_0", "k_weight"],
+ outputs=["k_out" if is_fused or model_type == "llama2_msft" else "k_matmul_out"],
+ name="K_MatMul",
+ )
+
+ v_matmul_node = helper.make_node(
+ "MatMul",
+ inputs=["input_0", "v_weight"],
+ outputs=["v_out"],
+ name="V_MatMul",
+ )
+
+ return [q_matmul_node, k_matmul_node, v_matmul_node]
+
+ def create_rotary_embeddings(
+ self,
+ is_fused: bool,
+ model_type: str,
+ interleaved: bool,
+ inputs: List[TensorProto],
+ initializers: List[TensorProto],
+ ):
+ def get_first_rope_input(node_type: str):
+ if is_fused or model_type == "llama2_msft":
+ # q_out/k_out
+ return f"{node_type}_out"
+ if model_type in {"no_past", "past", "merged"}:
+ if node_type == "k":
+ return "k_before_rope"
+ return "q_before_rope"
+ return ""
+
+ def get_first_rope_output(node_type: str):
+ if is_fused or model_type in {"llama2_msft", "past", "merged"}:
+ if node_type == "q":
+ return "q_rope"
+ return "k_rope"
+ if model_type in {"no_past"}:
+ if node_type == "k":
+ return "present_key"
+ return "q_rope"
+ return ""
+
+ q_rope_node = helper.make_node(
+ "RotaryEmbedding",
+ inputs=[get_first_rope_input("q"), inputs[1].name, initializers[0].name, initializers[1].name],
+ outputs=[get_first_rope_output("q")],
+ name="Q_RotaryEmbedding",
+ interleaved=int(interleaved),
+ )
+
+ k_rope_node = helper.make_node(
+ "RotaryEmbedding",
+ inputs=[get_first_rope_input("k"), inputs[1].name, initializers[0].name, initializers[1].name],
+ outputs=[get_first_rope_output("k")],
+ name="K_RotaryEmbedding",
+ interleaved=int(interleaved),
+ )
+
+ return [q_rope_node, k_rope_node]
+
+ def create_q_path(self, model_type: str):
+ if model_type == "llama2_msft":
+ transpose_q_node = helper.make_node(
+ "Transpose",
+ inputs=["q_rope"],
+ outputs=["q_transposed"],
+ name="Transpose_q",
+ perm=[0, 2, 1, 3],
+ )
+ reshape_q_node = helper.make_node(
+ "Reshape",
+ inputs=["q_transposed", "concat_q_extra_out"],
+ outputs=["q"],
+ name="Reshape_q",
+ )
+ return [transpose_q_node, reshape_q_node]
+
+ reshape_q_node = helper.make_node(
+ "Reshape",
+ inputs=["q_matmul_out", "concat_q_extra_out"],
+ outputs=["q_reshaped"],
+ name="Reshape_q",
+ )
+ transpose_q_node = helper.make_node(
+ "Transpose",
+ inputs=["q_reshaped"],
+ outputs=["q_before_rope"],
+ name="Transpose_q",
+ )
+ return [reshape_q_node, transpose_q_node]
+
+ def create_k_path_llama2_msft(self):
+ # Create k cache slicing path
+ k_cache_unsqueeze_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["position_ids", "zero"],
+ outputs=["k_pos_id"],
+ )
+ k_cache_slice_node = helper.make_node(
+ "Slice",
+ inputs=["past_key", "zero", "k_pos_id", "two", "one"],
+ outputs=["k_cache_sliced"],
+ )
+ # Create k path
+ transpose_k_1_node = helper.make_node(
+ "Transpose",
+ inputs=["k_rope"],
+ outputs=["k_rope_transposed"],
+ name="Transpose_k_1",
+ perm=[0, 2, 1, 3],
+ )
+ concat_k_node = helper.make_node(
+ "Concat",
+ inputs=["k_cache_sliced", "k_rope_transposed"],
+ outputs=["present_key"],
+ name="Concat_k",
+ axis=2,
+ )
+ transpose_k_2_node = helper.make_node(
+ "Transpose",
+ inputs=["present_key"],
+ outputs=["present_key_transposed"],
+ name="Transpose_k_2",
+ perm=[0, 2, 3, 1],
+ )
+ reshape_k_node = helper.make_node(
+ "Reshape",
+ inputs=["present_key_transposed", "concat_k_extra_out"],
+ outputs=["k"],
+ name="Reshape_k",
+ )
+ return [
+ k_cache_unsqueeze_node,
+ k_cache_slice_node,
+ transpose_k_1_node,
+ concat_k_node,
+ transpose_k_2_node,
+ reshape_k_node,
+ ]
+
+ def create_k_path_hf(self, model_type: str):
+ reshape_k_node = helper.make_node(
+ "Reshape",
+ inputs=["k_matmul_out", "concat_k_extra_out"],
+ outputs=["k_reshaped"],
+ name="Reshape_k",
+ )
+ transpose_k_1_node = helper.make_node(
+ "Transpose",
+ inputs=["k_reshaped"],
+ outputs=["k_before_rope"],
+ name="Transpose_k_1",
+ perm=[0, 2, 1, 3],
+ )
+ k_nodes = [reshape_k_node, transpose_k_1_node]
+
+ if model_type in {"past", "merged"}:
+ concat_k_node = helper.make_node(
+ "Concat",
+ inputs=["past_key", "k_rope"],
+ outputs=["present_key"],
+ axis=2,
+ )
+ k_nodes.append(concat_k_node)
+
+ transpose_k_2_node = helper.make_node(
+ "Transpose",
+ inputs=["present_key"],
+ outputs=["k"],
+ name="Transpose_k_2",
+ perm=[0, 1, 3, 2],
+ )
+ return k_nodes + [transpose_k_2_node] # noqa: RUF005
+
+ def create_k_path(self, model_type: str):
+ if model_type == "llama2_msft":
+ return self.create_k_path_llama2_msft()
+ return self.create_k_path_hf(model_type)
+
+ def create_attn_mask_path_llama2_msft(self):
+ x_shape_node = helper.make_node(
+ "Shape",
+ inputs=["input_0"],
+ outputs=["input_0_shape"],
+ name="Shape_input",
+ )
+ x_get_seq_len_node = helper.make_node(
+ "Gather",
+ inputs=["input_0_shape", "one"],
+ outputs=["input_0_seq_len"],
+ name="Gather_input",
+ axis=0,
+ )
+ x_new_seq_len_node = helper.make_node(
+ "Add",
+ inputs=["position_ids", "input_0_seq_len"],
+ outputs=["new_seq_len"],
+ name="Add_mask",
+ )
+ unsqueeze_0_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["position_ids", "zero"],
+ outputs=["unsqueeze_mask_0_out"],
+ name="Unsqueeze_mask_0",
+ )
+ unsqueeze_1_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["new_seq_len", "zero"],
+ outputs=["unsqueeze_mask_1_out"],
+ name="Unsqueeze_mask_1",
+ )
+ unsqueeze_2_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["new_seq_len", "zero"],
+ outputs=["unsqueeze_mask_2_out"],
+ name="Unsqueeze_mask_2",
+ )
+ slice_mask_1_node = helper.make_node(
+ "Slice",
+ inputs=["attn_mask", "unsqueeze_mask_0_out", "unsqueeze_mask_1_out", "one", "one"],
+ outputs=["slice_mask_1_out"],
+ name="Slice_mask_1",
+ )
+ slice_mask_2_node = helper.make_node(
+ "Slice",
+ inputs=["slice_mask_1_out", "zero", "unsqueeze_mask_2_out", "two", "one"],
+ outputs=["slice_mask_2_out"],
+ name="Slice_mask_2",
+ )
+ concat_mask_node = helper.make_node(
+ "Concat",
+ inputs=["slice_mask_2_out" for _ in range(self.num_heads)],
+ outputs=["attn_mask_out"],
+ name="Concat_mask",
+ axis=0,
+ )
+ return [
+ x_shape_node,
+ x_get_seq_len_node,
+ x_new_seq_len_node,
+ unsqueeze_0_node,
+ unsqueeze_1_node,
+ unsqueeze_2_node,
+ slice_mask_1_node,
+ slice_mask_2_node,
+ concat_mask_node,
+ ]
+
+ def create_attn_mask_path_hf(self, model_type: str):
+ unsqueeze_1_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["attn_mask", "one"],
+ outputs=["unsqueeze_1_mask_out"],
+ name="Unsqueeze_1_mask",
+ )
+ unsqueeze_2_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["unsqueeze_1_mask_out", "two"],
+ outputs=["unsqueeze_2_mask_out"],
+ name="Unsqueeze_2_mask",
+ )
+ expand_node = helper.make_node(
+ "Expand",
+ inputs=["unsqueeze_2_mask_out", "zero"],
+ outputs=["expand_out"],
+ name="Expand_mask",
+ )
+ cast_node = helper.make_node(
+ "Cast",
+ inputs=["expand_out"],
+ outputs=["cast_out"],
+ name="Cast_mask",
+ to=TensorProto.FLOAT,
+ )
+ sub_node = helper.make_node(
+ "Sub",
+ inputs=["one", "cast_out"],
+ outputs=["sub_out"],
+ name="Sub_mask",
+ )
+ where_node = helper.make_node(
+ "Where",
+ inputs=["zero", "neg_int_max", "sub_out"],
+ outputs=["where_out" if model_type != "past" else "attn_mask_out"],
+ name="Where_mask",
+ )
+ attn_mask_nodes = [unsqueeze_1_node, unsqueeze_2_node, expand_node, cast_node, sub_node, where_node]
+
+ if model_type == "past":
+ return attn_mask_nodes
+
+ add_node = helper.make_node(
+ "Add",
+ inputs=["where_out", "zero"],
+ outputs=["attn_mask_out"],
+ name="Add_mask",
+ )
+ return attn_mask_nodes + [add_node] # noqa: RUF005
+
+ def create_attn_mask_path(self, is_fused: bool, model_type: str):
+ if model_type == "llama2_msft":
+ attn_mask_nodes = self.create_attn_mask_path_llama2_msft()
+ if is_fused:
+ attn_mask_nodes.pop()
+ attn_mask_nodes[-1].output[0] = "attn_mask_out"
+ return attn_mask_nodes
+
+ attn_mask_nodes = self.create_attn_mask_path_hf(model_type)
+ if is_fused:
+ new_output_name = "attn_mask_out_mask"
+ attn_mask_nodes[-1].output[0] = new_output_name
+ concat_mask_node = helper.make_node(
+ "Concat",
+ inputs=[new_output_name for _ in range(self.num_heads)],
+ outputs=["attn_mask_out"],
+ name="Concat_mask",
+ axis=0,
+ )
+ attn_mask_nodes.append(concat_mask_node)
+ return attn_mask_nodes
+
+ def create_qk_path(self, model_type: str):
+ matmul_qk_node = helper.make_node(
+ "MatMul",
+ inputs=["q" if model_type == "llama2_msft" else "q_rope", "k"],
+ outputs=["qk"],
+ name="MatMul_q_k",
+ )
+ div_node = helper.make_node(
+ "Div",
+ inputs=["qk", "sqrt_head_size"],
+ outputs=["qk_div"],
+ name="Div_0",
+ )
+ add_node = helper.make_node(
+ "Add",
+ inputs=["qk_div", "attn_mask_out"],
+ outputs=["qk_plus_mask"],
+ name="Add_0",
+ )
+ softmax_node = helper.make_node(
+ "Softmax",
+ inputs=["qk_plus_mask"],
+ outputs=["softmax_out"],
+ name="Softmax_0",
+ )
+ return [matmul_qk_node, div_node, add_node, softmax_node]
+
+ def create_v_path(self, model_type: str):
+ reshape_v_1_node = helper.make_node(
+ "Reshape",
+ inputs=["v_out", "concat_v_1_extra_out"],
+ outputs=["reshape_v_1_out"],
+ name="Reshape_v_1",
+ )
+ transpose_v_1_node = helper.make_node(
+ "Transpose",
+ inputs=["reshape_v_1_out"],
+ outputs=["transpose_v_1_out" if model_type != "no_past" else "present_value"],
+ name="Transpose_v_1",
+ )
+ v_nodes = [reshape_v_1_node, transpose_v_1_node]
+
+ if model_type == "no_past":
+ return v_nodes
+
+ if model_type in {"past", "merged"}:
+ concat_v_node = helper.make_node(
+ "Concat",
+ inputs=["past_value", "transpose_v_1_out"],
+ outputs=["present_value"],
+ name="Concat_v",
+ axis=2,
+ )
+ return v_nodes + [concat_v_node] # noqa: RUF005
+
+ # Create extra nodes for `position_ids`
+ unsqueeze_v_node = helper.make_node(
+ "Unsqueeze",
+ inputs=["position_ids", "zero"],
+ outputs=["unsqueeze_v_out"],
+ name="Unsqueeze_v",
+ )
+ slice_v_node = helper.make_node(
+ "Slice",
+ inputs=["past_value", "zero", "unsqueeze_v_out", "two", "one"],
+ outputs=["v_cache_sliced_out"],
+ name="Slice_v",
+ )
+ concat_v_node = helper.make_node(
+ "Concat",
+ inputs=["v_cache_sliced_out", "transpose_v_1_out"],
+ outputs=["present_value"],
+ name="Concat_v",
+ axis=2,
+ )
+ v_nodes.extend([unsqueeze_v_node, slice_v_node, concat_v_node])
+
+ # Create remaining nodes for v path
+ transpose_v_2_node = helper.make_node(
+ "Transpose",
+ inputs=["present_value"],
+ outputs=["transpose_v_2_out"],
+ name="Transpose_v_2",
+ )
+ reshape_v_2_node = helper.make_node(
+ "Reshape",
+ inputs=["transpose_v_2_out", "concat_v_2_extra_out"],
+ outputs=["v"],
+ name="Reshape_v_2",
+ )
+ return v_nodes + [transpose_v_2_node, reshape_v_2_node] # noqa: RUF005
+
+ def create_qkv_path(self, model_type: str):
+ matmul_qkv_node = helper.make_node(
+ "MatMul",
+ inputs=["softmax_out", "v" if model_type == "llama2_msft" else "present_value"],
+ outputs=["softmax_v_out"],
+ name="MatMul_softmax_v",
+ )
+ qkv_nodes = [matmul_qkv_node]
+
+ if model_type == "llama2_msft":
+ reshape_qkv_1_node = helper.make_node(
+ "Reshape",
+ inputs=["softmax_v_out", "concat_qkv_1_extra_out"],
+ outputs=["reshape_qkv_1_out"],
+ name="Reshape_qkv_1",
+ )
+ qkv_nodes.append(reshape_qkv_1_node)
+
+ transpose_qkv_node = helper.make_node(
+ "Transpose",
+ inputs=["reshape_qkv_1_out" if model_type == "llama2_msft" else "softmax_v_out"],
+ outputs=["transpose_qkv_out"],
+ name="Transpose_qkv",
+ )
+ reshape_qkv_2_node = helper.make_node(
+ "Reshape",
+ inputs=["transpose_qkv_out", "concat_qkv_2_extra_out"],
+ outputs=["attn_output"],
+ name="Reshape_qkv_2",
+ )
+
+ return qkv_nodes + [transpose_qkv_node, reshape_qkv_2_node] # noqa: RUF005
+
+ def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[NodeProto]):
+ # Create initial shape paths
+ shape_0_node = helper.make_node(
+ "Shape",
+ inputs=["input_0"],
+ outputs=["input_0_shape_0"],
+ name="Shape_0",
+ )
+ gather_0_node = helper.make_node(
+ "Gather",
+ inputs=["input_0_shape_0", "zero"],
+ outputs=["input_0_shape_0_indexed"],
+ name="Gather_0",
+ axis=0,
+ )
+ shape_1_node = helper.make_node(
+ "Shape",
+ inputs=["input_0"],
+ outputs=["input_0_shape_1"],
+ name="Shape_1",
+ )
+ gather_1_node = helper.make_node(
+ "Gather",
+ inputs=["input_0_shape_1", "one"],
+ outputs=["input_0_shape_1_indexed"],
+ name="Gather_1",
+ axis=0,
+ )
+ extra_nodes = [shape_0_node, gather_0_node, shape_1_node, gather_1_node]
+
+ if model_type == "llama2_msft":
+ mul_node = helper.make_node(
+ "Mul",
+ inputs=[gather_0_node.output[0], "num_heads"],
+ outputs=["mul_extra_out"],
+ name="Mul_extra_0",
+ )
+ add_node = helper.make_node(
+ "Add",
+ inputs=[gather_1_node.output[0], "position_ids"],
+ outputs=["add_extra_out"],
+ name="Add_extra_0",
+ )
+ extra_nodes.extend([mul_node, add_node])
+
+ for i, reshape_node in enumerate(reshape_nodes):
+ use_mul_and_add_nodes_0 = model_type == "llama2_msft" and reshape_node.output[0] in {"q", "k", "v"}
+ use_mul_and_add_nodes_1 = model_type == "llama2_msft" and reshape_node.output[0] in {"k", "v"}
+
+ unsqueeze_0_node = helper.make_node(
+ "Unsqueeze",
+ inputs=[gather_0_node.output[0] if not use_mul_and_add_nodes_0 else "mul_extra_out", "zero"],
+ outputs=[f"unsqueeze_extra_{2*i}"],
+ name=f"Unsqueeze_extra_{2*i}",
+ )
+ unsqueeze_1_node = helper.make_node(
+ "Unsqueeze",
+ inputs=[gather_1_node.output[0] if not use_mul_and_add_nodes_1 else "add_extra_out", "zero"],
+ outputs=[f"unsqueeze_extra_{2*i + 1}"],
+ name=f"Unsqueeze_extra_{2*i + 1}",
+ )
+
+ reshape_name = reshape_node.name
+ if reshape_name == "Reshape_qkv_2":
+ concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "hidden_size"]
+ elif reshape_name == "Reshape_qkv_1":
+ concat_node_inputs = [unsqueeze_0_node.output[0], "num_heads", unsqueeze_1_node.output[0], "head_size"]
+ elif reshape_name == "Reshape_v_2":
+ concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"]
+ elif reshape_name == "Reshape_v_1":
+ concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "num_heads", "head_size"]
+ elif reshape_name == "Reshape_k":
+ concat_node_inputs = [unsqueeze_0_node.output[0], "head_size", unsqueeze_1_node.output[0]]
+ elif reshape_name == "Reshape_q":
+ concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"]
+
+ concat_node = helper.make_node(
+ "Concat",
+ inputs=concat_node_inputs,
+ outputs=[reshape_nodes[i].input[1]],
+ name=f"Concat_extra_{i}",
+ axis=0,
+ )
+ extra_nodes.extend([unsqueeze_0_node, unsqueeze_1_node, concat_node])
+
+ return extra_nodes
+
+ def create_end_nodes(self):
+ matmul_o_node = helper.make_node(
+ "MatMul",
+ inputs=["attn_output", "o_weight"],
+ outputs=["output_proj"],
+ name="MatMul_o_proj",
+ )
+ end_node = helper.make_node(
+ "Add",
+ inputs=["zero", "output_proj"],
+ outputs=["output_0"],
+ name="Add_normalize_node",
+ )
+ return [matmul_o_node, end_node]
+
+ def create_fused_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]):
+ inputs, outputs = self.create_inputs_and_outputs(model_type)
+ matmul_nodes = self.create_matmul_nodes(True, model_type=model_type)
+ rope_nodes = self.create_rotary_embeddings(True, model_type, interleaved, inputs, initializers)
+ attn_mask_nodes = self.create_attn_mask_path(True, model_type)
+
+ mha_inputs = [
+ rope_nodes[0].output[0], # q
+ rope_nodes[1].output[0], # k
+ matmul_nodes[-1].output[0], # v
+ "", # bias
+ "attn_mask_out" if model_type == "llama2_msft" else "", # attn_mask
+ "attn_mask_out" if model_type != "llama2_msft" else "", # add_qk
+ "past_key" if model_type != "no_past" else "", # past_key
+ "past_value" if model_type != "no_past" else "", # past_value
+ ]
+ mha_node = helper.make_node(
+ "MultiHeadAttention",
+ inputs=mha_inputs,
+ outputs=["attn_output", "present_key", "present_value"],
+ name="MultiHeadAttention_0",
+ num_heads=self.num_heads,
+ )
+
+ end_nodes = self.create_end_nodes()
+
+ graph = helper.make_graph(
+ nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes,
+ name="RotaryAttention_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ def create_test_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]):
+ inputs, outputs = self.create_inputs_and_outputs(model_type)
+ matmul_nodes = self.create_matmul_nodes(False, model_type)
+ rope_nodes = self.create_rotary_embeddings(False, model_type, interleaved, inputs, initializers)
+
+ # Create main paths
+ q_nodes = self.create_q_path(model_type)
+ k_nodes = self.create_k_path(model_type)
+ attn_mask_nodes = self.create_attn_mask_path(False, model_type)
+ qk_nodes = self.create_qk_path(model_type)
+ v_nodes = self.create_v_path(model_type)
+ qkv_nodes = self.create_qkv_path(model_type)
+
+ reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes))
+ extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes)
+
+ end_nodes = self.create_end_nodes()
+
+ first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes
+ second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes
+ graph = helper.make_graph(
+ nodes=first_set_of_nodes + second_set_of_nodes,
+ name="RotaryAttention_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="ai.onnx", version=17)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ def check_models(self, model_type: str, interleaved: bool):
+ initializers = self.create_initializers()
+
+ expected_model_filename = "expected_model.onnx"
+ expected_model = self.create_fused_model(model_type, interleaved, initializers)
+ onnx.save(expected_model, expected_model_filename)
+
+ original_model_filename = "original_model.onnx"
+ original_model = self.create_test_model(model_type, interleaved, initializers)
+ onnx.save(original_model, original_model_filename)
+
+ self.verify_fusion(expected_model_filename, original_model_filename)
+ os.remove(expected_model_filename)
+ os.remove(original_model_filename)
+
+ def test_llama2_msft_model(self):
+ model_type = "llama2_msft"
+ interleaved = True
+ self.check_models(model_type, interleaved)
+
+ def test_hf_decoder_model(self):
+ model_type = "no_past"
+ interleaved = False
+ self.check_models(model_type, interleaved)
+
+ def test_hf_decoder_with_past_model(self):
+ model_type = "past"
+ interleaved = False
+ self.check_models(model_type, interleaved)
+
+ def test_hf_decoder_merged_model(self):
+ model_type = "merged"
+ interleaved = False
+ self.check_models(model_type, interleaved)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py
new file mode 100644
index 0000000000000..e86bdda7baffb
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py
@@ -0,0 +1,243 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import os
+import unittest
+from typing import List
+
+import numpy as np
+import onnx
+from onnx import TensorProto, helper
+from parity_utilities import find_transformers_source
+
+if find_transformers_source():
+ from fusion_options import FusionOptions
+ from onnx_model import OnnxModel
+ from optimizer import optimize_model
+else:
+ from onnxruntime.transformers.fusion_options import FusionOptions
+ from onnxruntime.transformers.onnx_model import OnnxModel
+ from onnxruntime.transformers.optimizer import optimize_model
+
+
+def float_tensor(name: str, shape: List[int], random=False):
+ low = 0.0
+ high = 1.0
+ total_elements = 1
+ for x in shape:
+ total_elements *= x
+ weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
+ return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
+
+
+class TestSimplifiedLayerNormFusion(unittest.TestCase):
+ def setUp(self):
+ self.vocab_size = 5
+ self.batch_size = 2
+ self.sequence_length = 8
+ self.hidden_size = 16
+ self.epsilon = 0.000009999999747378752
+
+ def verify_fusion(self, expected_model_path, original_model_path):
+ expected_model = OnnxModel(onnx.load(expected_model_path))
+ expected_model.topological_sort(is_deterministic=True)
+
+ options = FusionOptions("gpt2")
+ optimized_model = optimize_model(original_model_path, optimization_options=options)
+ optimized_model.topological_sort(is_deterministic=True)
+
+ self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph))
+
+ def create_initializers(self, use_embed_weight: bool = False):
+ initializers = [
+ helper.make_tensor("Two", TensorProto.FLOAT, [1], np.array([2], dtype=np.float32)),
+ helper.make_tensor("epsilon", TensorProto.FLOAT, [1], np.array([self.epsilon], dtype=np.float32)),
+ helper.make_tensor("One", TensorProto.FLOAT, [1], np.array([1], dtype=np.float32)),
+ float_tensor("scale", [self.hidden_size]),
+ ]
+ if use_embed_weight:
+ initializers = [ # noqa: RUF005
+ float_tensor("embed_weight", [self.vocab_size, self.hidden_size])
+ ] + initializers
+ return initializers
+
+ def create_inputs_and_outputs(self, start_node_type: str):
+ inputs, start_node = None, None
+ if start_node_type == "Add":
+ start_node = helper.make_node(
+ "Add",
+ inputs=["input_0", "input_1"],
+ outputs=["D"],
+ name="Add_0",
+ )
+ input_0 = helper.make_tensor_value_info(
+ "input_0",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.hidden_size],
+ )
+ input_1 = helper.make_tensor_value_info(
+ "input_1",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.hidden_size],
+ )
+ inputs = [input_0, input_1]
+ elif start_node_type == "Gather":
+ start_node = helper.make_node(
+ "Gather",
+ inputs=["embed_weight", "input_0"],
+ outputs=["D"],
+ name="Gather_0",
+ )
+ input_0 = helper.make_tensor_value_info(
+ "input_0",
+ TensorProto.INT64,
+ [self.batch_size, self.sequence_length],
+ )
+ inputs = [input_0]
+ else:
+ # start_node_type is a graph input
+ assert start_node_type == "GraphInput"
+ input_0 = helper.make_tensor_value_info(
+ "D",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.hidden_size],
+ )
+ inputs = [input_0]
+
+ outputs = [
+ helper.make_tensor_value_info(
+ "output_0",
+ TensorProto.FLOAT,
+ [self.batch_size, self.sequence_length, self.hidden_size],
+ )
+ ]
+ return inputs, outputs, start_node
+
+ def create_fused_model(self, start_node_type: str, initializers: List[TensorProto]):
+ inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type)
+
+ sln_node = helper.make_node(
+ "SimplifiedLayerNormalization",
+ inputs=[start_node.output[0] if start_node is not None else "D", initializers[0].name],
+ outputs=[outputs[0].name],
+ axis=-1,
+ epsilon=initializers[2].float_data[0],
+ stash_type=1,
+ )
+
+ graph = helper.make_graph(
+ nodes=[sln_node] + ([] if start_node is None else [start_node]),
+ name="SimplifiedLayerNorm_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ # Notation follows https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary
+ def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]):
+ end_node = helper.make_node(
+ "Mul",
+ inputs=["scale", "Normalized"] if first_parent_idx == 1 else ["Normalized", "scale"],
+ outputs=["output_0"],
+ name="Mul_1",
+ )
+ mul_node = helper.make_node(
+ "Mul",
+ inputs=["D", "InvStdDev"],
+ outputs=["Normalized"],
+ name="Mul_0",
+ )
+ div_node = helper.make_node(
+ "Div",
+ inputs=["One", "StdDev"],
+ outputs=["InvStdDev"],
+ name="Div_0",
+ )
+ sqrt_node = helper.make_node(
+ "Sqrt",
+ inputs=["VarEps"],
+ outputs=["StdDev"],
+ name="Sqrt_0",
+ )
+ add_node = helper.make_node(
+ "Add",
+ inputs=["Var", "epsilon"],
+ outputs=["VarEps"],
+ name="Add_1",
+ )
+ reducemean_node = helper.make_node(
+ "ReduceMean",
+ inputs=["DD"],
+ outputs=["Var"],
+ name="ReduceMean_0",
+ )
+ pow_node = helper.make_node(
+ "Pow",
+ inputs=["D", "Two"],
+ outputs=["DD"],
+ name="Pow_0",
+ )
+
+ inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type)
+
+ main_nodes = [pow_node, reducemean_node, add_node, sqrt_node, div_node, mul_node, end_node]
+ graph = helper.make_graph(
+ nodes=main_nodes + ([] if start_node is None else [start_node]),
+ name="SimplifiedLayerNorm_Graph",
+ inputs=inputs,
+ outputs=outputs,
+ initializer=initializers,
+ )
+ opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
+ model = helper.make_model(graph, opset_imports=[opset_import])
+ return model
+
+ def check_models(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]):
+ expected_model_filename = "expected_model.onnx"
+ expected_model = self.create_fused_model(start_node_type, initializers)
+ onnx.save(expected_model, expected_model_filename)
+
+ original_model_filename = "original_model.onnx"
+ original_model = self.create_test_model(start_node_type, first_parent_idx, initializers)
+ onnx.save(original_model, original_model_filename)
+
+ self.verify_fusion(expected_model_filename, original_model_filename)
+ os.remove(expected_model_filename)
+ os.remove(original_model_filename)
+
+ # sim_ln_nodes_1
+ def test_simplified_layernorm_add_idx1(self):
+ start_node_type = "Add"
+ first_parent_idx = 1
+ initializers = self.create_initializers()
+ self.check_models(start_node_type, first_parent_idx, initializers)
+
+ # sim_ln_nodes_2
+ def test_simplified_layernorm_gather_idx1(self):
+ start_node_type = "Gather"
+ first_parent_idx = 1
+ initializers = self.create_initializers(use_embed_weight=True)
+ self.check_models(start_node_type, first_parent_idx, initializers)
+
+ # sim_ln_nodes_3
+ def test_simplified_layernorm_add_idx0(self):
+ start_node_type = "Add"
+ first_parent_idx = 0
+ initializers = self.create_initializers()
+ self.check_models(start_node_type, first_parent_idx, initializers)
+
+ # sim_ln_nodes_4
+ def test_simplified_layernorm_gather_graph_input(self):
+ start_node_type = "GraphInput"
+ first_parent_idx = 0
+ initializers = self.create_initializers()
+ self.check_models(start_node_type, first_parent_idx, initializers)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py
index ebda0bccaadcf..ceda5a88c3925 100644
--- a/onnxruntime/test/python/transformers/test_whisper.py
+++ b/onnxruntime/test/python/transformers/test_whisper.py
@@ -50,7 +50,7 @@ def verify_fusion(self, optimized_model, expected_model_filename):
)
)
- # Attention type #1 in onnx_model_bart.py
+ # Attention type #1 in fusion_bart_attention.py
def test_encoder_attention_fusion_with_skiplayernorm(self):
num_heads = 4
hidden_size = 64
@@ -67,7 +67,7 @@ def test_encoder_attention_fusion_with_skiplayernorm(self):
os.remove(model_path)
self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx")
- # Attention type #2 in onnx_model_bart.py
+ # Attention type #2 in fusion_bart_attention.py
def test_decoder_attention_fusion_with_skiplayernorm(self):
num_heads = 4
hidden_size = 64
@@ -84,7 +84,7 @@ def test_decoder_attention_fusion_with_skiplayernorm(self):
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx")
- # Attention type #4 in onnx_model_bart.py
+ # Attention type #4 in fusion_bart_attention.py
def test_decoder_multihead_attention_fusion(self):
num_heads = 4
hidden_size = 64
@@ -100,7 +100,7 @@ def test_decoder_multihead_attention_fusion(self):
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_mha_fused.onnx")
- # Attention type #3 in onnx_model_bart.py
+ # Attention type #3 in fusion_bart_attention.py
def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self):
num_heads = 4
hidden_size = 64
@@ -118,7 +118,7 @@ def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(se
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx")
- # Attention type #5 in onnx_model_bart.py
+ # Attention type #5 in fusion_bart_attention.py
def test_decoder_with_past_multihead_cross_attention_fusion(self):
num_heads = 4
hidden_size = 64
@@ -134,7 +134,7 @@ def test_decoder_with_past_multihead_cross_attention_fusion(self):
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx")
- # Attention type #4 in onnx_model_bart.py
+ # Attention type #4 in fusion_bart_attention.py
def test_decoder_multihead_attention_split_bias_fusion(self):
num_heads = 4
hidden_size = 64
@@ -151,7 +151,7 @@ def test_decoder_multihead_attention_split_bias_fusion(self):
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx")
- # Attention type #3 in onnx_model_bart.py
+ # Attention type #3 in fusion_bart_attention.py
def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self):
num_heads = 4
hidden_size = 64
@@ -171,7 +171,7 @@ def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skipl
os.remove(model_path)
self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx")
- # Attention type #5 in onnx_model_bart.py
+ # Attention type #5 in fusion_bart_attention.py
def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self):
num_heads = 4
hidden_size = 64