@@ -2768,6 +2768,9 @@ class AdjointGenerator
2768
2768
std::vector<Instruction *> postCreate;
2769
2769
std::vector<Instruction *> userReplace;
2770
2770
2771
+ SmallVector<Value *, 0 > OutTypes;
2772
+ SmallVector<Type *, 0 > OutFPTypes;
2773
+
2771
2774
for (unsigned i = 3 ; i < call.getNumArgOperands (); ++i) {
2772
2775
2773
2776
auto argi = gutils->getNewFromOriginal (call.getArgOperand (i));
@@ -2820,7 +2823,9 @@ class AdjointGenerator
2820
2823
assert (whatType (argType, Mode) == DIFFE_TYPE::DUP_ARG ||
2821
2824
whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
2822
2825
} else {
2823
- assert (0 && " out for omp not handled" );
2826
+ assert (TR.query (call.getArgOperand (i)).Inner0 ().isFloat ());
2827
+ OutTypes.push_back (call.getArgOperand (i));
2828
+ OutFPTypes.push_back (argType);
2824
2829
argsInverted.push_back (DIFFE_TYPE::OUT_DIFF);
2825
2830
assert (whatType (argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
2826
2831
whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
@@ -2905,7 +2910,6 @@ class AdjointGenerator
2905
2910
subdata->returns .end ()) {
2906
2911
ValueToValueMapTy VMap;
2907
2912
newcalled = CloneFunction (newcalled, VMap);
2908
- llvm::errs () << *newcalled << " \n " ;
2909
2913
auto tapeArg = newcalled->arg_end ();
2910
2914
tapeArg--;
2911
2915
std::vector<std::pair<ssize_t , Value *>> geps;
@@ -2926,11 +2930,14 @@ class AdjointGenerator
2926
2930
for (auto SI : storesToErase)
2927
2931
SI->eraseFromParent ();
2928
2932
gepsToErase.insert (gep);
2929
- }
2930
- if (auto SI = dyn_cast<StoreInst>(a)) {
2933
+ } else if (auto SI = dyn_cast<StoreInst>(a)) {
2931
2934
Value *op = SI->getValueOperand ();
2932
2935
gepsToErase.insert (SI);
2933
2936
geps.emplace_back (-1 , op);
2937
+ } else {
2938
+ llvm::errs () << " unknown tape user: " << a << " \n " ;
2939
+ assert (0 && " unknown tape user" );
2940
+ llvm_unreachable (" unknown tape user" );
2934
2941
}
2935
2942
}
2936
2943
for (auto gep : gepsToErase)
@@ -2988,6 +2995,7 @@ class AdjointGenerator
2988
2995
: ph.CreateInBoundsGEP (tapeArg, Idxs)));
2989
2996
cast<Instruction>(op)->eraseFromParent ();
2990
2997
}
2998
+ assert (tape);
2991
2999
auto alloc =
2992
3000
IRBuilder<>(gutils->inversionAllocs )
2993
3001
.CreateAlloca (
@@ -3025,7 +3033,6 @@ class AdjointGenerator
3025
3033
3026
3034
auto found = subdata->returns .find (AugmentedStruct::DifferentialReturn);
3027
3035
assert (found == subdata->returns .end ());
3028
- ;
3029
3036
3030
3037
found = subdata->returns .find (AugmentedStruct::Return);
3031
3038
assert (found == subdata->returns .end ());
@@ -3035,14 +3042,22 @@ class AdjointGenerator
3035
3042
IRBuilder<> Builder2 (call.getParent ());
3036
3043
getReverseBuilder (Builder2);
3037
3044
3038
- Value *newcalled = nullptr ;
3045
+ if (Mode == DerivativeMode::ReverseModeGradient)
3046
+ eraseIfUnused (call, /* erase*/ true , /* check*/ false );
3047
+
3048
+ Function *newcalled = nullptr ;
3039
3049
if (called) {
3040
3050
if (subdata->returns .find (AugmentedStruct::Tape) !=
3041
3051
subdata->returns .end ()) {
3042
3052
if (Mode == DerivativeMode::ReverseModeGradient) {
3053
+ if (tape == nullptr )
3054
+ tape = Builder2.CreatePHI (Type::getInt8Ty (call.getContext ()), 0 ,
3055
+ " tapeArg" );
3043
3056
tape = gutils->cacheForReverse (Builder2, tape,
3044
- getIndex (&call, CacheType::Tape));
3057
+ getIndex (&call, CacheType::Tape),
3058
+ /* ignoreType*/ true );
3045
3059
}
3060
+ tape = lookup (tape, Builder2);
3046
3061
auto alloc = IRBuilder<>(gutils->inversionAllocs )
3047
3062
.CreateAlloca (tape->getType ());
3048
3063
Builder2.CreateStore (tape, alloc);
@@ -3057,6 +3072,115 @@ class AdjointGenerator
3057
3072
nextTypeInfo, uncacheable_args, subdata, /* AtomicAdd*/ true ,
3058
3073
/* postopt*/ false , /* omp*/ true );
3059
3074
3075
+ Value *OutAlloc = nullptr ;
3076
+ if (OutTypes.size ()) {
3077
+ auto ST = StructType::get (newcalled->getContext (), OutFPTypes);
3078
+ OutAlloc = IRBuilder<>(gutils->inversionAllocs ).CreateAlloca (ST);
3079
+ args.push_back (OutAlloc);
3080
+
3081
+ SmallVector<Type *, 3 > MetaTypes;
3082
+ for (auto P :
3083
+ cast<Function>(newcalled)->getFunctionType ()->params ()) {
3084
+ MetaTypes.push_back (P);
3085
+ }
3086
+ MetaTypes.push_back (PointerType::getUnqual (ST));
3087
+ auto FT = FunctionType::get (Type::getVoidTy (newcalled->getContext ()),
3088
+ MetaTypes, false );
3089
+ #if LLVM_VERSION_MAJOR >= 10
3090
+ Function *F =
3091
+ Function::Create (FT, GlobalVariable::InternalLinkage,
3092
+ cast<Function>(newcalled)->getName () + " #out" ,
3093
+ *task->getParent ());
3094
+ #else
3095
+ Function *F = Function::Create (
3096
+ FT, GlobalVariable::InternalLinkage,
3097
+ cast<Function>(newcalled)->getName () + " #out" , task->getParent ());
3098
+ #endif
3099
+ BasicBlock *entry =
3100
+ BasicBlock::Create (newcalled->getContext (), " entry" , F);
3101
+ IRBuilder<> B (entry);
3102
+ SmallVector<Value *, 2 > SubArgs;
3103
+ for (auto &arg : F->args ())
3104
+ SubArgs.push_back (&arg);
3105
+ Value *cacheArg = SubArgs.back ();
3106
+ SubArgs.pop_back ();
3107
+ Value *outdiff = B.CreateCall (newcalled, SubArgs);
3108
+ for (size_t ee = 0 ; ee < OutTypes.size (); ee++) {
3109
+ Value *dif = B.CreateExtractValue (outdiff, ee);
3110
+ Value *Idxs[] = {
3111
+ ConstantInt::get (Type::getInt64Ty (ST->getContext ()), 0 ),
3112
+ ConstantInt::get (Type::getInt32Ty (ST->getContext ()), ee)};
3113
+ Value *ptr = B.CreateInBoundsGEP (cacheArg, Idxs);
3114
+
3115
+ if (dif->getType ()->isIntOrIntVectorTy ()) {
3116
+
3117
+ ptr = B.CreateBitCast (
3118
+ ptr,
3119
+ PointerType::get (
3120
+ IntToFloatTy (dif->getType ()),
3121
+ cast<PointerType>(ptr->getType ())->getAddressSpace ()));
3122
+ dif = B.CreateBitCast (dif, IntToFloatTy (dif->getType ()));
3123
+ }
3124
+
3125
+ #if LLVM_VERSION_MAJOR >= 10
3126
+ MaybeAlign align;
3127
+ #elif LLVM_VERSION_MAJOR >= 9
3128
+ unsigned align = 0 ;
3129
+ #endif
3130
+
3131
+ #if LLVM_VERSION_MAJOR >= 9
3132
+ AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
3133
+ if (auto vt = dyn_cast<VectorType>(dif->getType ())) {
3134
+ #if LLVM_VERSION_MAJOR >= 12
3135
+ assert (!vt->getElementCount ().isScalable ());
3136
+ size_t numElems = vt->getElementCount ().getKnownMinValue ();
3137
+ #else
3138
+ size_t numElems = vt->getNumElements ();
3139
+ #endif
3140
+ for (size_t i = 0 ; i < numElems; ++i) {
3141
+ auto vdif = B.CreateExtractElement (dif, i);
3142
+ Value *Idxs[] = {
3143
+ ConstantInt::get (Type::getInt64Ty (vt->getContext ()), 0 ),
3144
+ ConstantInt::get (Type::getInt32Ty (vt->getContext ()), i)};
3145
+ auto vptr = B.CreateGEP (ptr, Idxs);
3146
+ #if LLVM_VERSION_MAJOR >= 13
3147
+ B.CreateAtomicRMW (op, vptr, vdif, align,
3148
+ AtomicOrdering::Monotonic, SyncScope::System);
3149
+ #elif LLVM_VERSION_MAJOR >= 11
3150
+ AtomicRMWInst *rmw =
3151
+ B.CreateAtomicRMW (op, vptr, vdif, AtomicOrdering::Monotonic,
3152
+ SyncScope::System);
3153
+ if (align)
3154
+ rmw->setAlignment (align.getValue ());
3155
+ #else
3156
+ B.CreateAtomicRMW (op, vptr, vdif, AtomicOrdering::Monotonic,
3157
+ SyncScope::System);
3158
+ #endif
3159
+ }
3160
+ } else {
3161
+ #if LLVM_VERSION_MAJOR >= 13
3162
+ B.CreateAtomicRMW (op, ptr, dif, align, AtomicOrdering::Monotonic,
3163
+ SyncScope::System);
3164
+ #elif LLVM_VERSION_MAJOR >= 11
3165
+ AtomicRMWInst *rmw = B.CreateAtomicRMW (
3166
+ op, ptr, dif, AtomicOrdering::Monotonic, SyncScope::System);
3167
+ if (align)
3168
+ rmw->setAlignment (align.getValue ());
3169
+ #else
3170
+ B.CreateAtomicRMW (op, ptr, dif, AtomicOrdering::Monotonic,
3171
+ SyncScope::System);
3172
+ #endif
3173
+ }
3174
+ #else
3175
+ llvm::errs () << " unhandled atomic fadd on llvm version " << *ptr
3176
+ << " " << *dif << " \n " ;
3177
+ llvm_unreachable (" unhandled atomic fadd" );
3178
+ #endif
3179
+ }
3180
+ B.CreateRetVoid ();
3181
+ newcalled = F;
3182
+ }
3183
+
3060
3184
auto numargs = ConstantInt::get (Type::getInt32Ty (call.getContext ()),
3061
3185
args.size () - 3 );
3062
3186
args[0 ] =
@@ -3070,11 +3194,30 @@ class AdjointGenerator
3070
3194
diffes->setCallingConv (call.getCallingConv ());
3071
3195
diffes->setDebugLoc (gutils->getNewFromOriginal (call.getDebugLoc ()));
3072
3196
3197
+ for (size_t i = 0 ; i < OutTypes.size (); i++) {
3198
+
3199
+ size_t size = 1 ;
3200
+ if (OutTypes[i]->getType ()->isSized ())
3201
+ size = (gutils->newFunc ->getParent ()
3202
+ ->getDataLayout ()
3203
+ .getTypeSizeInBits (OutTypes[i]->getType ()) +
3204
+ 7 ) /
3205
+ 8 ;
3206
+ Value *Idxs[] = {
3207
+ ConstantInt::get (Type::getInt64Ty (call.getContext ()), 0 ),
3208
+ ConstantInt::get (Type::getInt32Ty (call.getContext ()), i)};
3209
+ ((DiffeGradientUtils *)gutils)
3210
+ ->addToDiffe (OutTypes[i],
3211
+ Builder2.CreateLoad (
3212
+ Builder2.CreateInBoundsGEP (OutAlloc, Idxs)),
3213
+ Builder2, TR.addingType (size, OutTypes[i]));
3214
+ }
3215
+
3073
3216
if (tape) {
3074
3217
for (auto idx : subdata->tapeIndiciesToFree ) {
3075
3218
auto ci = cast<CallInst>(CallInst::CreateFree (
3076
3219
Builder2.CreatePointerCast (
3077
- Builder2.CreateExtractValue (tape, idx),
3220
+ idx == - 1 ? tape : Builder2.CreateExtractValue (tape, idx),
3078
3221
Type::getInt8PtrTy (Builder2.getContext ())),
3079
3222
Builder2.GetInsertBlock ()));
3080
3223
ci->addAttribute (AttributeList::FirstArgIndex, Attribute::NonNull);
@@ -3759,7 +3902,7 @@ class AdjointGenerator
3759
3902
augmentedReturn->tapeIndices .find (std::make_pair (
3760
3903
orig, CacheType::Tape)) != augmentedReturn->tapeIndices .end ()) {
3761
3904
tape = Builder2.CreatePHI (Type::getInt32Ty (orig->getContext ()), 0 );
3762
- tape = gutils->cacheForReverse (Builder2, (Value *) tape,
3905
+ tape = gutils->cacheForReverse (Builder2, tape,
3763
3906
getIndex (orig, CacheType::Tape),
3764
3907
/* ignoreType*/ true );
3765
3908
}
@@ -3814,25 +3957,6 @@ class AdjointGenerator
3814
3957
return ;
3815
3958
}
3816
3959
3817
- if (Mode != DerivativeMode::ReverseModePrimal && called) {
3818
- if (funcName == " __kmpc_for_static_init_4" ||
3819
- funcName == " __kmpc_for_static_init_4u" ||
3820
- funcName == " __kmpc_for_static_init_8" ||
3821
- funcName == " __kmpc_for_static_init_8u" ) {
3822
- IRBuilder<> Builder2 (call.getParent ());
3823
- getReverseBuilder (Builder2);
3824
- auto fini = called->getParent ()->getFunction (" __kmpc_for_static_fini" );
3825
- assert (fini);
3826
- Value *args[] = {
3827
- lookup (gutils->getNewFromOriginal (call.getArgOperand (0 )), Builder2),
3828
- lookup (gutils->getNewFromOriginal (call.getArgOperand (1 )),
3829
- Builder2)};
3830
- auto fcall = Builder2.CreateCall (fini->getFunctionType (), fini, args);
3831
- fcall->setCallingConv (fini->getCallingConv ());
3832
- return ;
3833
- }
3834
- }
3835
-
3836
3960
if (funcName.startswith (" MPI_" ) && !gutils->isConstantInstruction (&call)) {
3837
3961
handleMPI (call, called, funcName);
3838
3962
return ;
@@ -3860,6 +3984,61 @@ class AdjointGenerator
3860
3984
visitOMPCall (call);
3861
3985
return ;
3862
3986
}
3987
+
3988
+ if (funcName == " __kmpc_for_static_init_4" ||
3989
+ funcName == " __kmpc_for_static_init_4u" ||
3990
+ funcName == " __kmpc_for_static_init_8" ||
3991
+ funcName == " __kmpc_for_static_init_8u" ) {
3992
+ if (Mode != DerivativeMode::ReverseModePrimal) {
3993
+ IRBuilder<> Builder2 (call.getParent ());
3994
+ getReverseBuilder (Builder2);
3995
+ auto fini =
3996
+ called->getParent ()->getFunction (" __kmpc_for_static_fini" );
3997
+ assert (fini);
3998
+ Value *args[] = {
3999
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (0 )),
4000
+ Builder2),
4001
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (1 )),
4002
+ Builder2)};
4003
+ auto fcall = Builder2.CreateCall (fini->getFunctionType (), fini, args);
4004
+ fcall->setCallingConv (fini->getCallingConv ());
4005
+ }
4006
+ return ;
4007
+ }
4008
+ if (funcName == " __kmpc_for_static_fini" ) {
4009
+ if (Mode != DerivativeMode::ReverseModePrimal) {
4010
+ eraseIfUnused (call, /* erase*/ true , /* check*/ false );
4011
+ }
4012
+ return ;
4013
+ }
4014
+ // TODO check
4015
+ // Adjoint of barrier is to place a barrier at the corresponding
4016
+ // location in the reverse.
4017
+ if (funcName == " __kmpc_barrier" ) {
4018
+ if (Mode == DerivativeMode::ReverseModeGradient ||
4019
+ Mode == DerivativeMode::ReverseModeCombined) {
4020
+ IRBuilder<> Builder2 (call.getParent ());
4021
+ getReverseBuilder (Builder2);
4022
+ #if LLVM_VERSION_MAJOR >= 11
4023
+ auto callval = call.getCalledOperand ();
4024
+ #else
4025
+ auto callval = call.getCalledValue ();
4026
+ #endif
4027
+ Value *args[] = {
4028
+ lookup (gutils->getNewFromOriginal (call.getOperand (0 )), Builder2),
4029
+ lookup (gutils->getNewFromOriginal (call.getOperand (1 )), Builder2)};
4030
+ Builder2.CreateCall (call.getFunctionType (), callval, args);
4031
+ }
4032
+ return ;
4033
+ }
4034
+
4035
+ if (funcName.startswith (" __kmpc" )) {
4036
+ llvm::errs () << *gutils->oldFunc << " \n " ;
4037
+ llvm::errs () << call << " \n " ;
4038
+ assert (0 && " unhandled openmp function" );
4039
+ llvm_unreachable (" unhandled openmp function" );
4040
+ }
4041
+
3863
4042
if (funcName == " asin" || funcName == " asinf" || funcName == " asinl" ) {
3864
4043
if (gutils->knownRecomputeHeuristic .find (orig) !=
3865
4044
gutils->knownRecomputeHeuristic .end ()) {
0 commit comments