Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,15 @@ class ConvertAtenConvolutionBackwardOp
auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
if (!isa<mlir::FloatType>(gradOutputDTy) ||
!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy))
return op.emitError("unimplemented: only fp convolution bwd supported");
return rewriter.notifyMatchFailure(
op, "unimplemented: only fp convolution bwd supported");

// TODO: support this.
if (!llvm::all_equal({inputDTy, weightDTy, gradOutputDTy}))
return rewriter.notifyMatchFailure(
op, "unimplemented: mixed-precision fp types.");

auto accumulatorDTy = getDefaultAccType(rewriter, inputDTy);

size_t gradRank = cast<RankedTensorType>(gradOutput.getType()).getRank();
size_t numSpatialDims = gradRank - 2;
Expand Down Expand Up @@ -1833,6 +1841,30 @@ class ConvertAtenConvolutionBackwardOp
return createZeroInitTensor(rewriter, loc, expandedSizes, type);
};

auto convertFloatAccDtype = [&](Value accumulator, Type targetDTy) {
auto accDTy =
cast<RankedTensorType>(accumulator.getType()).getElementType();
auto floatAccDTy = dyn_cast<mlir::FloatType>(accDTy);
auto floatTargetDTy = dyn_cast<mlir::FloatType>(targetDTy);

assert(floatAccDTy && "Dtype conversion expects float dtypes only.");
assert(floatTargetDTy && "Dtype conversion expects float dtypes only.");

if (floatAccDTy == floatTargetDTy)
return accumulator;

Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {accumulator}, targetDTy,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
// We don't need to pass src/dst dtypes for floats.
// Update this when supporting integer types.
auto result =
convertScalarToDtype(b, loc, payloadArgs[0], targetDTy);
linalg::YieldOp::create(b, loc, result);
});
return generic;
};

SmallVector<Value> newResults(op->getNumResults());

// Computing Backward-Input Convolution.
Expand Down Expand Up @@ -1945,11 +1977,11 @@ class ConvertAtenConvolutionBackwardOp
// [N, G, C/G, D*] tensor and collapse back to the original input shape.
SmallVector<ReassociationIndices> gradInputCollapseIndices;
Value gradInputInit =
isGroupedConvBwd
? createZeroInitExpandedGroupsTensor(rewriter, loc,
gradInputSizes, inputDTy, 1,
gradInputCollapseIndices)
: createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy);
isGroupedConvBwd ? createZeroInitExpandedGroupsTensor(
rewriter, loc, gradInputSizes, accumulatorDTy,
1, gradInputCollapseIndices)
: createZeroInitTensor(rewriter, loc, gradInputSizes,
accumulatorDTy);

// Create convolution for data gradient
auto convRes = createConvInputGradient(rewriter, loc, context,
Expand All @@ -1958,11 +1990,16 @@ class ConvertAtenConvolutionBackwardOp
weightExpanded, gradInputInit)
.getResult(0);

auto returnTensorTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
auto returnDTy = returnTensorTy.getElementType();
convRes = convertFloatAccDtype(convRes, returnDTy);

// Collapse [N, G, C/G, D] to [N, C, D] the result of the conv
// if it is grouped.
if (isGroupedConvBwd) {
convRes = tensor::CollapseShapeOp::create(
rewriter, loc, input.getType(), convRes, gradInputCollapseIndices);
rewriter, loc, returnTensorTy, convRes, gradInputCollapseIndices);
}

// Cast to the final result type expected by the type converter.
Expand Down Expand Up @@ -1998,10 +2035,11 @@ class ConvertAtenConvolutionBackwardOp
SmallVector<ReassociationIndices> gradWeightCollapseIndices;
Value gradWeightInit =
isGroupedConvBwd
? createZeroInitExpandedGroupsTensor(rewriter, loc,
gradWeightSizes, weightDTy,
0, gradWeightCollapseIndices)
: createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
? createZeroInitExpandedGroupsTensor(
rewriter, loc, gradWeightSizes, accumulatorDTy, 0,
gradWeightCollapseIndices)
: createZeroInitTensor(rewriter, loc, gradWeightSizes,
accumulatorDTy);

// Create convolution for weight gradient
auto convResult = createConvWeightGradient(
Expand All @@ -2010,12 +2048,17 @@ class ConvertAtenConvolutionBackwardOp
paddedInput, gradOutputExpanded, gradWeightInit)
.getResult(0);

