From c8fb9cfafe07c54ba4dd5aa61b994469bd8ac1b9 Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Fri, 15 Sep 2023 15:27:07 +0900 Subject: [PATCH 1/2] [AutoDiff] Fix custom derivative thunk for Optional The patch fixes the issue #55882 and enables the nil coalescing operator (aka `??`) for Optional type. --- lib/SILGen/SILGenPoly.cpp | 9 +++++--- test/AutoDiff/SILGen/nil_coalescing.swift | 25 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 test/AutoDiff/SILGen/nil_coalescing.swift diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 36790109e2dfa..56f03c956859e 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -6305,10 +6305,13 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( arguments.push_back(indErrorRes.getLValueAddress()); forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments); + SubstitutionMap subs = thunk->getForwardingSubstitutionMap(); + SILType substFnType = fnRef->getType().substGenericArgs( + M, subs, thunk->getTypeExpansionContext()); + // Apply function argument. - auto apply = thunkSGF.emitApplyWithRethrow( - loc, fnRef, /*substFnType*/ fnRef->getType(), - thunk->getForwardingSubstitutionMap(), arguments); + auto apply = + thunkSGF.emitApplyWithRethrow(loc, fnRef, substFnType, subs, arguments); // Self reordering thunk is necessary if wrt at least two parameters, // including self. diff --git a/test/AutoDiff/SILGen/nil_coalescing.swift b/test/AutoDiff/SILGen/nil_coalescing.swift new file mode 100644 index 0000000000000..994fbeed691fa --- /dev/null +++ b/test/AutoDiff/SILGen/nil_coalescing.swift @@ -0,0 +1,25 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s + +import _Differentiation + +// CHECK: sil @test_nil_coalescing +// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional, %[[ARG_PB:.*]] : +// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for ): +// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional +// CHECK: copy_addr %[[ARG_OPT]] to [init] %[[ALLOC_OPT]] : $*Optional +// CHECK: switch_enum_addr %[[ALLOC_OPT]] : $*Optional, case #Optional.some!enumelt: {{.*}}, case #Optional.none!enumelt: {{.*}} +// CHECK: try_apply %[[ARG_PB]](%{{.*}}) : $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for , normal {{.*}}, error {{.*}} +// +@_silgen_name("test_nil_coalescing") +@derivative(of: ??) +@usableFromInline +func nilCoalescing(optional: T?, defaultValue: @autoclosure () throws -> T) + rethrows -> (value: T, pullback: (T.TangentVector) -> Optional.TangentVector) +{ + let hasValue = optional != nil + let value = try optional ?? defaultValue() + func pullback(_ v: T.TangentVector) -> Optional.TangentVector { + return hasValue ? .init(v) : .zero + } + return (value, pullback) +} From c653a7d56761f38e5976099993f9e81050a790d3 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Wed, 5 Jun 2024 16:50:48 -0700 Subject: [PATCH 2/2] Operator ?? is hidden_external and therefore might be serialized. empty JVPs / VJPs has the same linkage as the original function sans external flag and also inherit serialization flag. This works for all cases but not for hidden_external ones. Drop serializable flag for such case --- .../Mandatory/Differentiation.cpp | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index b8251eccfdd09..3584fba568a3c 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -906,6 +906,15 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( traceMessage.c_str(), witness->getOriginalFunction()); assert(witness->isDefinition()); + SILFunction *orig = witness->getOriginalFunction(); + + // We can generate empty JVP / VJP for functions available externally. These + // functions have the same linkage as the original ones sans `external` + // flag. Important exception here hidden_external functions as they are + // serializable but corresponding hidden ones would be not and the SIL + // verifier will fail. Patch `serializeFunctions` for this case. + if (orig->getLinkage() == SILLinkage::HiddenExternal) + serializeFunctions = IsNotSerialized; // If the JVP doesn't exist, need to synthesize it. if (!witness->getJVP()) { @@ -914,9 +923,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( // - Functions with unsupported control flow. if (context.getASTContext() .LangOpts.hasFeature(Feature::ForwardModeDifferentiation) && - (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) || - diagnoseUnsupportedControlFlow( - context, witness->getOriginalFunction(), invoker))) + (diagnoseNoReturn(context, orig, invoker) || + diagnoseUnsupportedControlFlow(context, orig, invoker))) return true; // Create empty JVP. @@ -933,10 +941,10 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( !witness->getVJP()) { // JVP and differential generation do not currently support functions with // multiple basic blocks. - if (witness->getOriginalFunction()->size() > 1) { - context.emitNondifferentiabilityError( - witness->getOriginalFunction()->getLocation().getSourceLoc(), - invoker, diag::autodiff_jvp_control_flow_not_supported); + if (orig->size() > 1) { + context.emitNondifferentiabilityError(orig->getLocation().getSourceLoc(), + invoker, + diag::autodiff_jvp_control_flow_not_supported); return true; } // Emit JVP function. @@ -950,7 +958,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( "_fatalErrorForwardModeDifferentiationDisabled"); LLVM_DEBUG(getADDebugStream() << "Generated empty JVP for " - << witness->getOriginalFunction()->getName() << ":\n" + << orig->getName() << ":\n" << *jvp); } } @@ -960,9 +968,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( // Diagnose: // - Functions with no return. // - Functions with unsupported control flow. - if (diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) || - diagnoseUnsupportedControlFlow( - context, witness->getOriginalFunction(), invoker)) + if (diagnoseNoReturn(context, orig, invoker) || + diagnoseUnsupportedControlFlow(context, orig, invoker)) return true; // Create empty VJP.