Skip to content

Commit 0f3860d

Browse files
authored
[CINN] Arange supports partially symbolic inputs (PaddlePaddle#74209)
* [CINN] Arange support partially symbolic shape input * [CINN] Fixed arange symbolic input support * [CINN] Fixed arange CI approval
1 parent 8830a97 commit 0f3860d

File tree

16 files changed

+411
-99
lines changed

16 files changed

+411
-99
lines changed

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ struct PredicatePrinter : public ir::IrPrinter {
134134
void Visit(const ir::Or *x) { PrintBinaryOp("OR", x); }
135135
void Visit(const ir::Max *x) { PrintBinaryOp("MAX", x); }
136136
void Visit(const ir::Min *x) { PrintBinaryOp("MIN", x); }
137+
void Visit(const ir::Call *x) { PrintCallOp(x); }
137138

138139
template <typename IRN>
139140
void PrintBinaryOp(const std::string &op, const ir::BinaryOpNode<IRN> *x) {
@@ -143,6 +144,27 @@ struct PredicatePrinter : public ir::IrPrinter {
143144
ir::IrPrinter::Visit(x->b());
144145
str_ += "_BPA_";
145146
}
147+
148+
void PrintCallOp(const ir::Call *x) {
149+
str_ += "_BCALL_";
150+
str_ += [&]() {
151+
std::string temp = x->name;
152+
std::transform(
153+
temp.begin(), temp.end(), temp.begin(), [](unsigned char c) {
154+
return std::toupper(c);
155+
});
156+
return temp;
157+
}();
158+
if (!x->read_args.empty()) {
159+
str_ += "_R_";
160+
for (const auto &v : x->read_args) ir::IrPrinter::Visit(v);
161+
}
162+
if (!x->write_args.empty()) {
163+
str_ += "_W_";
164+
for (const auto &v : x->write_args) ir::IrPrinter::Visit(v);
165+
}
166+
str_ += "_ECALL_";
167+
}
146168
};
147169

148170
std::string Predicate2String(ir::Expr predicate) {

paddle/cinn/backends/llvm/codegen_llvm.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,14 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Dim_ *) {
838838
CINN_NOT_IMPLEMENTED return nullptr;
839839
}
840840

841+
llvm::Function *CallHostFallBack(const llvm::Module *m, const ir::Call *op) {
842+
std::string fallback_func_name =
843+
"cinn_host_" + op->name + "_" + common::Type2Str(op->type());
844+
VLOG(6) << "Warn: host side has no func named '" << op->name
845+
<< "', trying a fallback version '" << fallback_func_name << "'";
846+
return m->getFunction(fallback_func_name);
847+
}
848+
841849
llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
842850
if (op->name == runtime::intrinsic::debug_log_repr) {
843851
return EmitCall_debug_info(op);
@@ -854,6 +862,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
854862
}
855863

856864
llvm::Function *callee = m_->getFunction(op->name);
865+
if (!callee) {
866+
callee = CallHostFallBack(m_, op);
867+
}
857868
CHECK(callee) << "Unknown function referenced. [" << op->name << "]";
858869

859870
std::vector<llvm::Value *> args;

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ void ApplyCinnPreprocessPass(
172172
if (has_dynamic_shape) {
173173
pass_manager->AddPass(
174174
cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass());
175+
pass_manager->AddPass(
176+
cinn::dialect::ir::CreatePdOpToDynamicShapeCinnOpPass());
175177
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
176178
}
177179
pass_manager->Run(program);

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -514,18 +514,68 @@ class SliceOpPattern : public pir::OpRewritePattern<paddle::dialect::SliceOp> {
514514
}
515515
};
516516

517+
/**
518+
* CINN ArangeOp supports two kinds of input:
519+
* input from pd_op.full (static) and input from cinn_op.generate_shape
520+
* An example for the latter:
521+
* ```c++
522+
* x = paddle.zeros([3, 10])
523+
* batch_size = paddle.shape(x)[1]
524+
* stop = batch_size * 2
525+
* paddle.arange(
526+
* 0, // static start (from pd_op.full)
527+
* stop, // symbolic stop (from cinn_op.generate_shape)
528+
* 2 // static end (from pd_op.full)
529+
* )
530+
* ``` Note that step is not allowed to be symbolic, and when
531+
* the inputs are symbolic, the start and end must be of integer type
532+
*/
517533
class ArangeOpPattern
518534
: public pir::OpRewritePattern<paddle::dialect::ArangeOp> {
519535
public:
520536
using pir::OpRewritePattern<paddle::dialect::ArangeOp>::OpRewritePattern;
521537

522538
bool Match(paddle::dialect::ArangeOp op) const override {
523-
// ArangeOp for CINN must have static start, end, step to calculate
524-
// the shape of output tensor. Otherwise, it will be denied
525-
// due to CauseNewSymbolicShape returning false
526539
bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
527-
return !is_denied && IsDefinedBy<FullOp>(op, 0) &&
528-
IsDefinedBy<FullOp>(op, 1) && IsDefinedBy<FullOp>(op, 2);
540+
if (is_denied) return false;
541+
// step is not allowed to be symbolic
542+
if (IsDefinedBy<FullOp>(op, 2)) {
543+
const FullOp full_op = CastDefinedTo<FullOp>(op, 2);
544+
phi::Scalar step = full_op.attribute("value")
545+
.dyn_cast<paddle::dialect::ScalarAttribute>()
546+
.data();
547+
bool positive_step = true;
548+
#define MATCH_TYPE_TEST(TypeEnum, Dtype) \
549+
case phi::DataType::TypeEnum: \
550+
positive_step = step.to<Dtype>() > 0; \
551+
break;
552+
553+
switch (step.dtype()) {
554+
MATCH_TYPE_TEST(FLOAT32, float)
555+
MATCH_TYPE_TEST(FLOAT64, double)
556+
MATCH_TYPE_TEST(INT32, int)
557+
MATCH_TYPE_TEST(INT64, int64_t)
558+
MATCH_TYPE_TEST(FLOAT16, float)
559+
MATCH_TYPE_TEST(BFLOAT16, float)
560+
#undef MATCH_TYPE_TEST
561+
default:
562+
positive_step = false;
563+
}
564+
if (positive_step) {
565+
const auto &dtype = op.attributes()
566+
.at("dtype")
567+
.dyn_cast<paddle::dialect::DataTypeAttribute>()
568+
.data();
569+
return (IsDefinedBy<FullOp>(op, 0) ||
570+
IsDefinedBy<GenerateShapeOp>(op, 0)) &&
571+
(IsDefinedBy<FullOp>(op, 1) ||
572+
IsDefinedBy<GenerateShapeOp>(op, 1)) &&
573+
(dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64);
574+
} else {
575+
return IsDefinedBy<FullOp>(op, 0) && IsDefinedBy<FullOp>(op, 1);
576+
}
577+
}
578+
return false;
529579
}
530580

531581
void Rewrite(paddle::dialect::ArangeOp op,
@@ -537,31 +587,39 @@ class ArangeOpPattern
537587

538588
std::array<phi::Scalar, 3> input_list;
539589
for (int i = 0; i < 3; i++) {
540-
const FullOp full_op = CastDefinedTo<FullOp>(op, i);
541-
phi::Scalar input = full_op.attribute("value")
542-
.dyn_cast<paddle::dialect::ScalarAttribute>()
543-
.data();
544-
if (input.dtype() != dtype) {
545-
// FullOp creates a tensor (scalar) with fp64 type by default
546-
// therefore, we might need to perform type casting
547-
switch (dtype) {
548-
case phi::DataType::FLOAT32:
549-
input = phi::Scalar(input.to<float>());
550-
break;
551-
case phi::DataType::FLOAT64:
552-
input = phi::Scalar(input.to<double>());
553-
break;
554-
case phi::DataType::INT32:
555-
input = phi::Scalar(input.to<int>());
556-
break;
557-
case phi::DataType::FLOAT16:
558-
input = phi::Scalar(input.to<float>());
559-
break;
560-
case phi::DataType::BFLOAT16:
561-
input = phi::Scalar(input.to<float>());
562-
break;
563-
default:
564-
input = phi::Scalar(input.to<int64_t>());
590+
phi::Scalar input;
591+
if (IsDefinedBy<GenerateShapeOp>(op, i)) {
592+
// arange does not support bool, so if the input is boolean, this would
593+
// mean that there is dynamic shape
594+
input = phi::Scalar(false);
595+
input.SetFromTensor(true);
596+
} else {
597+
const FullOp full_op = CastDefinedTo<FullOp>(op, i);
598+
input = full_op.attribute("value")
599+
.dyn_cast<paddle::dialect::ScalarAttribute>()
600+
.data();
601+
if (input.dtype() != dtype) {
602+
// FullOp creates a tensor (scalar) with fp64 type by default
603+
// therefore, we might need to perform type casting
604+
switch (dtype) {
605+
case phi::DataType::FLOAT32:
606+
input = phi::Scalar(input.to<float>());
607+
break;
608+
case phi::DataType::FLOAT64:
609+
input = phi::Scalar(input.to<double>());
610+
break;
611+
case phi::DataType::INT32:
612+
input = phi::Scalar(input.to<int>());
613+
break;
614+
case phi::DataType::FLOAT16:
615+
input = phi::Scalar(input.to<float>());
616+
break;
617+
case phi::DataType::BFLOAT16:
618+
input = phi::Scalar(input.to<float>());
619+
break;
620+
default:
621+
input = phi::Scalar(input.to<int64_t>());
622+
}
565623
}
566624
}
567625
input_list[i] = input;
@@ -1436,6 +1494,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
14361494
ps.Add<
14371495
ArgMinMaxOpPattern<paddle::dialect::ArgmaxOp, cinn::dialect::ArgmaxOp>>(
14381496
context);
1497+
// Arange in this pass only handles static inputs
14391498
ps.Add<ArangeOpPattern>(context);
14401499
ps.Add<ProdOpPattern>(context);
14411500
ps.Add<ReshapeOpPattern>(context);
@@ -1469,6 +1528,24 @@ std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass() {
14691528
return std::make_unique<PdOpToCinnOpPass>();
14701529
}
14711530

1531+
PdOpToDynamicShapeCinnOpPass::PdOpToDynamicShapeCinnOpPass()
1532+
: pir::PatternRewritePass("pd_to_dyn_shape_cinn_pass", 1) {}
1533+
1534+
pir::RewritePatternSet PdOpToDynamicShapeCinnOpPass::InitializePatterns(
1535+
pir::IrContext *context) {
1536+
pir::RewritePatternSet ps(context);
1537+
ps.Add<ArangeOpPattern>(context);
1538+
return ps;
1539+
}
1540+
1541+
bool PdOpToDynamicShapeCinnOpPass::CanApplyOn(pir::Operation *op) const {
1542+
return op->num_regions() > 0;
1543+
}
1544+
1545+
std::unique_ptr<pir::Pass> CreatePdOpToDynamicShapeCinnOpPass() {
1546+
return std::make_unique<PdOpToDynamicShapeCinnOpPass>();
1547+
}
1548+
14721549
} // namespace ir
14731550
} // namespace dialect
14741551
} // namespace cinn

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,17 @@ class PdOpToCinnOpPass : public pir::PatternRewritePass {
3131
bool CanApplyOn(pir::Operation *op) const override;
3232
};
3333

34+
class PdOpToDynamicShapeCinnOpPass : public pir::PatternRewritePass {
35+
public:
36+
PdOpToDynamicShapeCinnOpPass();
37+
38+
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;
39+
40+
bool CanApplyOn(pir::Operation *op) const override;
41+
};
42+
3443
IR_API std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass();
44+
IR_API std::unique_ptr<pir::Pass> CreatePdOpToDynamicShapeCinnOpPass();
3545

3646
} // namespace ir
3747
} // namespace dialect

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/cinn/common/dim_expr_converter.h"
2323
#include "paddle/cinn/common/shape_constraint.h"
2424
#include "paddle/cinn/common/target.h"
25+
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
2526
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
2627
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
2728
#include "paddle/cinn/hlir/framework/compile_error.h"
@@ -666,6 +667,79 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
666667
return funcs;
667668
}
668669

670+
/**
671+
* This function converts pir::Value::defining_op for ir::Tensor::operation
672+
* Normally, ir::Tensor::operation will only be used to record the name
673+
* of the compiler-generated var name, which is useless. However, operation
674+
* has Attributes field, so can be used to record the op info.
675+
*/
676+
ir::PlaceholderOp* TensorOperationRecording(const ::pir::Value& value) {
677+
// TODO(heqianyue): I think this is kinda ugly, since we should manually
678+
// specify the rules to convert all the op (and their attribute), yet current
679+
// implementation works and can be quickly written.
680+
const ::pir::Operation* define_op = value.defining_op();
681+
ir::PlaceholderOp* res = nullptr;
682+
if (!define_op) return res;
683+
res = cinn::common::make_shared<ir::PlaceholderOp>();
684+
res->name = define_op->name();
685+
// we filter some of the ops, and only record the **needed** attributes
686+
if (define_op->name() == "pd_op.full") {
687+
auto dtype = define_op->attribute("dtype")
688+
.dyn_cast<paddle::dialect::DataTypeAttribute>()
689+
.data();
690+
phi::Scalar data = define_op->attribute("value")
691+
.dyn_cast<paddle::dialect::ScalarAttribute>()
692+
.data();
693+
ir::Expr value;
694+
#define DEFINE_CASE(TypeFlag, Type) \
695+
case phi::DataType::TypeFlag: \
696+
value = ir::Expr(data.to<Type>()); \
697+
break;
698+
switch (dtype) {
699+
DEFINE_CASE(FLOAT32, float)
700+
DEFINE_CASE(FLOAT64, double)
701+
DEFINE_CASE(INT32, int)
702+
DEFINE_CASE(BFLOAT16, float)
703+
value->set_type(cinn::common::BFloat16());
704+
break;
705+
DEFINE_CASE(FLOAT16, float)
706+
value->set_type(cinn::common::Float16());
707+
break;
708+
default:
709+
value = ir::Expr(data.to<int64_t>());
710+
}
711+
#undef DEFINE_CASE
712+
res->attrs.emplace("value", value);
713+
} else if (define_op->name() == "cinn_op.generate_shape") {
714+
// pir::Attribute --> symbol::DimExpr --> ir::Expr
715+
716+
auto ir_dim_expr = [&]() {
717+
auto dim_expr_attr = define_op->attribute("output_dim_exprs");
718+
auto dim_exprs = dialect::ConvertAttributeToDimExprs(dim_expr_attr);
719+
720+
PADDLE_ENFORCE_EQ(
721+
dim_exprs.has_value(),
722+
true,
723+
::common::errors::PreconditionNotMet(
724+
"Required success to execute convert attribute to dim exprs."));
725+
726+
auto expr_vec = dim_exprs.value();
727+
PADDLE_ENFORCE_EQ(
728+
expr_vec.empty(),
729+
false,
730+
::common::errors::PreconditionNotMet(
731+
"Generate shape op can not yield empty symbolic shape."));
732+
// only the first dim_expr matters for ArangeOp
733+
return common::DimExprConverter().ConvertToIrExpr(expr_vec[0]);
734+
}();
735+
res->attrs.emplace("value", ir_dim_expr);
736+
} else {
737+
VLOG(6) << "Tensor defining op recording: not currently supported op.";
738+
return nullptr;
739+
}
740+
return res;
741+
}
742+
669743
ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
670744
const ::pir::Value& value) {
671745
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
@@ -704,6 +778,9 @@ ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
704778
tensor->set_value(*tensor_value);
705779
}
706780
}
781+
if (auto op_ptr = TensorOperationRecording(value)) {
782+
tensor->operation = ir::FunctionRef(op_ptr);
783+
}
707784
return tensor;
708785
}
709786

0 commit comments

Comments
 (0)