Skip to content

Commit 07bfbbb

Browse files
kunal-vaishnavikleiti
authored andcommitted
LLaMA Model Optimization (microsoft#18021)
### Description This PR contains fusion-level and kernel-level optimizations for [Meta's LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/). Some of the added optimizations include: - SimplifiedLayerNorm changes - Fusions for multiple variants - SkipSimplifiedLayerNorm changes - Kernel support for CPU - Rotary embeddings (previously did not exist) - Fusions for multiple variants - CPU and CUDA kernels - Supports interleaving and non-interleaving in the same kernels - Optimized cache that requires half of its originally exported sizes - Reduced from `(max_sequence_length, head_size)` to `(max_sequence_length, head_size / 2)` - Multi-head attention - Support for 2D and 3D attention masks - Group query attention (for FP16 CUDA and INT4 CUDA) - Integration with flash attention v2 and past-present buffer sharing - Removes need for `attention_mask` input as it is supported in the kernel - 4 bit quantization - `block_size` parameter is available for customizing - Support the new changes for [Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Support combinations of the below variants (ex: export ORT version and run with Optimum) Supported variants of LLaMA-2 include: - [ORT version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama) - Produces one ONNX file that is already optimized (and quantized if requested) - Integrates with Optimum - [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Already exported and available off-the-shelf - Faster versions of those models will be uploaded there soon - [Hugging Face version](https://huggingface.co/meta-llama) - Models that end with `-hf` - Some older and current versions of [`transformers`](https://github.com/huggingface/transformers) and [`optimum`](https://github.com/huggingface/optimum) that export the model to ONNX differently - Note that while some older versions are supported, it is recommended to use the latest package versions. ### Usage To use the optimizations, please see `README.md` for details. Please note the various `requirements.txt` files for the package versions recommended in order to use these changes. To run the ORT transformer optimizer separately, run the script as follows: ``` $ cd onnxruntime/onnxruntime/python/tools/transformers/ $ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0 ``` ### Motivation and Context This PR helps the following issues: - microsoft#14997 - microsoft#16254 - microsoft#17681 - microsoft#17925 - microsoft/onnxruntime-inference-examples#320 This PR uses changes from the following PRs: - pytorch/pytorch#104468 - pytorch/pytorch#109759 - microsoft#17020 - microsoft#17674 - microsoft#17890 - microsoft#17920 - huggingface/transformers#26162 - huggingface/optimum#1257 - huggingface/optimum#1289 - huggingface/optimum#1462 ### New TorchDynamo Exporter (experimental stage) This PR uses changes from the following issues and PRs to begin supporting the [new TorchDynamo exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter): - huggingface/transformers#26307 - pytorch/pytorch#104903 - pytorch/pytorch#105040 - microsoft/onnxscript#847 - microsoft/onnxscript#862 - microsoft/onnxscript#493
1 parent e68bb7b commit 07bfbbb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+5897
-563
lines changed

docs/ContribOperators.md

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Do not modify directly.*
9090
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
9191
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
9292
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
93+
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
9394
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
9495
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
9596
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
@@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
28342835
<dt><tt>bias</tt> (optional) : T</dt>
28352836
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
28362837
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
2837-
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
2838+
<dd>Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)</dd>
28382839
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
28392840
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
28402841
<dt><tt>past_key</tt> (optional) : T</dt>
@@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr
47964797
</dl>
47974798

47984799

4800+
### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>
4801+
4802+
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
4803+
that are multiplied to query and key before the inner product of query and key is taken.
4804+
4805+
#### Version
4806+
4807+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
4808+
4809+
#### Attributes
4810+
4811+
<dl>
4812+
<dt><tt>interleaved</tt> : int</dt>
4813+
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
4814+
<dt><tt>scale</tt> : float</dt>
4815+
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
4816+
</dl>
4817+
4818+
#### Inputs
4819+
4820+
<dl>
4821+
<dt><tt>input</tt> : T</dt>
4822+
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
4823+
<dt><tt>position_ids</tt> : M</dt>
4824+
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
4825+
<dt><tt>cos_cache</tt> : T</dt>
4826+
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
4827+
<dt><tt>sin_cache</tt> : T</dt>
4828+
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
4829+
</dl>
4830+
4831+
#### Outputs
4832+
4833+
<dl>
4834+
<dt><tt>output</tt> : T</dt>
4835+
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
4836+
</dl>
4837+
4838+
#### Type Constraints
4839+
4840+
<dl>
4841+
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
4842+
<dd>Constrain input and output types to float tensors.</dd>
4843+
<dt><tt>M</tt> : tensor(int64)</dt>
4844+
<dd>Constrain input and output types to integer tensors</dd>
4845+
</dl>
4846+
4847+
47994848
### <a name="com.microsoft.SampleOp"></a><a name="com.microsoft.sampleop">**com.microsoft.SampleOp**</a>
48004849

48014850
Sample echo operator.

docs/OperatorKernels.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,11 @@ Do not modify directly.*
477477
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
478478
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
479479
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
480+
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
480481
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
481482
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
482483
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
484+
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
483485
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
484486
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
485487
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
@@ -866,6 +868,7 @@ Do not modify directly.*
866868
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
867869
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
868870
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
871+
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
869872
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
870873
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
871874
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include <unsupported/Eigen/SpecialFunctions>
1818
#include <vector>
19-
#include <iostream>
2019

2120
using onnxruntime::concurrency::ThreadPool;
2221

onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ Status CheckInputs(const T* query,
206206
}
207207
}
208208

