Skip to content

Commit c24f3fb

Browse files
seantaltsGoogle-ML-Automation
authored andcommitted
[XLA:CPU] Refactor Intrinsic and use it in all math intrinsics.
Will rename codegen/math to codegen/intrinsic in subsequent CL. PiperOrigin-RevId: 784775357
1 parent ae18ab8 commit c24f3fb

34 files changed

+640
-901
lines changed

xla/codegen/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ cc_library(
198198
"@com_google_absl//absl/container:flat_hash_map",
199199
"@com_google_absl//absl/container:flat_hash_set",
200200
"@com_google_absl//absl/log:check",
201+
"@com_google_absl//absl/strings",
201202
"@com_google_absl//absl/strings:string_view",
203+
"@com_google_absl//absl/types:span",
202204
"@llvm-project//llvm:Analysis",
203205
"@llvm-project//llvm:ExecutionEngine",
204206
"@llvm-project//llvm:IPO",

xla/codegen/emitters/transforms/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ cc_library(
6262
":atomic_rmw_utils",
6363
":convert_pure_call_ops_pass",
6464
":passes_inc_gen",
65-
":propagate_alias_scopes",
65+
":propagate_alias_scopes", # buildcleaner: keep
6666
":simplify_affine_pass",
6767
":simplify_arith_pass",
6868
"//xla:shape_util",
@@ -75,11 +75,11 @@ cc_library(
7575
"//xla/codegen/emitters:implicit_arith_op_builder",
7676
"//xla/codegen/emitters/ir:xla",
7777
"//xla/codegen/math:erf",
78+
"//xla/codegen/math:exp",
7879
"//xla/codegen/math:fptrunc",
7980
"//xla/codegen/math:intrinsic",
8081
"//xla/codegen/math:log1p",
8182
"//xla/hlo/analysis:indexing_analysis",
82-
"//xla/mlir/utils:type_util",
8383
"//xla/mlir_hlo",
8484
"//xla/mlir_hlo:map_mhlo_to_scalar_op",
8585
"//xla/service/gpu:ir_emission_utils",

xla/codegen/emitters/transforms/lower_xla_math_lib.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ limitations under the License.
3333
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3434
#include "xla/codegen/emitters/transforms/passes.h"
3535
#include "xla/codegen/math/erf.h"
36+
#include "xla/codegen/math/exp.h"
3637
#include "xla/codegen/math/fptrunc.h"
3738
#include "xla/codegen/math/intrinsic.h"
3839
#include "xla/codegen/math/log1p.h"
39-
#include "xla/mlir/utils/type_util.h"
4040

4141
namespace xla {
4242
namespace emitters {
@@ -46,7 +46,8 @@ namespace emitters {
4646

4747
namespace {
4848

49-
using Intrinsic = ::xla::codegen::Intrinsic;
49+
using Type = ::xla::codegen::intrinsics::Type;
50+
// TODO(talts): Add LowerMathOpPattern based on MathFunction instances.
5051

5152
mlir::func::FuncOp GetOrInsertDeclaration(mlir::PatternRewriter& rewriter,
5253
mlir::ModuleOp& module_op,
@@ -91,8 +92,8 @@ class LowerExpOpPattern : public mlir::OpRewritePattern<mlir::math::ExpOp> {
9192
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
9293

9394
mlir::func::FuncOp xla_exp_func =
94-
GetOrInsertDeclaration(rewriter, module_op_, "xla.exp.f64",
95-
rewriter.getFunctionType(op_type, op_type));
95+
codegen::intrinsics::Exp::GetOrInsertDeclaration(
96+
rewriter, module_op_, Type::TypeFromIrType(op_type));
9697

9798
// Replace math.exp with call to xla.exp.f64
9899
auto call_op =
@@ -114,14 +115,11 @@ class LowerLog1pPattern : public mlir::OpRewritePattern<mlir::math::Log1pOp> {
114115
mlir::LogicalResult matchAndRewrite(
115116
mlir::math::Log1pOp op, mlir::PatternRewriter& rewriter) const override {
116117
mlir::Type type = op.getType();
117-
PrimitiveType primitive_type = ConvertMlirTypeToPrimitiveType(type);
118118

119119
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
120120

121-
auto log1p_decl = GetOrInsertDeclaration(
122-
rewriter, module_op_,
123-
codegen::math::Log1pFunctionName(1, primitive_type),
124-
rewriter.getFunctionType(type, type));
121+
auto log1p_decl = codegen::intrinsics::Log1p::GetOrInsertDeclaration(
122+
rewriter, module_op_, Type::TypeFromIrType(type));
125123
auto call_op = b.create<mlir::func::CallOp>(log1p_decl, op.getOperand());
126124
rewriter.replaceOp(op, call_op->getResults());
127125
return mlir::success();
@@ -149,9 +147,8 @@ class LowerErfPattern : public mlir::OpRewritePattern<mlir::math::ErfOp> {
149147
mlir::Value input_value =
150148
b.create<mlir::arith::ExtFOp>(f32_type, op.getOperand());
151149

152-
auto erf_decl = GetOrInsertDeclaration(
153-
rewriter, module_op_, codegen::math::ErfFunctionName(1, F32),
154-
rewriter.getFunctionType(f32_type, f32_type));
150+
auto erf_decl = codegen::intrinsics::Erf::GetOrInsertDeclaration(
151+
rewriter, module_op_, Type::TypeFromIrType(f32_type));
155152
auto call_op = b.create<mlir::func::CallOp>(erf_decl, input_value);
156153

157154
mlir::Value f32_result = call_op.getResult(0);
@@ -204,8 +201,11 @@ class LowerTruncF32BF16FPattern
204201

205202
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
206203

207-
auto f32_to_bf16_decl = Intrinsic::FpTrunc::GetOrInsertDeclaration(
208-
rewriter, module_op_, Intrinsic::S(F32), Intrinsic::S(BF16));
204+
Type src_type = Type::S(F32);
205+
Type dst_type = Type::S(BF16);
206+
auto f32_to_bf16_decl =
207+
codegen::intrinsics::FpTrunc::GetOrInsertDeclaration(
208+
rewriter, module_op_, src_type, dst_type);
209209
auto call_op =
210210
b.create<mlir::func::CallOp>(f32_to_bf16_decl, op.getOperand());
211211
rewriter.replaceOp(op, call_op->getResults());

xla/codegen/emitters/transforms/tests/lower_xla_math_lib.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ module {
3636

3737
// CHECK: func @exp_f64_vector
3838
// CHECK-NOT: math.exp %arg0 : vector<4xf64>
39-
// CHECK: @xla.exp.f64
39+
// CHECK: @xla.exp.v4f64
4040

4141
// -----
4242

xla/codegen/math/BUILD

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,14 @@ cc_library(
3131
"//xla:xla_data_proto_cc",
3232
"//xla/mlir/utils:type_util",
3333
"//xla/service/llvm_ir:llvm_util",
34-
"@com_google_absl//absl/functional:overload",
3534
"@com_google_absl//absl/status",
3635
"@com_google_absl//absl/status:statusor",
3736
"@com_google_absl//absl/strings",
38-
"@llvm-project//llvm:Core",
3937
"@llvm-project//llvm:Support",
4038
"@llvm-project//llvm:ir_headers",
4139
"@llvm-project//mlir:FuncDialect",
4240
"@llvm-project//mlir:IR",
41+
"@llvm-project//mlir:Support",
4342
],
4443
)
4544

@@ -61,19 +60,14 @@ cc_library(
6160
hdrs = ["fptrunc.h"],
6261
deps = [
6362
":intrinsic",
64-
"//xla:shape_util",
6563
"//xla:util",
6664
"//xla:xla_data_proto_cc",
67-
"//xla/service/llvm_ir:llvm_util",
6865
"//xla/tsl/platform:errors",
6966
"@com_google_absl//absl/log:check",
7067
"@com_google_absl//absl/status:statusor",
7168
"@com_google_absl//absl/strings",
72-
"@llvm-project//llvm:Core",
7369
"@llvm-project//llvm:Support",
7470
"@llvm-project//llvm:ir_headers",
75-
"@llvm-project//mlir:FuncDialect",
76-
"@llvm-project//mlir:IR",
7771
],
7872
)
7973

@@ -96,6 +90,7 @@ cc_library(
9690
":intrinsic",
9791
"//xla:xla_data_proto_cc",
9892
"@com_google_absl//absl/log:check",
93+
"@com_google_absl//absl/status:statusor",
9994
"@com_google_absl//absl/strings",
10095
"@llvm-project//llvm:Core", # buildcleaner: keep
10196
"@llvm-project//llvm:Support",
@@ -114,7 +109,6 @@ xla_cc_test(
114109
"//xla:xla_data_proto_cc",
115110
"@com_google_googletest//:gtest_main",
116111
"@llvm-project//llvm:JITLink",
117-
"@llvm-project//llvm:Support",
118112
"@llvm-project//llvm:ir_headers",
119113
],
120114
)
@@ -207,6 +201,7 @@ cc_library(
207201
hdrs = ["math_compiler_lib.h"],
208202
deps = [
209203
"@com_google_absl//absl/container:flat_hash_set",
204+
"@com_google_absl//absl/strings",
210205
"@com_google_absl//absl/strings:string_view",
211206
"@llvm-project//llvm:Analysis",
212207
"@llvm-project//llvm:IPO",
@@ -239,9 +234,9 @@ cc_library(
239234
deps = [
240235
":intrinsic",
241236
":ldexp",
242-
"//xla:shape_util",
243237
"//xla:xla_data_proto_cc",
244238
"@com_google_absl//absl/log:check",
239+
"@com_google_absl//absl/status:statusor",
245240
"@com_google_absl//absl/strings",
246241
"@llvm-project//llvm:Analysis",
247242
"@llvm-project//llvm:Core", # buildcleaner: keep
@@ -277,6 +272,7 @@ xla_cc_test(
277272
name = "log1p_test",
278273
srcs = ["log1p_test.cc"],
279274
deps = [
275+
":intrinsic",
280276
":log1p",
281277
":simple_jit_runner",
282278
":test_matchers",
@@ -299,7 +295,6 @@ xla_cc_test(
299295
"//xla:xla_data_proto_cc",
300296
"@com_google_googletest//:gtest_main",
301297
"@llvm-project//llvm:JITLink",
302-
"@llvm-project//llvm:Support",
303298
"@llvm-project//llvm:ir_headers",
304299
],
305300
)
@@ -332,6 +327,7 @@ cc_library(
332327
"//xla:xla_data_proto_cc",
333328
"//xla/service/llvm_ir:llvm_util",
334329
"@com_google_absl//absl/log:check",
330+
"@com_google_absl//absl/status:statusor",
335331
"@com_google_absl//absl/strings",
336332
"@llvm-project//llvm:Support",
337333
"@llvm-project//llvm:ir_headers",
@@ -343,6 +339,7 @@ xla_cc_test(
343339
srcs = ["erf_test.cc"],
344340
deps = [
345341
":erf",
342+
":intrinsic",
346343
":simple_jit_runner",
347344
":test_matchers",
348345
"@com_google_googletest//:gtest_main",
@@ -361,7 +358,6 @@ xla_cc_test(
361358
":test_matchers",
362359
"//xla:shape_util",
363360
"//xla:xla_data_proto_cc",
364-
"//xla/service/llvm_ir:llvm_util",
365361
"@com_google_googletest//:gtest_main",
366362
"@llvm-project//llvm:JITLink",
367363
"@llvm-project//llvm:Support",
@@ -381,10 +377,8 @@ xla_cc_test(
381377
":simple_jit_runner",
382378
"//xla:shape_util",
383379
"//xla:xla_data_proto_cc",
384-
"//xla/service/llvm_ir:llvm_util",
385380
"//xla/tsl/platform:test_benchmark",
386381
"//xla/tsl/platform:test_main",
387-
"@llvm-project//llvm:Support",
388382
"@llvm-project//llvm:ir_headers",
389383
],
390384
)

xla/codegen/math/erf.cc

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@ limitations under the License.
1515

1616
#include "xla/codegen/math/erf.h"
1717

18-
#include <cstddef>
19-
#include <cstdint>
20-
#include <string>
21-
2218
#include "absl/log/check.h"
23-
#include "absl/strings/str_cat.h"
2419
#include "llvm/IR/Argument.h"
2520
#include "llvm/IR/BasicBlock.h"
2621
#include "llvm/IR/Constants.h"
@@ -34,7 +29,7 @@ limitations under the License.
3429
#include "xla/codegen/math/intrinsic.h"
3530
#include "xla/service/llvm_ir/llvm_util.h"
3631

37-
namespace xla::codegen {
32+
namespace xla::codegen::intrinsics {
3833

3934
// Emits an approximation of erf. The implementation uses the same rational
4035
// interpolant as implemented in Eigen3.
@@ -109,33 +104,9 @@ static llvm::Value* EmitErfF32(llvm::IRBuilderBase* b, llvm::Value* x) {
109104
return result;
110105
}
111106

112-
std::string Intrinsic::Erf::Name(PrimitiveType type) {
113-
return absl::StrCat("xla.erf.", ScalarName(type));
114-
}
115-
116-
std::string Intrinsic::Erf::Name(PrimitiveType type, int64_t vector_width) {
117-
return absl::StrCat(Name(type), ".v", vector_width);
118-
}
119-
120-
llvm::Function* Intrinsic::Erf::GetOrInsertDeclaration(llvm::Module* module,
121-
PrimitiveType type) {
122-
auto* llvm_type = llvm_ir::PrimitiveTypeToIrType(type, module->getContext());
123-
auto* function_type = llvm::FunctionType::get(llvm_type, {llvm_type}, false);
124-
return llvm::cast<llvm::Function>(
125-
module->getOrInsertFunction(Name(type), function_type).getCallee());
126-
}
127-
128-
namespace math {
129-
130-
std::string ErfFunctionName(size_t num_elements, PrimitiveType type) {
131-
if (num_elements > 1) {
132-
return Intrinsic::Erf::Name(type, num_elements);
133-
}
134-
135-
return Intrinsic::Erf::Name(type);
136-
}
137-
138-
llvm::Function* CreateErf(llvm::Module* module, llvm::Type* type) {
107+
absl::StatusOr<llvm::Function*> Erf::CreateDefinition(
108+
llvm::Module* module, const Type intrinsic_type) {
109+
llvm::Type* type = Type::TypeToIrType(intrinsic_type, module->getContext());
139110
CHECK(type != nullptr);
140111
CHECK(type->isFloatTy() ||
141112
(type->isVectorTy() && type->getScalarType()->isFloatTy()))
@@ -149,14 +120,10 @@ llvm::Function* CreateErf(llvm::Module* module, llvm::Type* type) {
149120
num_elements = vec_ty->getElementCount().getKnownMinValue();
150121
}
151122

152-
PrimitiveType primitive_type = llvm_ir::PrimitiveTypeFromIrType(type);
153-
154123
llvm::FunctionType* function_type =
155124
llvm::FunctionType::get(type, {type}, false);
156125
llvm::Function* func = llvm::dyn_cast<llvm::Function>(
157-
module
158-
->getOrInsertFunction(ErfFunctionName(num_elements, primitive_type),
159-
function_type)
126+
module->getOrInsertFunction(Name(intrinsic_type), function_type)
160127
.getCallee());
161128

162129
llvm::Argument* input_value = func->getArg(0);
@@ -170,5 +137,4 @@ llvm::Function* CreateErf(llvm::Module* module, llvm::Type* type) {
170137
return func;
171138
}
172139

173-
} // namespace math
174-
} // namespace xla::codegen
140+
} // namespace xla::codegen::intrinsics

xla/codegen/math/erf.h

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,30 @@ limitations under the License.
1616
#ifndef XLA_CODEGEN_MATH_ERF_H_
1717
#define XLA_CODEGEN_MATH_ERF_H_
1818

19-
#include <cstddef>
20-
#include <cstdint>
21-
#include <string>
19+
#include <vector>
2220

21+
#include "absl/status/statusor.h"
22+
#include "absl/strings/string_view.h"
2323
#include "llvm/IR/Function.h"
2424
#include "xla/codegen/math/intrinsic.h"
2525
#include "xla/xla_data.pb.h"
2626

27-
namespace xla::codegen {
27+
namespace xla::codegen::intrinsics {
2828

29-
class Intrinsic::Erf {
29+
class Erf : public Intrinsic<Erf> {
3030
public:
31-
static std::string Name(PrimitiveType type);
32-
static std::string Name(PrimitiveType type, int64_t vector_width);
33-
34-
static llvm::Function* GetOrInsertDeclaration(llvm::Module* module,
35-
PrimitiveType type);
31+
static constexpr absl::string_view kName = "erf";
32+
static std::vector<std::vector<Type>> SupportedVectorTypes() {
33+
return {{Type::S(F32)},
34+
{Type::V(F32, 2)},
35+
{Type::V(F32, 4)},
36+
{Type::V(F32, 8)}};
37+
}
38+
39+
static absl::StatusOr<llvm::Function*> CreateDefinition(llvm::Module* module,
40+
Type type);
3641
};
3742

38-
namespace math {
39-
40-
// Return the XLA intrinsic name for the erf function:
41-
//
42-
// `xla.erf.v<num_elements><type>`
43-
std::string ErfFunctionName(size_t num_elements, PrimitiveType type);
44-
45-
llvm::Function* CreateErf(llvm::Module* module, llvm::Type* type);
46-
47-
} // namespace math
48-
} // namespace xla::codegen
43+
} // namespace xla::codegen::intrinsics
4944

5045
#endif // XLA_CODEGEN_MATH_ERF_H_

0 commit comments

Comments
 (0)