-
Notifications
You must be signed in to change notification settings - Fork 3.9k
LLaMA Model Optimization #18021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kunal-vaishnavi
merged 39 commits into
microsoft:main
from
kunal-vaishnavi:kvaishnavi/llama
Oct 23, 2023
Merged
LLaMA Model Optimization #18021
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
e74b899
Initial fusions and kernel changes for LLaMA
kunal-vaishnavi 228de8c
Add rotary embeddings for LLaMA
kunal-vaishnavi dc16e16
Change input shapes and types for fused model
kunal-vaishnavi 816f7e9
Add present kv to multi-head attention
kunal-vaishnavi 5ce8e5a
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi 6669899
Update benchmark scripts
kunal-vaishnavi ed61ae4
Update inputs for optimized model
kunal-vaishnavi cdbd466
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi becbd30
Add interleaved and non-interleaved rotary embeddings
kunal-vaishnavi eece5e8
Update rotary embeddings and export scripts
kunal-vaishnavi 55d0554
Fix attention mask for HF version
kunal-vaishnavi 37e6b5f
Modify rotary embeddings fusion for merged HF model
kunal-vaishnavi 909f8e7
Add optimization passes after conversion
kunal-vaishnavi 43f459b
Fix adding GQA to optimized model
kunal-vaishnavi 4e2bf41
Add CPU implementation for rotary embeddings
kunal-vaishnavi 2210c47
Add test cases
kunal-vaishnavi 6f154e3
Clean up test cases
kunal-vaishnavi 822c2e6
Fix initializer data in test case
kunal-vaishnavi cdf5536
Add merged export
kunal-vaishnavi 52f5994
Remove logger warning
kunal-vaishnavi 0d17656
Update docs
kunal-vaishnavi bcb5a32
Enable buffer sharing and int4 quantization
kunal-vaishnavi 8ae9188
Fix inputs for buffer sharing
kunal-vaishnavi 143d805
Remove extra print
kunal-vaishnavi f2b4644
Clean up code
kunal-vaishnavi d7bb72c
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi 8968bb3
Address PR feedback
kunal-vaishnavi 84f7cc0
Add changes suggested by linters
kunal-vaishnavi 99ec341
Fix min CUDA architecture
kunal-vaishnavi b76e2c2
Add graph input for GQA
kunal-vaishnavi edafef5
Fix GQA parity issue
kunal-vaishnavi 7b82912
Add changes suggested by linter
kunal-vaishnavi a891398
Remove unreferenced parameter
kunal-vaishnavi 716b725
Change rotary embedding test threshold
kunal-vaishnavi 6b8698d
Add int4 CPU support
kunal-vaishnavi cc0199b
Add changes suggested by linters
kunal-vaishnavi e38ecb3
Merge branch 'main' into kvaishnavi/llama
kunal-vaishnavi e69c23b
Fix linter issue
kunal-vaishnavi d14d5bd
Fix CodeQL error
kunal-vaishnavi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "rotary_embedding.h" | ||
| #include "rotary_embedding_helper.h" | ||
|
|
||
| #include "core/platform/threadpool.h" | ||
|
|
||
| using onnxruntime::concurrency::ThreadPool; | ||
| using namespace onnxruntime::contrib::rotary_embedding_helper; | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
||
| // These ops are internal-only, so register outside of onnx | ||
| ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
| RotaryEmbedding, | ||
| kMSDomain, | ||
| 1, | ||
| float, | ||
| kCpuExecutionProvider, | ||
| KernelDefBuilder() | ||
| .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
| .TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), | ||
| RotaryEmbedding<float>); | ||
|
|
||
| template <typename T> | ||
| RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { | ||
| scale = info.GetAttrOrDefault<float>("scale", 1.0); | ||
| interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const { | ||
| const Tensor* input = context->Input<Tensor>(0); | ||
| const Tensor* position_ids = context->Input<Tensor>(1); | ||
| const Tensor* cos_cache = context->Input<Tensor>(2); | ||
| const Tensor* sin_cache = context->Input<Tensor>(3); | ||
|
|
||
| RotaryParameters parameters = {}; | ||
| ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input, | ||
| position_ids, | ||
| cos_cache, | ||
| sin_cache, | ||
| ¶meters)); | ||
|
|
||
| Tensor* output = context->Output(0, input->Shape()); | ||
|
|
||
| if (parameters.sequence_length > parameters.max_sequence_length) { | ||
| // Launch update_cos_sin_cache kernel with scale | ||
| ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); | ||
| } | ||
|
|
||
| const T* input_src = input->Data<T>(); | ||
| const int64_t* pos_ids_data = position_ids->Data<int64_t>(); | ||
| const T* cos_cache_data = cos_cache->Data<T>(); | ||
| const T* sin_cache_data = sin_cache->Data<T>(); | ||
| T* output_dest = output->MutableData<T>(); | ||
|
|
||
| const int batch_size = parameters.batch_size; | ||
| const int sequence_length = parameters.sequence_length; | ||
| const int num_heads = parameters.num_heads; | ||
| const int head_size = parameters.head_size; | ||
| const int position_ids_format = parameters.position_ids_format; | ||
| const int half_head_size = head_size / 2; | ||
|
|
||
| AllocatorPtr allocator; | ||
| ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); | ||
| auto* tp = context->GetOperatorThreadPool(); | ||
|
|
||
| const int loop_len = batch_size * sequence_length * num_heads; | ||
| const double cost = static_cast<double>(head_size); | ||
| ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { | ||
| for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { | ||
| const int b = static_cast<int>((ptr / num_heads) / sequence_length); | ||
| const int s = static_cast<int>((ptr / num_heads) % sequence_length); | ||
| const int n = static_cast<int>(ptr % num_heads); | ||
|
|
||
| const int block_offset = b * sequence_length * num_heads + s * num_heads + n; | ||
| const int data_offset = block_offset * head_size; | ||
|
|
||
| const T* input_data = input_src + data_offset; | ||
| T* output_data = output_dest + data_offset; | ||
|
|
||
| // Cache is (M, H/2) | ||
| const int position_id = (position_ids_format == 0) ? static_cast<int>(pos_ids_data[0]) : static_cast<int>(pos_ids_data[b * sequence_length + s]); | ||
| const int cache_offset = (position_ids_format == 0) ? (position_id + s) * half_head_size : position_id * half_head_size; | ||
| const T* cos_data = cos_cache_data + cache_offset; | ||
| const T* sin_data = sin_cache_data + cache_offset; | ||
|
|
||
| int cache_idx = 0; | ||
| T sign = 0; | ||
| int j = 0; | ||
| for (int i = 0; i < head_size; i++) { | ||
| if (interleaved) { | ||
| cache_idx = (i / 2) % half_head_size; | ||
| sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1); | ||
| j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign | ||
| } else { | ||
| cache_idx = i % half_head_size; | ||
| sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1); | ||
| j = (i + half_head_size) % head_size; | ||
| } | ||
| output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; | ||
| } | ||
| } | ||
| }); | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include "core/common/common.h" | ||
| #include "core/framework/op_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
||
| template <typename T> | ||
| class RotaryEmbedding final : public OpKernel { | ||
| public: | ||
| RotaryEmbedding(const OpKernelInfo& info); | ||
| Status Compute(OpKernelContext* context) const override; | ||
|
|
||
| protected: | ||
| float scale; | ||
| bool interleaved; | ||
| }; | ||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime |
120 changes: 120 additions & 0 deletions
120
onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include "core/common/common.h" | ||
| #include "core/providers/common.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace rotary_embedding_helper { | ||
|
|
||
| // Parameters deduced from node attributes and inputs/outputs. | ||
| struct RotaryParameters { | ||
| int batch_size; // Batch size used by input | ||
| int sequence_length; // Sequence length used by input | ||
| int hidden_size; // Hidden size used by input | ||
| int head_size; // Head size used by cos/sin cache * 2 | ||
| int num_heads; // num_heads = hidden_size / head_size | ||
| int max_sequence_length; // Sequence length used by cos/sin cache | ||
| int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) | ||
| }; | ||
|
|
||
| template <typename T> | ||
| Status CheckInputs(const T* input, | ||
| const T* position_ids, | ||
| const T* cos_cache, | ||
| const T* sin_cache, | ||
| void* parameters) { | ||
| // input : (batch_size, sequence_length, hidden_size) | ||
| // position ids : (1) or (batch_size, sequence_length) | ||
| // cos cache : (max_sequence_length, head_size / 2) | ||
| // sin cache : (max_sequence_length, head_size / 2) | ||
|
|
||
| // Check input | ||
| const auto& input_dims = input->Shape().GetDims(); | ||
| if (input_dims.size() != 3) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", | ||
| input_dims.size()); | ||
| } | ||
| // Check position_ids | ||
| const auto& position_ids_dims = position_ids->Shape().GetDims(); | ||
| if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ", | ||
| position_ids_dims.size()); | ||
| } | ||
| // Check cos_cache and sin_cache | ||
| const auto& cos_cache_dims = cos_cache->Shape().GetDims(); | ||
| if (cos_cache_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", | ||
| cos_cache_dims.size()); | ||
| } | ||
| const auto& sin_cache_dims = sin_cache->Shape().GetDims(); | ||
| if (sin_cache_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", | ||
| sin_cache_dims.size()); | ||
| } | ||
| if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape"); | ||
| } | ||
|
|
||
| // Get attributes from inputs | ||
| int batch_size = static_cast<int>(input_dims[0]); | ||
| int sequence_length = static_cast<int>(input_dims[1]); | ||
| int hidden_size = static_cast<int>(input_dims[2]); | ||
| int max_sequence_length = static_cast<int>(cos_cache_dims[0]); | ||
| int head_size = static_cast<int>(cos_cache_dims[1]) * 2; | ||
| int num_heads = hidden_size / head_size; | ||
| int position_ids_format = -1; | ||
|
|
||
| // Check position_ids input shapes | ||
| if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { | ||
| if (batch_size != static_cast<int>(position_ids_dims[0])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size batch_size, got ", | ||
| position_ids_dims[0]); | ||
| } | ||
| if (sequence_length != static_cast<int>(position_ids_dims[1])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size sequence_length, got ", | ||
| position_ids_dims[1]); | ||
| } | ||
| position_ids_format = 1; | ||
| } else { | ||
| position_ids_format = 0; | ||
| } | ||
| // Check cos_cache input shapes | ||
| if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as max_sequence_length, got ", | ||
| cos_cache_dims[0]); | ||
| } | ||
| if ((head_size / 2) != static_cast<int>(cos_cache_dims[1])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as head_size / 2, got ", | ||
| cos_cache_dims[1]); | ||
| } | ||
| // Check sin_cache input shapes | ||
| if (max_sequence_length != static_cast<int>(sin_cache_dims[0])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as max_sequence_length, got ", | ||
| sin_cache_dims[0]); | ||
| } | ||
| if ((head_size / 2) != static_cast<int>(sin_cache_dims[1])) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as head_size / 2, got ", | ||
| sin_cache_dims[1]); | ||
| } | ||
|
|
||
| // Set rotary parameters | ||
| if (parameters != nullptr) { | ||
| RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters); | ||
| output_parameters->batch_size = batch_size; | ||
| output_parameters->sequence_length = sequence_length; | ||
| output_parameters->hidden_size = hidden_size; | ||
| output_parameters->head_size = head_size; | ||
| output_parameters->num_heads = num_heads; | ||
| output_parameters->max_sequence_length = max_sequence_length; | ||
| output_parameters->position_ids_format = position_ids_format; | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| } // namespace rotary_embedding_helper | ||
| } // namespace contrib | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.