Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 2 additions & 59 deletions csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,9 @@ class FormatShapeImpl : public FormatShapeOp {
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}

protected:
Result<void> apply(const std::vector<Tensor>& tensors, Tensor& output, int clip_len,
int num_clips) override {
auto N = static_cast<int64_t>(tensors.size());
auto H = tensors[0].shape(1);
auto W = tensors[0].shape(2);
auto C = tensors[0].shape(3);
Device host_{0, 0};

TensorDesc desc = {kHost, DataType::kFLOAT, {N, H, W, C}};
Tensor imgs(desc);
auto offset = 0UL;
auto n_item = H * W * C;
auto copy_size = n_item * sizeof(float);
for (int i = 0; i < N; i++) {
auto src_buffer = tensors[i].buffer();
auto dst_buffer = imgs.buffer();
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
offset += copy_size;
}

OUTCOME_TRY(stream().Wait());

Tensor dst;
if (input_format_ == "NCHW") {
OUTCOME_TRY(dst, FormatNCHW(imgs, clip_len, num_clips));
}
if (input_format_ == "NCTHW") {
OUTCOME_TRY(dst, FormatNCTHW(imgs, clip_len, num_clips));
}
TensorShape expand_dim = dst.shape();
expand_dim.insert(expand_dim.begin(), 1);
dst.Reshape(expand_dim);
output = std::move(dst);

return success();
}

Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
};

Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
auto L = clip_len;
if (N % L != 0) {
return Status(eInvalidArgument);
}
int M = N / L;
src.Reshape({M, L, H, W, C});

return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
};
const Device& GetDevice() { return host_; }

Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) {
Expand Down Expand Up @@ -113,8 +58,6 @@ class FormatShapeImpl : public FormatShapeOp {
} while (i >= 0);
return dst;
}

constexpr static Device kHost{0, 0};
};

MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {
Expand Down
8 changes: 3 additions & 5 deletions csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
endif ()

project(mmdeploy_mmaction_cuda_impl CXX)
include(${CMAKE_SOURCE_DIR}/cmake/modules/FindCUDNN.cmake)

add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp)
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu)
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
endif ()
target_include_directories(${PROJECT_NAME} PRIVATE
${CUDA_INCLUDE_DIRS})
${CUDA_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy::core
cudnn)
mmdeploy::core)
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
mmdeploy_export(${PROJECT_NAME})
108 changes: 15 additions & 93 deletions csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,92 +1,23 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "cudnn.h"
#include "cuda_runtime.h"
#include "mmdeploy/codebase/mmaction/format_shape.h"
#include "mmdeploy/core/utils/device_utils.h"

using namespace std;

namespace mmdeploy::mmaction::cuda {

#define CUDNN_CHECK(condition) \
do { \
if (condition != CUDNN_STATUS_SUCCESS) { \
MMDEPLOY_ERROR("cudnn error, msg = {}", cudnnGetErrorString(condition)); \
} \
} while (0);
template <typename T>
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
int total, cudaStream_t stream);

class FormatShapeImpl : public FormatShapeOp {
public:
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {
CUDNN_CHECK(cudnnCreate(&handle_));
CUDNN_CHECK(cudnnSetStream(handle_, GetNative<cudaStream_t>(stream())));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&src_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dst_desc_));
}

~FormatShapeImpl() override {
CUDNN_CHECK(cudnnDestroy(handle_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(src_desc_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dst_desc_));
}
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}

protected:
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
int num_clips) override {
auto N = static_cast<int64_t>(inputs.size());
auto H = inputs[0].shape(1);
auto W = inputs[0].shape(2);
auto C = inputs[0].shape(3);

auto t0 = std::chrono::high_resolution_clock::now();
TensorDesc desc = {device_, DataType::kFLOAT, {N, H, W, C}};
Tensor imgs(desc);
int offset = 0;
int n_item = H * W * C;
int copy_size = n_item * sizeof(float);
for (int i = 0; i < N; i++) {
auto src_buffer = inputs[i].buffer();
auto dst_buffer = imgs.buffer();
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
offset += copy_size;
}

// Tensor dst;
if (input_format_ == "NCHW") {
OUTCOME_TRY(output, FormatNCHW(imgs, clip_len, num_clips));
}
if (input_format_ == "NCTHW") {
OUTCOME_TRY(output, FormatNCTHW(imgs, clip_len, num_clips));
}
TensorShape expand_dim = output.shape();
expand_dim.insert(expand_dim.begin(), 1);
output.Reshape(expand_dim);

return success();
}

Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
};

Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
int L = clip_len;
if (N % L != 0) {
return Status(eInvalidArgument);
}
int M = N / L;
src.Reshape({M, L, H, W, C});

return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
};
const Device& GetDevice() { return device(); }

Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) {
Expand All @@ -97,14 +28,6 @@ class FormatShapeImpl : public FormatShapeOp {
}
dst.Reshape(shape);

SetCudnnTensorDescriptor(src_dims, permutation);
CUDNN_CHECK(cudnnTransformTensor(handle_, &one_, src_desc_, src.data<float>(), &zero_,
dst_desc_, dst.data<float>()));

return dst;
}

