Skip to content

Commit 7e49000

Browse files
authored
[OpenMP] fix openmp (rust-lang#243)
1 parent 514f044 commit 7e49000

File tree

12 files changed

+614
-66
lines changed

12 files changed

+614
-66
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ std::set<std::string> KnownInactiveFunctions = {
112112
"__kmpc_dispatch_fini_4u",
113113
"__kmpc_dispatch_fini_8",
114114
"__kmpc_dispatch_fini_8u",
115+
"__kmpc_barrier",
116+
"__kmpc_barrier_master",
117+
"__kmpc_barrier_master_nowait",
118+
"__kmpc_barrier_end_barrier_master",
119+
"omp_get_max_threads",
115120
"malloc_usable_size",
116121
"malloc_size",
117122
"MPI_Init",

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 207 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,9 @@ class AdjointGenerator
27682768
std::vector<Instruction *> postCreate;
27692769
std::vector<Instruction *> userReplace;
27702770

2771+
SmallVector<Value *, 0> OutTypes;
2772+
SmallVector<Type *, 0> OutFPTypes;
2773+
27712774
for (unsigned i = 3; i < call.getNumArgOperands(); ++i) {
27722775

27732776
auto argi = gutils->getNewFromOriginal(call.getArgOperand(i));
@@ -2820,7 +2823,9 @@ class AdjointGenerator
28202823
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
28212824
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
28222825
} else {
2823-
assert(0 && "out for omp not handled");
2826+
assert(TR.query(call.getArgOperand(i)).Inner0().isFloat());
2827+
OutTypes.push_back(call.getArgOperand(i));
2828+
OutFPTypes.push_back(argType);
28242829
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
28252830
assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
28262831
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
@@ -2905,7 +2910,6 @@ class AdjointGenerator
29052910
subdata->returns.end()) {
29062911
ValueToValueMapTy VMap;
29072912
newcalled = CloneFunction(newcalled, VMap);
2908-
llvm::errs() << *newcalled << "\n";
29092913
auto tapeArg = newcalled->arg_end();
29102914
tapeArg--;
29112915
std::vector<std::pair<ssize_t, Value *>> geps;
@@ -2926,11 +2930,14 @@ class AdjointGenerator
29262930
for (auto SI : storesToErase)
29272931
SI->eraseFromParent();
29282932
gepsToErase.insert(gep);
2929-
}
2930-
if (auto SI = dyn_cast<StoreInst>(a)) {
2933+
} else if (auto SI = dyn_cast<StoreInst>(a)) {
29312934
Value *op = SI->getValueOperand();
29322935
gepsToErase.insert(SI);
29332936
geps.emplace_back(-1, op);
2937+
} else {
2938+
llvm::errs() << "unknown tape user: " << a << "\n";
2939+
assert(0 && "unknown tape user");
2940+
llvm_unreachable("unknown tape user");
29342941
}
29352942
}
29362943
for (auto gep : gepsToErase)
@@ -2988,6 +2995,7 @@ class AdjointGenerator
29882995
: ph.CreateInBoundsGEP(tapeArg, Idxs)));
29892996
cast<Instruction>(op)->eraseFromParent();
29902997
}
2998+
assert(tape);
29912999
auto alloc =
29923000
IRBuilder<>(gutils->inversionAllocs)
29933001
.CreateAlloca(
@@ -3025,7 +3033,6 @@ class AdjointGenerator
30253033

30263034
auto found = subdata->returns.find(AugmentedStruct::DifferentialReturn);
30273035
assert(found == subdata->returns.end());
3028-
;
30293036

30303037
found = subdata->returns.find(AugmentedStruct::Return);
30313038
assert(found == subdata->returns.end());
@@ -3035,14 +3042,22 @@ class AdjointGenerator
30353042
IRBuilder<> Builder2(call.getParent());
30363043
getReverseBuilder(Builder2);
30373044

3038-
Value *newcalled = nullptr;
3045+
if (Mode == DerivativeMode::ReverseModeGradient)
3046+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3047+
3048+
Function *newcalled = nullptr;
30393049
if (called) {
30403050
if (subdata->returns.find(AugmentedStruct::Tape) !=
30413051
subdata->returns.end()) {
30423052
if (Mode == DerivativeMode::ReverseModeGradient) {
3053+
if (tape == nullptr)
3054+
tape = Builder2.CreatePHI(Type::getInt8Ty(call.getContext()), 0,
3055+
"tapeArg");
30433056
tape = gutils->cacheForReverse(Builder2, tape,
3044-
getIndex(&call, CacheType::Tape));
3057+
getIndex(&call, CacheType::Tape),
3058+
/*ignoreType*/ true);
30453059
}
3060+
tape = lookup(tape, Builder2);
30463061
auto alloc = IRBuilder<>(gutils->inversionAllocs)
30473062
.CreateAlloca(tape->getType());
30483063
Builder2.CreateStore(tape, alloc);
@@ -3057,6 +3072,115 @@ class AdjointGenerator
30573072
nextTypeInfo, uncacheable_args, subdata, /*AtomicAdd*/ true,
30583073
/*postopt*/ false, /*omp*/ true);
30593074

3075+
Value *OutAlloc = nullptr;
3076+
if (OutTypes.size()) {
3077+
auto ST = StructType::get(newcalled->getContext(), OutFPTypes);
3078+
OutAlloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(ST);
3079+
args.push_back(OutAlloc);
3080+
3081+
SmallVector<Type *, 3> MetaTypes;
3082+
for (auto P :
3083+
cast<Function>(newcalled)->getFunctionType()->params()) {
3084+
MetaTypes.push_back(P);
3085+
}
3086+
MetaTypes.push_back(PointerType::getUnqual(ST));
3087+
auto FT = FunctionType::get(Type::getVoidTy(newcalled->getContext()),
3088+
MetaTypes, false);
3089+
#if LLVM_VERSION_MAJOR >= 10
3090+
Function *F =
3091+
Function::Create(FT, GlobalVariable::InternalLinkage,
3092+
cast<Function>(newcalled)->getName() + "#out",
3093+
*task->getParent());
3094+
#else
3095+
Function *F = Function::Create(
3096+
FT, GlobalVariable::InternalLinkage,
3097+
cast<Function>(newcalled)->getName() + "#out", task->getParent());
3098+
#endif
3099+
BasicBlock *entry =
3100+
BasicBlock::Create(newcalled->getContext(), "entry", F);
3101+
IRBuilder<> B(entry);
3102+
SmallVector<Value *, 2> SubArgs;
3103+
for (auto &arg : F->args())
3104+
SubArgs.push_back(&arg);
3105+
Value *cacheArg = SubArgs.back();
3106+
SubArgs.pop_back();
3107+
Value *outdiff = B.CreateCall(newcalled, SubArgs);
3108+
for (size_t ee = 0; ee < OutTypes.size(); ee++) {
3109+
Value *dif = B.CreateExtractValue(outdiff, ee);
3110+
Value *Idxs[] = {
3111+
ConstantInt::get(Type::getInt64Ty(ST->getContext()), 0),
3112+
ConstantInt::get(Type::getInt32Ty(ST->getContext()), ee)};
3113+
Value *ptr = B.CreateInBoundsGEP(cacheArg, Idxs);
3114+
3115+
if (dif->getType()->isIntOrIntVectorTy()) {
3116+
3117+
ptr = B.CreateBitCast(
3118+
ptr,
3119+
PointerType::get(
3120+
IntToFloatTy(dif->getType()),
3121+
cast<PointerType>(ptr->getType())->getAddressSpace()));
3122+
dif = B.CreateBitCast(dif, IntToFloatTy(dif->getType()));
3123+
}
3124+
3125+
#if LLVM_VERSION_MAJOR >= 10
3126+
MaybeAlign align;
3127+
#elif LLVM_VERSION_MAJOR >= 9
3128+
unsigned align = 0;
3129+
#endif
3130+
3131+
#if LLVM_VERSION_MAJOR >= 9
3132+
AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
3133+
if (auto vt = dyn_cast<VectorType>(dif->getType())) {
3134+
#if LLVM_VERSION_MAJOR >= 12
3135+
assert(!vt->getElementCount().isScalable());
3136+
size_t numElems = vt->getElementCount().getKnownMinValue();
3137+
#else
3138+
size_t numElems = vt->getNumElements();
3139+
#endif
3140+
for (size_t i = 0; i < numElems; ++i) {
3141+
auto vdif = B.CreateExtractElement(dif, i);
3142+
Value *Idxs[] = {
3143+
ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0),
3144+
ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)};
3145+
auto vptr = B.CreateGEP(ptr, Idxs);
3146+
#if LLVM_VERSION_MAJOR >= 13
3147+
B.CreateAtomicRMW(op, vptr, vdif, align,
3148+
AtomicOrdering::Monotonic, SyncScope::System);
3149+
#elif LLVM_VERSION_MAJOR >= 11
3150+
AtomicRMWInst *rmw =
3151+
B.CreateAtomicRMW(op, vptr, vdif, AtomicOrdering::Monotonic,
3152+
SyncScope::System);
3153+
if (align)
3154+
rmw->setAlignment(align.getValue());
3155+
#else
3156+
B.CreateAtomicRMW(op, vptr, vdif, AtomicOrdering::Monotonic,
3157+
SyncScope::System);
3158+
#endif
3159+
}
3160+
} else {
3161+
#if LLVM_VERSION_MAJOR >= 13
3162+
B.CreateAtomicRMW(op, ptr, dif, align, AtomicOrdering::Monotonic,
3163+
SyncScope::System);
3164+
#elif LLVM_VERSION_MAJOR >= 11
3165+
AtomicRMWInst *rmw = B.CreateAtomicRMW(
3166+
op, ptr, dif, AtomicOrdering::Monotonic, SyncScope::System);
3167+
if (align)
3168+
rmw->setAlignment(align.getValue());
3169+
#else
3170+
B.CreateAtomicRMW(op, ptr, dif, AtomicOrdering::Monotonic,
3171+
SyncScope::System);
3172+
#endif
3173+
}
3174+
#else
3175+
llvm::errs() << "unhandled atomic fadd on llvm version " << *ptr
3176+
<< " " << *dif << "\n";
3177+
llvm_unreachable("unhandled atomic fadd");
3178+
#endif
3179+
}
3180+
B.CreateRetVoid();
3181+
newcalled = F;
3182+
}
3183+
30603184
auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
30613185
args.size() - 3);
30623186
args[0] =
@@ -3070,11 +3194,30 @@ class AdjointGenerator
30703194
diffes->setCallingConv(call.getCallingConv());
30713195
diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
30723196

