diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3edd7888d2f2..5d4a1fa5b65e 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1695,7 +1695,15 @@ class ConvertAtenConvolutionBackwardOp auto weightDTy = cast(weight.getType()).getElementType(); if (!isa(gradOutputDTy) || !isa(inputDTy) || !isa(weightDTy)) - return op.emitError("unimplemented: only fp convolution bwd supported"); + return rewriter.notifyMatchFailure( + op, "unimplemented: only fp convolution bwd supported"); + + // TODO: support this. + if (!llvm::all_equal({inputDTy, weightDTy, gradOutputDTy})) + return rewriter.notifyMatchFailure( + op, "unimplemented: mixed-precision fp types."); + + auto accumulatorDTy = getDefaultAccType(rewriter, inputDTy); size_t gradRank = cast(gradOutput.getType()).getRank(); size_t numSpatialDims = gradRank - 2; @@ -1833,6 +1841,22 @@ class ConvertAtenConvolutionBackwardOp return createZeroInitTensor(rewriter, loc, expandedSizes, type); }; + auto convertFloatAccDtype = [&](Value accumulator, Type targetDTy) { + auto accDTy = + cast(accumulator.getType()).getElementType(); + auto floatAccDTy = dyn_cast(accDTy); + auto floatTargetDTy = dyn_cast(targetDTy); + + assert(floatAccDTy && "Dtype conversion expects float dtypes only."); + assert(floatTargetDTy && "Dtype conversion expects float dtypes only."); + + if (floatAccDTy == floatTargetDTy) + return accumulator; + + return torch_to_linalg::convertTensorToElementType( + rewriter, loc, accumulator, targetDTy); + }; + SmallVector newResults(op->getNumResults()); // Computing Backward-Input Convolution. @@ -1945,11 +1969,11 @@ class ConvertAtenConvolutionBackwardOp // [N, G, C/G, D*] tensor and collapse back to the original input shape. SmallVector gradInputCollapseIndices; Value gradInputInit = - isGroupedConvBwd - ? createZeroInitExpandedGroupsTensor(rewriter, loc, - gradInputSizes, inputDTy, 1, - gradInputCollapseIndices) - : createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); + isGroupedConvBwd ? createZeroInitExpandedGroupsTensor( + rewriter, loc, gradInputSizes, accumulatorDTy, + 1, gradInputCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradInputSizes, + accumulatorDTy); // Create convolution for data gradient auto convRes = createConvInputGradient(rewriter, loc, context, @@ -1958,11 +1982,16 @@ class ConvertAtenConvolutionBackwardOp weightExpanded, gradInputInit) .getResult(0); + auto returnTensorTy = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + auto returnDTy = returnTensorTy.getElementType(); + convRes = convertFloatAccDtype(convRes, returnDTy); + // Collapse [N, G, C/G, D] to [N, C, D] the result of the conv // if it is grouped. if (isGroupedConvBwd) { convRes = tensor::CollapseShapeOp::create( - rewriter, loc, input.getType(), convRes, gradInputCollapseIndices); + rewriter, loc, returnTensorTy, convRes, gradInputCollapseIndices); } // Cast to the final result type expected by the type converter. @@ -1998,10 +2027,11 @@ class ConvertAtenConvolutionBackwardOp SmallVector gradWeightCollapseIndices; Value gradWeightInit = isGroupedConvBwd - ? createZeroInitExpandedGroupsTensor(rewriter, loc, - gradWeightSizes, weightDTy, - 0, gradWeightCollapseIndices) - : createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); + ? createZeroInitExpandedGroupsTensor( + rewriter, loc, gradWeightSizes, accumulatorDTy, 0, + gradWeightCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradWeightSizes, + accumulatorDTy); // Create convolution for weight gradient auto convResult = createConvWeightGradient( @@ -2010,12 +2040,17 @@ class ConvertAtenConvolutionBackwardOp paddedInput, gradOutputExpanded, gradWeightInit) .getResult(0); + auto returnTensorTy = cast( + getTypeConverter()->convertType(op->getResult(1).getType())); + auto returnDTy = returnTensorTy.getElementType(); + convResult = convertFloatAccDtype(convResult, returnDTy); + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the conv // if it is grouped. if (isGroupedConvBwd) { - convResult = tensor::CollapseShapeOp::create( - rewriter, loc, weight.getType(), convResult, - gradWeightCollapseIndices); + convResult = tensor::CollapseShapeOp::create(rewriter, loc, + returnTensorTy, convResult, + gradWeightCollapseIndices); } // Cast to the final result type expected by the type converter. @@ -2038,10 +2073,12 @@ class ConvertAtenConvolutionBackwardOp // Zero init for the element type (arith.constant expects a scalar attr). Value initSum = arith::ConstantOp::create( - rewriter, loc, rewriter.getZeroAttr(gradOutputDTy)); + rewriter, loc, rewriter.getZeroAttr(accumulatorDTy)); auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) { Value x = args[0]; + if (gradOutputDTy != accumulatorDTy) + x = arith::ExtFOp::create(b, loc, accumulatorDTy, x); Value acc = args[1]; Value sum = arith::AddFOp::create(b, loc, x, acc); linalg::YieldOp::create(b, loc, sum); @@ -2050,6 +2087,11 @@ class ConvertAtenConvolutionBackwardOp Value gradBias = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, initSum, reductionBody); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(2).getType())); + auto resultDTy = resultType.getElementType(); + gradBias = convertFloatAccDtype(gradBias, resultDTy); + newResults[2] = tensor::CastOp::create(rewriter, loc, getTypeConverter()->convertType( op->getResult(2).getType()), diff --git a/test/Conversion/TorchToLinalg/convolution_bwd.mlir b/test/Conversion/TorchToLinalg/convolution_bwd.mlir index 513e74f74ec8..63eee1f62f5b 100644 --- a/test/Conversion/TorchToLinalg/convolution_bwd.mlir +++ b/test/Conversion/TorchToLinalg/convolution_bwd.mlir @@ -415,3 +415,134 @@ func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(%arg0: !torch.vten } // ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) { + // CHECK-DAG: %[[CST_F32:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[CST_BF16:.*]] = arith.constant 0.000000e+00 : bf16 + // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],bf16> -> tensor<2x128x64x64xbf16> + // CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16> + // CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xbf16> into tensor<2x4x32x64x64xbf16> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST_BF16]] : bf16 + // CHECK-NEXT: } : tensor<2x4x32x64x64xbf16> to tensor<2x4x32x68x68xbf16> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xbf16>, tensor<2x4x4x33x33xbf16>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[EXT0:.*]] = arith.extf %[[IN]] : bf16 to f32 + // CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32 + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[DOWNCAST0:.*]] = linalg.generic + // CHECK-SAME: {indexing_maps = [ + // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, + // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%[[CONV]] : tensor<4x4x32x2x2xf32>) outs(%[[ZERO_BF16_INIT:.*]] : tensor<4x4x32x2x2xbf16>) { + // CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16): + // CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16 + // CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16> + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[DOWNCAST0]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xbf16> into tensor<16x32x2x2xbf16> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xbf16> -> !torch.vtensor<[16,32,2,2],bf16> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xbf16>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: bf16, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_EXT:.*]] = arith.extf %[[IN_B]] : bf16 to f32 + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[B_EXT]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[DOWNCAST1:.*]] = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[SUM_GEN]] : tensor<16xf32>) outs(%[[ZERO_BF16_INIT_1:.*]] : tensor<16xbf16>) { + // CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16): + // CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16 + // CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16 + // CHECK-NEXT: } -> tensor<16xbf16> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[DOWNCAST1]] : tensor<16xbf16> -> !torch.vtensor<[16],bf16> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16> + return %result1, %result2 : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> !torch.vtensor<[2,128,64,64],bf16> { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> !torch.vtensor<[2,128,64,64],bf16> { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0_F32:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[CST0_BF16:.*]] = arith.constant 0.000000e+00 : bf16 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],bf16> -> tensor<16x32x2x2xbf16> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16> + // CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xbf16> into tensor<4x4x32x2x2xbf16> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xbf16> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0_BF16]] : bf16) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xbf16>) -> tensor<4x4x32x2x2xbf16> + // CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xbf16>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xbf16>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: bf16, %[[OUT_W:.*]]: bf16): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xbf16> + // CHECK-NEXT: linalg.yield %[[EX]] : bf16 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xbf16> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%[[CST_BF16]] : bf16) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xbf16>) -> tensor<2x4x4x66x66xbf16> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xbf16> into tensor<2x4x4x66x66xbf16> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0_F32]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV_F32:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xbf16>, tensor<4x4x32x2x2xbf16>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[EXT:.*]] = arith.extf %[[IN]] : bf16 to f32 + // CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32 + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT]], %[[EXT1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x4x32x64x64xf32> + // CHECK: %[[EMPTY_BF16:.*]] = tensor.empty() : tensor<2x4x32x64x64xbf16> + // CHECK: %[[CONV_BF16:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[CONV_F32]] : tensor<2x4x32x64x64xf32>) outs(%[[EMPTY_BF16]] : tensor<2x4x32x64x64xbf16>) { + // CHECK: ^bb0(%[[IN_F32:.*]]: f32, %[[OUT_BF16:.*]]: bf16): + // CHECK: %[[TRUNC_BF16:.*]] = arith.truncf %[[IN_F32]] : f32 to bf16 + // CHECK: linalg.yield %[[TRUNC_BF16]] : bf16 + // CHECK: } -> tensor<2x4x32x64x64xbf16> + // CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV_BF16]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xbf16> into tensor<2x128x64x64xbf16> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xbf16> -> !torch.vtensor<[2,128,64,64],bf16> + // CHECK: return %[[IGRAD]] : !torch.vtensor<[2,128,64,64],bf16> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],bf16>, !torch.none, !torch.none + return %result0 : !torch.vtensor<[2,128,64,64],bf16> +} + +// -----