void SetCudnnTensorDescriptor(const TensorShape& src_dims, const std::vector<int>& permutation) {
auto ndim = src_dims.size();
std::vector<int> dst_dims(ndim);
for (int i = 0; i < ndim; i++) {
Expand All @@ -124,17 +47,16 @@ class FormatShapeImpl : public FormatShapeOp {
src_strides[i] = buffer[permutation[i]];
}

CUDNN_CHECK(cudnnSetTensorNdDescriptor(src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
src_strides.data()));
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
dst_strides.data()));
}
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides));
OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides));

constexpr static float one_{1.0};
constexpr static float zero_{0.0};
cudnnHandle_t handle_{};
cudnnTensorDescriptor_t src_desc_{};
cudnnTensorDescriptor_t dst_desc_{};
::mmdeploy::mmaction::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides),
dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
src.size(), (cudaStream_t)stream().GetNative());
return dst;
}
};

MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {
Expand Down
38 changes: 38 additions & 0 deletions csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <stdint.h>
#include <stdio.h>

namespace mmdeploy::mmaction::cuda {

template <typename T>
__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides,
int ndim, int total) {
int u = blockIdx.x * blockDim.x + threadIdx.x;
if (u >= total) {
return;
}

int remaining = u;
int v = 0;
for (int i = 0; i < ndim; i++) {
int p = remaining / dst_strides[i];
remaining -= p * dst_strides[i];
v += p * src_strides[i];
}
dst[u] = src[v];
}

template <typename T>
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
int total, cudaStream_t stream) {
int thread_num = 256;
int block_num = (total + thread_num - 1) / thread_num;
transpose<T>
<<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
}

template void Transpose<float>(const float* src, const int* src_strides, float* dst,
const int* dst_strides, int ndim, int total, cudaStream_t stream);

} // namespace mmdeploy::mmaction::cuda
71 changes: 70 additions & 1 deletion csrc/mmdeploy/codebase/mmaction/format_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,76 @@ namespace mmdeploy::mmaction {
FormatShape::FormatShape(const Value& args) {
auto input_format = args.value("input_format", std::string(""));
if (input_format != "NCHW" && input_format != "NCTHW") {
throw std::domain_error("'input_format' should be 'NCHW' or 'NCTHW'");
MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'");
throw_exception(eInvalidArgument);
}
format_ = operation::Managed<mmdeploy::mmaction::FormatShapeOp>::Create(input_format);
}

Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& output, int clip_len,
int num_clips) {
Tensor inputs;
OUTCOME_TRY(MergeInputs(images, inputs));
if (GetDevice().is_host()) {
OUTCOME_TRY(stream().Wait());
}

// Tensor dst;
if (input_format_ == "NCHW") {
OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips));
}
if (input_format_ == "NCTHW") {
OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips));
}

TensorShape expand_dim = output.shape();
expand_dim.insert(expand_dim.begin(), 1);
output.Reshape(expand_dim);

return success();
}

Result<void> FormatShapeOp::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
auto N = static_cast<int64_t>(images.size());
auto H = images[0].shape(1);
auto W = images[0].shape(2);
auto C = images[0].shape(3);

TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}};
inputs = Tensor(desc);
auto offset = 0UL;
auto n_item = H * W * C;
auto copy_size = n_item * sizeof(float);
for (int i = 0; i < N; i++) {
auto src_buffer = images[i].buffer();
auto dst_buffer = inputs.buffer();
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
offset += copy_size;
}
return success();
}

Result<Tensor> FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
}

Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
int L = clip_len;
if (N % L != 0) {
return Status(eInvalidArgument);
}
int M = N / L;
src.Reshape({M, L, H, W, C});

return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
}

Result<void> FormatShape::Apply(Value& data) {
Expand Down Expand Up @@ -50,6 +118,7 @@ Result<void> FormatShape::Apply(Value& data) {
}

Tensor dst;
data = Value{};
OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips));
data["img"] = std::move(dst);

Expand Down
15 changes: 13 additions & 2 deletions csrc/mmdeploy/codebase/mmaction/format_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@ class FormatShapeOp : public operation::Operation {
public:
explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){};

virtual Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
int num_clips) = 0;
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
int num_clips);

virtual const Device& GetDevice() = 0;

virtual Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) = 0;

Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips);

Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips);

Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);

protected:
std::string input_format_;
Expand Down
2 changes: 1 addition & 1 deletion csrc/mmdeploy/preprocess/transform/lift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace mmdeploy::transform {
class Lift : public Transform {
public:
explicit Lift(const Value& args) {
const char* type = "compose";
const char* type = "Compose";
if (auto creator = gRegistry<Transform>().Get(type)) {
compose_ = creator->Create(args);
} else {
Expand Down