3197+
for (size_t i = 0; i < OutTypes.size(); i++) {
3198+
3199+
size_t size = 1;
3200+
if (OutTypes[i]->getType()->isSized())
3201+
size = (gutils->newFunc->getParent()
3202+
->getDataLayout()
3203+
.getTypeSizeInBits(OutTypes[i]->getType()) +
3204+
7) /
3205+
8;
3206+
Value *Idxs[] = {
3207+
ConstantInt::get(Type::getInt64Ty(call.getContext()), 0),
3208+
ConstantInt::get(Type::getInt32Ty(call.getContext()), i)};
3209+
((DiffeGradientUtils *)gutils)
3210+
->addToDiffe(OutTypes[i],
3211+
Builder2.CreateLoad(
3212+
Builder2.CreateInBoundsGEP(OutAlloc, Idxs)),
3213+
Builder2, TR.addingType(size, OutTypes[i]));
3214+
}
3215+
30733216
if (tape) {
30743217
for (auto idx : subdata->tapeIndiciesToFree) {
30753218
auto ci = cast<CallInst>(CallInst::CreateFree(
30763219
Builder2.CreatePointerCast(
3077-
Builder2.CreateExtractValue(tape, idx),
3220+
idx == -1 ? tape : Builder2.CreateExtractValue(tape, idx),
30783221
Type::getInt8PtrTy(Builder2.getContext())),
30793222
Builder2.GetInsertBlock()));
30803223
ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
@@ -3759,7 +3902,7 @@ class AdjointGenerator
37593902
augmentedReturn->tapeIndices.find(std::make_pair(
37603903
orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
37613904
tape = Builder2.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
3762-
tape = gutils->cacheForReverse(Builder2, (Value *)tape,
3905+
tape = gutils->cacheForReverse(Builder2, tape,
37633906
getIndex(orig, CacheType::Tape),
37643907
/*ignoreType*/ true);
37653908
}
@@ -3814,25 +3957,6 @@ class AdjointGenerator
38143957
return;
38153958
}
38163959

3817-
if (Mode != DerivativeMode::ReverseModePrimal && called) {
3818-
if (funcName == "__kmpc_for_static_init_4" ||
3819-
funcName == "__kmpc_for_static_init_4u" ||
3820-
funcName == "__kmpc_for_static_init_8" ||
3821-
funcName == "__kmpc_for_static_init_8u") {
3822-
IRBuilder<> Builder2(call.getParent());
3823-
getReverseBuilder(Builder2);
3824-
auto fini = called->getParent()->getFunction("__kmpc_for_static_fini");
3825-
assert(fini);
3826-
Value *args[] = {
3827-
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
3828-
lookup(gutils->getNewFromOriginal(call.getArgOperand(1)),
3829-
Builder2)};
3830-
auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
3831-
fcall->setCallingConv(fini->getCallingConv());
3832-
return;
3833-
}
3834-
}
3835-
38363960
if (funcName.startswith("MPI_") && !gutils->isConstantInstruction(&call)) {
38373961
handleMPI(call, called, funcName);
38383962
return;
@@ -3860,6 +3984,61 @@ class AdjointGenerator
38603984
visitOMPCall(call);
38613985
return;
38623986
}
3987+
3988+
if (funcName == "__kmpc_for_static_init_4" ||
3989+
funcName == "__kmpc_for_static_init_4u" ||
3990+
funcName == "__kmpc_for_static_init_8" ||
3991+
funcName == "__kmpc_for_static_init_8u") {
3992+
if (Mode != DerivativeMode::ReverseModePrimal) {
3993+
IRBuilder<> Builder2(call.getParent());
3994+
getReverseBuilder(Builder2);
3995+
auto fini =
3996+
called->getParent()->getFunction("__kmpc_for_static_fini");
3997+
assert(fini);
3998+
Value *args[] = {
3999+
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)),
4000+
Builder2),
4001+
lookup(gutils->getNewFromOriginal(call.getArgOperand(1)),
4002+
Builder2)};
4003+
auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
4004+
fcall->setCallingConv(fini->getCallingConv());
4005+
}
4006+
return;
4007+
}
4008+
if (funcName == "__kmpc_for_static_fini") {
4009+
if (Mode != DerivativeMode::ReverseModePrimal) {
4010+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4011+
}
4012+
return;
4013+
}
4014+
// TODO check
4015+
// Adjoint of barrier is to place a barrier at the corresponding
4016+
// location in the reverse.
4017+
if (funcName == "__kmpc_barrier") {
4018+
if (Mode == DerivativeMode::ReverseModeGradient ||
4019+
Mode == DerivativeMode::ReverseModeCombined) {
4020+
IRBuilder<> Builder2(call.getParent());
4021+
getReverseBuilder(Builder2);
4022+
#if LLVM_VERSION_MAJOR >= 11
4023+
auto callval = call.getCalledOperand();
4024+
#else
4025+
auto callval = call.getCalledValue();
4026+
#endif
4027+
Value *args[] = {
4028+
lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2),
4029+
lookup(gutils->getNewFromOriginal(call.getOperand(1)), Builder2)};
4030+
Builder2.CreateCall(call.getFunctionType(), callval, args);
4031+
}
4032+
return;
4033+
}
4034+
4035+
if (funcName.startswith("__kmpc")) {
4036+
llvm::errs() << *gutils->oldFunc << "\n";
4037+
llvm::errs() << call << "\n";
4038+
assert(0 && "unhandled openmp function");
4039+
llvm_unreachable("unhandled openmp function");
4040+
}
4041+
38634042
if (funcName == "asin" || funcName == "asinf" || funcName == "asinl") {
38644043
if (gutils->knownRecomputeHeuristic.find(orig) !=
38654044
gutils->knownRecomputeHeuristic.end()) {

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,24 @@ static inline bool is_use_directly_needed_in_reverse(
5454

5555
// We don't need any of the input operands to compute the adjoint of a store
5656
// instance
57-
if (isa<StoreInst>(user)) {
57+
if (auto SI = dyn_cast<StoreInst>(user)) {
58+
// The one exception to this is stores to the loop bounds.
59+
if (SI->getValueOperand() == val) {
60+
for (auto U : SI->getPointerOperand()->users()) {
61+
if (auto CI = dyn_cast<CallInst>(U)) {
62+
if (auto F = CI->getCalledFunction()) {
63+
if (F->getName() == "__kmpc_for_static_init_4" ||
64+
F->getName() == "__kmpc_for_static_init_4u" ||
65+
F->getName() == "__kmpc_for_static_init_8" ||
66+
F->getName() == "__kmpc_for_static_init_8u") {
67+
if (CI->getArgOperand(4) == val || CI->getArgOperand(5) == val ||
68+
CI->getArgOperand(6))
69+
return true;
70+
}
71+
}
72+
}
73+
}
74+
}
5875
return false;
5976
}
6077

@@ -137,6 +154,28 @@ static inline bool is_use_directly_needed_in_reverse(
137154
return !gutils->isConstantValue(const_cast<SelectInst *>(si));
138155
}
139156

157+
if (auto CI = dyn_cast<CallInst>(user)) {
158+
if (auto F = CI->getCalledFunction()) {
159+
// Only need primal length and datatype for reverse
160+
if (F->getName() == "MPI_Isend" || F->getName() == "MPI_Irecv") {
161+
if (val != CI->getArgOperand(1) && val != CI->getArgOperand(2)) {
162+
return false;
163+
}
164+
}
165+
// Don't need any primal arguments for mpi_wait
166+
if (F->getName() == "MPI_Wait")
167+
return false;
168+
// Only need element count for reverse of waitall
169+
if (F->getName() == "MPI_Waitall")
170+
if (val != CI->getArgOperand(0))
171+
return false;
172+
// Since adjoint of barrier is another barrier in reverse
173+
// we still need even if instruction is inactive
174+
if (F->getName() == "__kmpc_barrier" || F->getName() == "MPI_Barrier")
175+
return true;
176+
}
177+
}
178+
140179
return !gutils->isConstantInstruction(user) ||
141180
!gutils->isConstantValue(const_cast<Instruction *>(user));
142181
}

0 commit comments

Comments
 (0)