auto returnTensorTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(1).getType()));
auto returnDTy = returnTensorTy.getElementType();
convResult = convertFloatAccDtype(convResult, returnDTy);

// Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the conv
// if it is grouped.
if (isGroupedConvBwd) {
convResult = tensor::CollapseShapeOp::create(
rewriter, loc, weight.getType(), convResult,
gradWeightCollapseIndices);
convResult = tensor::CollapseShapeOp::create(rewriter, loc,
returnTensorTy, convResult,
gradWeightCollapseIndices);
}

// Cast to the final result type expected by the type converter.
Expand All @@ -2038,10 +2081,12 @@ class ConvertAtenConvolutionBackwardOp

// Zero init for the element type (arith.constant expects a scalar attr).
Value initSum = arith::ConstantOp::create(
rewriter, loc, rewriter.getZeroAttr(gradOutputDTy));
rewriter, loc, rewriter.getZeroAttr(accumulatorDTy));

auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) {
Value x = args[0];
if (gradOutputDTy != accumulatorDTy)
x = arith::ExtFOp::create(b, loc, accumulatorDTy, x);
Value acc = args[1];
Value sum = arith::AddFOp::create(b, loc, x, acc);
linalg::YieldOp::create(b, loc, sum);
Expand All @@ -2050,6 +2095,11 @@ class ConvertAtenConvolutionBackwardOp
Value gradBias = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo, initSum, reductionBody);

auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(2).getType()));
auto resultDTy = resultType.getElementType();
gradBias = convertFloatAccDtype(gradBias, resultDTy);

newResults[2] = tensor::CastOp::create(rewriter, loc,
getTypeConverter()->convertType(
op->getResult(2).getType()),
Expand Down
68 changes: 68 additions & 0 deletions test/Conversion/TorchToLinalg/convolution_bwd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,71 @@ func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(%arg0: !torch.vten
}

// -----

// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
// CHECK-DAG: %[[CST_F32:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST_BF16:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],bf16> -> tensor<2x128x64x64xbf16>
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16>
// CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16>
// CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xbf16> into tensor<2x4x32x64x64xbf16>
// CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2]
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK-NEXT: tensor.yield %[[CST_BF16]] : bf16
// CHECK-NEXT: } : tensor<2x4x32x64x64xbf16> to tensor<2x4x32x68x68xbf16>
// CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32>
// CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32>
// CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xbf16>, tensor<2x4x4x33x33xbf16>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) {
// CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[EXT0:.*]] = arith.extf %[[IN]] : bf16 to f32
// CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
// CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
// CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32
// CHECK-NEXT: } -> tensor<4x4x32x2x2xf32>
// CHECK: %[[DOWNCAST0:.*]] = linalg.generic
// CHECK-SAME: {indexing_maps = [
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[CONV]] : tensor<4x4x32x2x2xf32>) outs(%[[ZERO_BF16_INIT:.*]] : tensor<4x4x32x2x2xbf16>) {
// CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
// CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
// CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
// CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[DOWNCAST0]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xbf16> into tensor<16x32x2x2xbf16>
// CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xbf16> -> !torch.vtensor<[16,32,2,2],bf16>
// CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32>
// CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32>
// CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xbf16>) outs(%[[SUM_FILLED]] : tensor<16xf32>) {
// CHECK-NEXT: ^bb0(%[[IN_B:.*]]: bf16, %[[ACC_B:.*]]: f32):
// CHECK-NEXT: %[[B_EXT:.*]] = arith.extf %[[IN_B]] : bf16 to f32
// CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[B_EXT]], %[[ACC_B]] : f32
// CHECK-NEXT: linalg.yield %[[B_RES]] : f32
// CHECK-NEXT: } -> tensor<16xf32>
// CHECK: %[[DOWNCAST1:.*]] = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
// CHECK-SAME: iterator_types = ["parallel"]} ins(%[[SUM_GEN]] : tensor<16xf32>) outs(%[[ZERO_BF16_INIT_1:.*]] : tensor<16xbf16>) {
// CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
// CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
// CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
// CHECK-NEXT: } -> tensor<16xbf16>
// CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[DOWNCAST1]] : tensor<16xbf16> -> !torch.vtensor<[16],bf16>
// CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
return %result1, %result2 : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
}

// -----
Loading