Skip to content

Restore alternatives lowering #1116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,14 +998,33 @@ void print(Stream &stream, const std::tuple<Args...> &t) {
TuplePrinter<Stream, decltype(t), sizeof...(Args)>::print(stream, t);
}

#if 0
enum PolygeistAlternativesMode { PAM_Static, PAM_PGO_Profile, PAM_PGO_Opt };
enum PolygeistGPUStructureMode {
PGSM_Discard,
PGSM_BlockThreadWrappers,
PGSM_ThreadNoop,
PGSM_BlockThreadNoops
};

const std::string POLYGEIST_PGO_DEFAULT_DATA_DIR = ".polygeist_pgo";

llvm::cl::opt<PolygeistAlternativesMode> PolygeistAlternativesMode(
"polygeist-alternatives-mode", llvm::cl::init(PAM_Static),
llvm::cl::desc("Polygeist alternatives op mode"),
llvm::cl::values(
clEnumValN(PAM_Static, "static", "Pick at compile time"),
clEnumValN(PAM_PGO_Profile, "pgo_prof",
"Profile Guided Optimization - profiling mode"),
clEnumValN(PAM_PGO_Opt, "pgo_opt",
"Profile Guided Optimization - optimization mode")));

struct LowerGPUAlternativesOp
: public OpRewritePattern<polygeist::AlternativesOp>,
: public OpRewritePattern<enzymexla::AlternativesOp>,
public GpuRuntimeCallBuilders {
using OpRewritePattern<polygeist::AlternativesOp>::OpRewritePattern;
using OpRewritePattern<enzymexla::AlternativesOp>::OpRewritePattern;
const char *PATTERN = "lower-gpu-alternatives";

LogicalResult matchAndRewrite(polygeist::AlternativesOp gao,
LogicalResult matchAndRewrite(enzymexla::AlternativesOp gao,
PatternRewriter &rewriter) const override {

if (gao->getAttrOfType<StringAttr>("alternatives.type").getValue() !=
Expand All @@ -1014,7 +1033,7 @@ struct LowerGPUAlternativesOp

Location loc = gao->getLoc();
std::string locStr =
gao->getAttrOfType<StringAttr>("polygeist.altop.id").data();
gao->getAttrOfType<StringAttr>("enzymexla.altop.id").data();

auto descs = gao->getAttrOfType<ArrayAttr>("alternatives.descs");

Expand Down Expand Up @@ -1080,6 +1099,8 @@ struct LowerGPUAlternativesOp
};
};

auto gpuBinaryAnnotation = "gpu.binary";

#if POLYGEIST_ENABLE_CUDA
if (gpuTarget == "cuda") {
char cuErrorBuffer[4096] = {0};
Expand Down Expand Up @@ -1337,7 +1358,7 @@ struct LowerGPUAlternativesOp
nullTermLocStr.push_back('\0');
auto kernelId = LLVM::createGlobalString(
loc, rewriter, std::string("kernelId.") + std::to_string(num++),
nullTermLocStr, LLVM::Linkage::Internal, /*opaquePointers*/ true);
nullTermLocStr, LLVM::Linkage::Internal);
auto totalAlternatives = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, gao->getNumRegions());
auto alternative =
Expand Down Expand Up @@ -1373,7 +1394,7 @@ struct LowerGPUAlternativesOp
return success();
} else if (PolygeistAlternativesMode == PAM_PGO_Opt) {
std::string dirname = []() {
if (char *d = getenv(POLYGEIST_PGO_DATA_DIR_ENV_VAR)) {
if (char *d = getenv("POLYGEIST_PGO_DATA_DIR")) {
return std::string(d);
} else {
return std::string(POLYGEIST_PGO_DEFAULT_DATA_DIR);
Expand Down Expand Up @@ -1433,15 +1454,13 @@ struct LowerGPUAlternativesOp
}

LowerGPUAlternativesOp(MLIRContext *context, LLVMTypeConverter &typeConverter,
StringRef gpuBinaryAnnotation, StringRef gpuTarget)
: OpRewritePattern<polygeist::AlternativesOp>(context),
StringRef gpuTarget)
: OpRewritePattern<enzymexla::AlternativesOp>(context),
GpuRuntimeCallBuilders(context, typeConverter),
gpuBinaryAnnotation(gpuBinaryAnnotation), gpuTarget(gpuTarget) {}
gpuTarget(gpuTarget) {}

llvm::SmallString<32> gpuBinaryAnnotation;
llvm::SmallString<4> gpuTarget;
};
#endif

// Creates a struct containing all kernel parameters on the stack and returns
// an array of type-erased pointers to the fields of the struct. The array can
Expand Down Expand Up @@ -3118,6 +3137,17 @@ struct ConvertPolygeistToLLVMPass
signalPassFailure();
});
}
{

{
// This op must be lowered before converting to LLVM but it still needs
// information about LLVM types thus it needs the converter
RewritePatternSet patterns(&getContext());
patterns.add<LowerGPUAlternativesOp>(&getContext(), converter,
backend);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
signalPassFailure();
}

LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
Loading