209+
int total_sequence_length = past_sequence_length + kv_sequence_length;
209210
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
210211
if (key_padding_mask != nullptr) {
211212
mask_type = AttentionMaskType::MASK_UNKNOWN;
@@ -216,13 +217,21 @@ Status CheckInputs(const T* query,
216217
} else if (mask_dims[0] == static_cast<int64_t>(3) * static_cast<int64_t>(batch_size) + static_cast<int64_t>(2)) {
217218
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
218219
}
219-
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
220+
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
221+
mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
222+
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
223+
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
224+
mask_dims[1] == static_cast<int64_t>(total_sequence_length)) {
220225
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
226+
} else if (mask_dims.size() == 3 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
227+
mask_dims[1] == static_cast<int64_t>(sequence_length) &&
228+
mask_dims[2] == static_cast<int64_t>(total_sequence_length)) {
229+
mask_type = AttentionMaskType::MASK_3D_ATTENTION;
221230
}
222231

223232
if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
224233
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
225-
"Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
234+
"Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
226235
}
227236
}
228237

@@ -257,7 +266,6 @@ Status CheckInputs(const T* query,
257266
}
258267
}
259268

260-
int total_sequence_length = past_sequence_length + kv_sequence_length;
261269
bool broadcast_res_pos_bias = false;
262270
if (relative_position_bias != nullptr) {
263271
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "contrib_ops/cpu/bert/rotary_embedding.h"
5+
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
6+
7+
#include "core/platform/threadpool.h"
8+
9+
using onnxruntime::concurrency::ThreadPool;
10+
using namespace onnxruntime::contrib::rotary_embedding_helper;
11+
12+
namespace onnxruntime {
13+
namespace contrib {
14+
15+
// These ops are internal-only, so register outside of onnx
16+
ONNX_OPERATOR_TYPED_KERNEL_EX(
17+
RotaryEmbedding,
18+
kMSDomain,
19+
1,
20+
float,
21+
kCpuExecutionProvider,
22+
KernelDefBuilder()
23+
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
24+
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
25+
RotaryEmbedding<float>);
26+
27+
template <typename T>
28+
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
29+
scale = info.GetAttrOrDefault<float>("scale", 1.0);
30+
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
31+
}
32+
33+
template <typename T>
34+
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
35+
const Tensor* input = context->Input<Tensor>(0);
36+
const Tensor* position_ids = context->Input<Tensor>(1);
37+
const Tensor* cos_cache = context->Input<Tensor>(2);
38+
const Tensor* sin_cache = context->Input<Tensor>(3);
39+
40+
RotaryParameters parameters = {};
41+
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
42+
position_ids,
43+
cos_cache,
44+
sin_cache,
45+
&parameters));
46+
47+
Tensor* output = context->Output(0, input->Shape());
48+
49+
if (parameters.sequence_length > parameters.max_sequence_length) {
50+
// Launch update_cos_sin_cache kernel with scale
51+
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
52+
}
53+
54+
const T* input_src = input->Data<T>();
55+
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
56+
const T* cos_cache_data = cos_cache->Data<T>();
57+
const T* sin_cache_data = sin_cache->Data<T>();
58+
T* output_dest = output->MutableData<T>();
59+
60+
const int batch_size = parameters.batch_size;
61+
const int sequence_length = parameters.sequence_length;
62+
const int num_heads = parameters.num_heads;
63+
const int head_size = parameters.head_size;
64+
const int position_ids_format = parameters.position_ids_format;
65+
const int half_head_size = head_size / 2;
66+
67+
AllocatorPtr allocator;
68+
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
69+
auto* tp = context->GetOperatorThreadPool();
70+
71+
const int loop_len = batch_size * sequence_length * num_heads;
72+
const double cost = static_cast<double>(head_size);
73+
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
74+
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
75+
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
76+
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
77+
const int n = static_cast<int>(ptr % num_heads);
78+
79+
const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
80+
const int data_offset = block_offset * head_size;
81+
82+
const T* input_data = input_src + data_offset;
83+
T* output_data = output_dest + data_offset;
84+
85+
// Cache is (M, H/2)
86+
const int position_id = (position_ids_format == 0)
87+
? static_cast<int>(pos_ids_data[0]) + s
88+
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
89+
const int cache_offset = position_id * half_head_size;
90+
const T* cos_data = cos_cache_data + cache_offset;
91+
const T* sin_data = sin_cache_data + cache_offset;
92+
93+
int cache_idx = 0;
94+
T sign = 0;
95+
int j = 0;
96+
for (int i = 0; i < head_size; i++) {
97+
if (interleaved) {
98+
cache_idx = (i / 2) % half_head_size;
99+
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
100+
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
101+
} else {
102+
cache_idx = i % half_head_size;
103+
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
104+
j = (i + half_head_size) % head_size;
105+
}
106+
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
107+
}
108+
}
109+
});
110+
111+
return Status::OK();
112+
}
113+
114+
} // namespace contrib
115+
} // namespace onnxruntime
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/common.h"
6+
#include "core/framework/op_kernel.h"
7+
8+
namespace onnxruntime {
9+
namespace contrib {
10+
11+
template <typename T>
12+
class RotaryEmbedding final : public OpKernel {
13+
public:
14+
RotaryEmbedding(const OpKernelInfo& info);
15+
Status Compute(OpKernelContext* context) const override;
16+
17+
protected:
18+
float scale;
19+
bool interleaved;
20+
};
21+
22+
} // namespace contrib
23+
} // namespace onnxruntime

0 commit comments

Comments
 (0)