Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bilinear2d": self._infer_aten_upsample,
}
self.run_ = True
self.suggested_merge_ = {}
Expand Down
8 changes: 8 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2147,5 +2147,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetScaledSumGradient) {
ORT_THROW("ScaledSum gradient builder does not support ", input_count, " inputs");
}

IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"ResizeGrad", kMSDomain, 1},
{GO(0), I(0), I(1), I(2)},
{GI(0)},
SrcNodeAttributes())};
}

} // namespace training
} // namespace onnxruntime
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ DECLARE_GRADIENT_BUILDER(GetGRUGradient)
DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient);
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
20 changes: 20 additions & 0 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5001,6 +5001,26 @@ Return true if all elements are true and false otherwise.
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");

ONNX_CONTRIB_OPERATOR_SCHEMA(ResizeGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "dY", "Gradient of output Y.", "T")
.Input(1, "X", "Input tensor to the Resize operator.", "T")
.Input(2, "roi", "The roi input to the Resize operator.", "T", OpSchema::Optional)
.Input(3, "scales", "The scales input to the Resize operator.", "tensor(float)", OpSchema::Optional)
Comment thread
baijumeswani marked this conversation as resolved.
.Output(0, "dX", "Gradient of the input X.", "T")
.AllowUncheckedAttributes()
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
if (hasInputShape(ctx, 1)) {
propagateShapeFromInputToOutput(ctx, 1, 0);
}
});
}

} // namespace training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,3 @@ def upsample_nearest2d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
def upsample_nearest3d_gradient():
return _upsample_gradient("upsample_nearest3d_backward", 3)


@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec")
def upsample_bilinear2d_gradient():
return _upsample_gradient("upsample_bilinear2d_backward", 2)
Original file line number Diff line number Diff line change
Expand Up @@ -808,16 +808,3 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")


@register_symbolic("upsample_bilinear2d")
Comment thread
baijumeswani marked this conversation as resolved.
def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors):
return g.op(
"org.pytorch.aten::ATen",
Comment thread
baijumeswani marked this conversation as resolved.
input,
output_size,
align_corners,
scale_factors,
operator_s="upsample_bilinear2d",
overload_name_s="vec",
)
35 changes: 35 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3298,6 +3298,41 @@ TEST(GradientCheckerTest, ConvTransposeGrad) {
execution_providers.push_back(DefaultCudaExecutionProvider());
ConvTransposeGradientCheckerTest(&execution_providers);
}

// TODO: Enable test for ROCM
TEST(GradientCheckerTest, ResizeGrad) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {
MakeAttribute("coordinate_transformation_mode", "half_pixel"),
MakeAttribute("cubic_coeff_a", -0.75f),
MakeAttribute("exclude_outside", static_cast<int64_t>(0)),
MakeAttribute("extrapolation_value", 0.0f),
MakeAttribute("mode", "linear"),
MakeAttribute("nearest_mode", "floor")};

float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Resize", kOnnxDomain, 18};

TensorInfo x_info({1, 2, 4, 4}, true);
TensorInfo roi_info({4}, false, nullptr, DataTypeImpl::GetTensorType<float>());
TensorInfo scales_info({4}, false, nullptr, DataTypeImpl::GetTensorType<float>());

TensorInfo y_info({1, 2, 8, 8}, true);

std::vector<std::vector<float>> x_datas = {{0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f,
0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f,
0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f,
0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f},
{1.0f, 1.0f, 1.0f, 1.0f},
{1.0f, 1.0f, 2.0f, 2.0f}};

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, roi_info, scales_info},
{y_info}, &max_error, x_datas, attributes, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}

#endif // USE_CUDA

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1773,13 +1773,17 @@ def run_step(model, input):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)


