File tree 3 files changed +49
-0
lines changed
test/AutoDiff/SILOptimizer
3 files changed +49
-0
lines changed Original file line number Diff line number Diff line change @@ -7106,6 +7106,11 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
7106
7106
ArrayRef<AutoDiffConfig>
7107
7107
AbstractFunctionDecl::getDerivativeFunctionConfigurations () {
7108
7108
prepareDerivativeFunctionConfigurations ();
7109
+ // Resolve derivative function configurations from `@differentiable`
7110
+ // attributes by type-checking them.
7111
+ for (auto *diffAttr : getAttrs ().getAttributes <DifferentiableAttr>())
7112
+ (void )diffAttr->getParameterIndices ();
7113
+ // Load derivative configurations from imported modules.
7109
7114
auto &ctx = getASTContext ();
7110
7115
if (ctx.getCurrentGeneration () > DerivativeFunctionConfigGeneration) {
7111
7116
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
Original file line number Diff line number Diff line change
1
+ import _Differentiation
2
+
3
+ protocol Protocol : Differentiable {
4
+ // Test cross-file `@differentiable` attribute.
5
+ @differentiable ( wrt: self )
6
+ func identityDifferentiableAttr( ) -> Self
7
+ }
8
+
9
+ extension Protocol {
10
+ func identityDerivativeAttr( ) -> Self { self }
11
+
12
+ // Test cross-file `@derivative` attribute.
13
+ @derivative ( of: identityDerivativeAttr)
14
+ func vjpIdentityDerivativeAttr( ) -> (
15
+ value: Self , pullback: ( TangentVector ) -> TangentVector
16
+ ) {
17
+ fatalError ( )
18
+ }
19
+ }
Original file line number Diff line number Diff line change
1
+ // RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/differentiation_diagnostics_other_file.swift -module-name main -o /dev/null
2
+
3
+ // Test differentiation transform cross-file diagnostics.
4
+
5
+ import _Differentiation
6
+
7
+ // TF-1271: Test `@differentiable` original function in other file.
8
+ @differentiable
9
+ func crossFileDifferentiableAttr< T: Protocol > (
10
+ _ input: T
11
+ ) -> T {
12
+ return input. identityDifferentiableAttr ( )
13
+ }
14
+
15
+ // TF-1272: Test original function with registered derivatives in other files.
16
+ // FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other
17
+ // files.
18
+ @differentiable
19
+ func crossFileDerivativeAttr< T: Protocol > (
20
+ _ input: T
21
+ ) -> T {
22
+ // expected-error @+2 {{expression is not differentiable}}
23
+ // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
24
+ return input. identityDerivativeAttr ( )
25
+ }
You can’t perform that action at this time.
0 commit comments