|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "contrib_ops/cpu/bert/rotary_embedding.h" |
| 5 | +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" |
| 6 | + |
| 7 | +#include "core/platform/threadpool.h" |
| 8 | + |
| 9 | +using onnxruntime::concurrency::ThreadPool; |
| 10 | +using namespace onnxruntime::contrib::rotary_embedding_helper; |
| 11 | + |
| 12 | +namespace onnxruntime { |
| 13 | +namespace contrib { |
| 14 | + |
| 15 | +// These ops are internal-only, so register outside of onnx |
| 16 | +ONNX_OPERATOR_TYPED_KERNEL_EX( |
| 17 | + RotaryEmbedding, |
| 18 | + kMSDomain, |
| 19 | + 1, |
| 20 | + float, |
| 21 | + kCpuExecutionProvider, |
| 22 | + KernelDefBuilder() |
| 23 | + .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) |
| 24 | + .TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), |
| 25 | + RotaryEmbedding<float>); |
| 26 | + |
| 27 | +template <typename T> |
| 28 | +RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { |
| 29 | + scale = info.GetAttrOrDefault<float>("scale", 1.0); |
| 30 | + interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); |
| 31 | +} |
| 32 | + |
| 33 | +template <typename T> |
| 34 | +Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const { |
| 35 | + const Tensor* input = context->Input<Tensor>(0); |
| 36 | + const Tensor* position_ids = context->Input<Tensor>(1); |
| 37 | + const Tensor* cos_cache = context->Input<Tensor>(2); |
| 38 | + const Tensor* sin_cache = context->Input<Tensor>(3); |
| 39 | + |
| 40 | + RotaryParameters parameters = {}; |
| 41 | + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input, |
| 42 | + position_ids, |
| 43 | + cos_cache, |
| 44 | + sin_cache, |
| 45 | + ¶meters)); |
| 46 | + |
| 47 | + Tensor* output = context->Output(0, input->Shape()); |
| 48 | + |
| 49 | + if (parameters.sequence_length > parameters.max_sequence_length) { |
| 50 | + // Launch update_cos_sin_cache kernel with scale |
| 51 | + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); |
| 52 | + } |
| 53 | + |
| 54 | + const T* input_src = input->Data<T>(); |
| 55 | + const int64_t* pos_ids_data = position_ids->Data<int64_t>(); |
| 56 | + const T* cos_cache_data = cos_cache->Data<T>(); |
| 57 | + const T* sin_cache_data = sin_cache->Data<T>(); |
| 58 | + T* output_dest = output->MutableData<T>(); |
| 59 | + |
| 60 | + const int batch_size = parameters.batch_size; |
| 61 | + const int sequence_length = parameters.sequence_length; |
| 62 | + const int num_heads = parameters.num_heads; |
| 63 | + const int head_size = parameters.head_size; |
| 64 | + const int position_ids_format = parameters.position_ids_format; |
| 65 | + const int half_head_size = head_size / 2; |
| 66 | + |
| 67 | + AllocatorPtr allocator; |
| 68 | + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); |
| 69 | + auto* tp = context->GetOperatorThreadPool(); |
| 70 | + |
| 71 | + const int loop_len = batch_size * sequence_length * num_heads; |
| 72 | + const double cost = static_cast<double>(head_size); |
| 73 | + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { |
| 74 | + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { |
| 75 | + const int b = static_cast<int>((ptr / num_heads) / sequence_length); |
| 76 | + const int s = static_cast<int>((ptr / num_heads) % sequence_length); |
| 77 | + const int n = static_cast<int>(ptr % num_heads); |
| 78 | + |
| 79 | + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; |
| 80 | + const int data_offset = block_offset * head_size; |
| 81 | + |
| 82 | + const T* input_data = input_src + data_offset; |
| 83 | + T* output_data = output_dest + data_offset; |
| 84 | + |
| 85 | + // Cache is (M, H/2) |
| 86 | + const int position_id = (position_ids_format == 0) |
| 87 | + ? static_cast<int>(pos_ids_data[0]) + s |
| 88 | + : static_cast<int>(pos_ids_data[b * sequence_length + s]); |
| 89 | + const int cache_offset = position_id * half_head_size; |
| 90 | + const T* cos_data = cos_cache_data + cache_offset; |
| 91 | + const T* sin_data = sin_cache_data + cache_offset; |
| 92 | + |
| 93 | + int cache_idx = 0; |
| 94 | + T sign = 0; |
| 95 | + int j = 0; |
| 96 | + for (int i = 0; i < head_size; i++) { |
| 97 | + if (interleaved) { |
| 98 | + cache_idx = (i / 2) % half_head_size; |
| 99 | + sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1); |
| 100 | + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign |
| 101 | + } else { |
| 102 | + cache_idx = i % half_head_size; |
| 103 | + sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1); |
| 104 | + j = (i + half_head_size) % head_size; |
| 105 | + } |
| 106 | + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; |
| 107 | + } |
| 108 | + } |
| 109 | + }); |
| 110 | + |
| 111 | + return Status::OK(); |
| 112 | +} |
| 113 | + |
| 114 | +} // namespace contrib |
| 115 | +} // namespace onnxruntime |
0 commit comments