Skip to content

Commit 517bcc4

Browse files
authored
[AutoDiff] Fix differentiation transform crashers in library evolution mode. (#34704)
AD-generated data structures (linear map structs and branching trace enums) do not need to be resilient data structures. These decls ade missing a `@frozen` attribute. Resolves rdar://71319547.
1 parent 3343a6a commit 517bcc4

File tree

3 files changed

+58
-15
lines changed

3 files changed

+58
-15
lines changed

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
138138
// Note: must mark enum as implicit to satisfy assertion in
139139
// `Parser::parseDeclListDelayed`.
140140
branchingTraceDecl->setImplicit();
141+
// Branching trace enums shall not be resilient.
142+
branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
141143
if (genericSig)
142144
branchingTraceDecl->setGenericSignature(genericSig);
143145
computeAccessLevel(branchingTraceDecl, original->getEffectiveSymbolLinkage());
@@ -201,6 +203,8 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
201203
// Note: must mark struct as implicit to satisfy assertion in
202204
// `Parser::parseDeclListDelayed`.
203205
linearMapStruct->setImplicit();
206+
// Linear map structs shall not be resilient.
207+
linearMapStruct->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
204208
if (genericSig)
205209
linearMapStruct->setGenericSignature(genericSig);
206210
computeAccessLevel(linearMapStruct, original->getEffectiveSymbolLinkage());

test/AutoDiff/compiler_crashers_fixed/rdar71191415-nested-differentiation-of-extension-method-optimized.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
import _Differentiation
66

77
protocol P {
8-
@differentiable
9-
func req(_ input: Float) -> Float
8+
@differentiable
9+
func req(_ input: Float) -> Float
1010
}
1111

1212
extension P {
13-
@differentiable
14-
func foo(_ input: Float) -> Float {
15-
return req(input)
16-
}
13+
@differentiable
14+
func foo(_ input: Float) -> Float {
15+
return req(input)
16+
}
1717
}
1818

1919
struct Dummy: P {
20-
@differentiable
21-
func req(_ input: Float) -> Float {
22-
input
23-
}
20+
@differentiable
21+
func req(_ input: Float) -> Float {
22+
input
23+
}
2424
}
2525

2626
struct DummyComposition: P {
27-
var layer = Dummy()
27+
var layer = Dummy()
2828

29-
@differentiable
30-
func req(_ input: Float) -> Float {
31-
layer.foo(input)
32-
}
29+
@differentiable
30+
func req(_ input: Float) -> Float {
31+
layer.foo(input)
32+
}
3333
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-build-swift -enable-library-evolution %s
2+
// RUN: %target-build-swift -O -enable-library-evolution %s
3+
// RUN: %target-build-swift -O -g -enable-library-evolution %s
4+
5+
// rdar://71319547
6+
7+
import _Differentiation
8+
9+
10+
// Assertion failed: (mainPullbackStruct->getType() == pbStructLoweredType), function run, file swift/lib/SILOptimizer/Differentiation/PullbackCloner.cpp, line 1899.
11+
// Stack dump:
12+
// 1. Swift version 5.3-dev (LLVM 618cb952e0f199a, Swift d74c261f098665c)
13+
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Mandatory Diagnostic Passes + Enabling Optimization Passes } on SIL for main.main)
14+
// 3. While running pass #17 SILModuleTransform "Differentiation".
15+
// 4. While processing // differentiability witness for foo(_:)
16+
// sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s4main3fooyS2fF : $@convention(thin) (Float) -> Float {
17+
// }
18+
@differentiable(wrt: x)
19+
public func i_have_a_pullback_struct(_ x: Float) -> Float {
20+
return x
21+
}
22+
23+
24+
// Assertion failed: (v->getType().isObject()), function operator(), file swift/lib/SIL/Utils/ValueUtils.cpp, line 22.
25+
// Stack dump:
26+
// 1. Swift version 5.3-dev (LLVM 618cb952e0f199a, Swift d74c261f098665c)
27+
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Mandatory Diagnostic Passes + Enabling Optimization Passes } on SIL for main.main)
28+
// 3. While running pass #24 SILModuleTransform "Differentiation".
29+
// 4. While processing // differentiability witness for i_have_a_branching_trace_enum(_:)
30+
// sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s4main29i_have_a_branching_trace_enumyS2fF : $@convention(thin) (Float) -> Float {
31+
// }
32+
@differentiable(wrt: x)
33+
public func i_have_a_branching_trace_enum(_ x: Float) -> Float {
34+
if true {
35+
return x
36+
} else {
37+
return x.squareRoot()
38+
}
39+
}

0 commit comments

Comments
 (0)