def test_aten_upsample_bilinear():
@pytest.mark.parametrize("interpolate_size_scale", ({"size": (8, 12)}, {"scale_factor": 4.7}))
@pytest.mark.parametrize("align_corners", (True, False))
def test_resize_grad_correctness_bilinear_2d(interpolate_size_scale, align_corners):
class _NeuralNetUpsampleBilinear(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear")
return torch.nn.functional.interpolate(
input, align_corners=align_corners, mode="bilinear", **interpolate_size_scale
)

device = "cuda"
pt_model = _NeuralNetUpsampleBilinear().to(device)
Expand Down
227 changes: 227 additions & 0 deletions orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "test/providers/compare_provider_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"

namespace onnxruntime::test {

#if defined(USE_CUDA) || defined(USE_ROCM)

namespace {

void AddResizeGradAttributes(OpTester& test, const std::string& coordinate_transformation_mode) {
test.AddAttribute<std::string>("mode", "linear");
test.AddAttribute<std::string>("coordinate_transformation_mode", coordinate_transformation_mode);
}

} // namespace

TEST(ResizeGradTest, ResizeGradWithSizes) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "half_pixel");

std::vector<float> dY(128, 1.0f);
std::vector<int64_t> dY_shape = {1, 2, 8, 8};

std::vector<float> X(32, 1.0f);
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX(32, 4.0f);
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<float>("dY", dY_shape, dY);
test.AddInput<float>("X", X_shape, X);

test.AddOutput<float>("dX", dX_shape, dX);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

TEST(ResizeGradTest, ResizeGradWithSizesHalf) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "half_pixel");

std::vector<float> dY(128, 1.0f);
std::vector<MLFloat16> dY_half(dY.size());
ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast<int>(dY.size()));
std::vector<int64_t> dY_shape = {1, 2, 8, 8};

std::vector<float> X(32, 1.0f);
std::vector<MLFloat16> X_half(X.size());
ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast<int>(X.size()));
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX(32, 4.0f);
std::vector<MLFloat16> dX_half(dX.size());
ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast<int>(dX.size()));
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<MLFloat16>("dY", dY_shape, dY_half);
test.AddInput<MLFloat16>("X", X_shape, X_half);

test.AddOutput<MLFloat16>("dX", dX_shape, dX_half);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "align_corners");

std::vector<float> dY(128, 1.0f);
std::vector<int64_t> dY_shape = {1, 2, 8, 8};

std::vector<float> X(32, 1.0f);
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX({2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f,
3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f,
2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f,
3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f});
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<float>("dY", dY_shape, dY);
test.AddInput<float>("X", X_shape, X);

test.AddOutput<float>("dX", dX_shape, dX);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

TEST(ResizeGradTest, ResizeGradWithScales) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "half_pixel");

std::vector<float> dY(72, 1.0f);
std::vector<int64_t> dY_shape = {1, 2, 6, 6};

std::vector<float> X(32, 1.0f);
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f,
2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f,
2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f,
2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f});
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<float>("dY", dY_shape, dY);
test.AddInput<float>("X", X_shape, X);
test.AddInput<float>("", {0}, {});
test.AddInput<float>("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f});

test.AddOutput<float>("dX", dX_shape, dX);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

TEST(ResizeGradTest, ResizeGradWithScalesHalf) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "half_pixel");

std::vector<float> dY(72, 1.0f);
std::vector<MLFloat16> dY_half(dY.size());
ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast<int>(dY.size()));
std::vector<int64_t> dY_shape = {1, 2, 6, 6};

std::vector<float> X(32, 1.0f);
std::vector<MLFloat16> X_half(X.size());
ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast<int>(X.size()));
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f,
2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f,
2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f,
2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f});
std::vector<MLFloat16> dX_half(dX.size());
ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast<int>(dX.size()));
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<MLFloat16>("dY", dY_shape, dY_half);
test.AddInput<MLFloat16>("X", X_shape, X_half);
test.AddInput<float>("", {0}, {});
test.AddInput<float>("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f});

test.AddOutput<MLFloat16>("dX", dX_shape, dX_half);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) {
std::vector<std::unique_ptr<IExecutionProvider>> providers;
#ifdef USE_CUDA
providers.emplace_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
providers.emplace_back(DefaultRocmExecutionProvider());
#endif

OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain);

AddResizeGradAttributes(test, "align_corners");

std::vector<float> dY(72, 1.0f);
std::vector<int64_t> dY_shape = {1, 2, 6, 6};

std::vector<float> X(32, 1.0f);
std::vector<int64_t> X_shape = {1, 2, 4, 4};

std::vector<float> dX({1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f,
2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f,
1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f,
2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f});
std::vector<int64_t> dX_shape = X_shape;

test.AddInput<float>("dY", dY_shape, dY);
test.AddInput<float>("X", X_shape, X);
test.AddInput<float>("", {0}, {});
test.AddInput<float>("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f});

test.AddOutput<float>("dX", dX_shape, dX);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
}

#endif // defined(USE_CUDA) || defined(USE_ROCM)

} // namespace onnxruntime::test
Loading