Skip to content

[AutoDiff] Fix differentiation transform crashers in library evolution mode. #34704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
// Note: must mark enum as implicit to satisfy assertion in
// `Parser::parseDeclListDelayed`.
branchingTraceDecl->setImplicit();
// Branching trace enums shall not be resilient.
branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
if (genericSig)
branchingTraceDecl->setGenericSignature(genericSig);
computeAccessLevel(branchingTraceDecl, original->getEffectiveSymbolLinkage());
Expand Down Expand Up @@ -201,6 +203,8 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
// Note: must mark struct as implicit to satisfy assertion in
// `Parser::parseDeclListDelayed`.
linearMapStruct->setImplicit();
// Linear map structs shall not be resilient.
linearMapStruct->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
if (genericSig)
linearMapStruct->setGenericSignature(genericSig);
computeAccessLevel(linearMapStruct, original->getEffectiveSymbolLinkage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
import _Differentiation

protocol P {
@differentiable
func req(_ input: Float) -> Float
@differentiable
func req(_ input: Float) -> Float
}

extension P {
@differentiable
func foo(_ input: Float) -> Float {
return req(input)
}
@differentiable
func foo(_ input: Float) -> Float {
return req(input)
}
}

struct Dummy: P {
@differentiable
func req(_ input: Float) -> Float {
input
}
@differentiable
func req(_ input: Float) -> Float {
input
}
}

struct DummyComposition: P {
var layer = Dummy()
var layer = Dummy()

@differentiable
func req(_ input: Float) -> Float {
layer.foo(input)
}
@differentiable
func req(_ input: Float) -> Float {
layer.foo(input)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %target-build-swift -enable-library-evolution %s
// RUN: %target-build-swift -O -enable-library-evolution %s
// RUN: %target-build-swift -O -g -enable-library-evolution %s

// rdar://71319547

import _Differentiation


// Assertion failed: (mainPullbackStruct->getType() == pbStructLoweredType), function run, file swift/lib/SILOptimizer/Differentiation/PullbackCloner.cpp, line 1899.
// Stack dump:
// 1. Swift version 5.3-dev (LLVM 618cb952e0f199a, Swift d74c261f098665c)
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Mandatory Diagnostic Passes + Enabling Optimization Passes } on SIL for main.main)
// 3. While running pass #17 SILModuleTransform "Differentiation".
// 4. While processing // differentiability witness for foo(_:)
// sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s4main3fooyS2fF : $@convention(thin) (Float) -> Float {
// }
@differentiable(wrt: x)
public func i_have_a_pullback_struct(_ x: Float) -> Float {
return x
}


// Assertion failed: (v->getType().isObject()), function operator(), file swift/lib/SIL/Utils/ValueUtils.cpp, line 22.
// Stack dump:
// 1. Swift version 5.3-dev (LLVM 618cb952e0f199a, Swift d74c261f098665c)
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Mandatory Diagnostic Passes + Enabling Optimization Passes } on SIL for main.main)
// 3. While running pass #24 SILModuleTransform "Differentiation".
// 4. While processing // differentiability witness for i_have_a_branching_trace_enum(_:)
// sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s4main29i_have_a_branching_trace_enumyS2fF : $@convention(thin) (Float) -> Float {
// }
@differentiable(wrt: x)
public func i_have_a_branching_trace_enum(_ x: Float) -> Float {
if true {
return x
} else {
return x.squareRoot()
}
}