Skip to content

Commit 738ef73

Browse files
authored
[AutoDiff] Fix @differentiable attribute derivative configurations. (swiftlang#31524)
In `AbstractFunctionDecl::getDerivativeFunctionConfigurations`, type-check `@differentiable` attributes. This is important to populate derivative configurations for original functions in other files. Resolves TF-1271. Exposes TF-1272: fix derivative configurations for cross-file `@derivative` attributes. This is a more difficult issue.
1 parent fae995a commit 738ef73

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

lib/AST/Decl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7106,6 +7106,11 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
71067106
ArrayRef<AutoDiffConfig>
71077107
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
71087108
prepareDerivativeFunctionConfigurations();
7109+
// Resolve derivative function configurations from `@differentiable`
7110+
// attributes by type-checking them.
7111+
for (auto *diffAttr : getAttrs().getAttributes<DifferentiableAttr>())
7112+
(void)diffAttr->getParameterIndices();
7113+
// Load derivative configurations from imported modules.
71097114
auto &ctx = getASTContext();
71107115
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
71117116
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import _Differentiation
2+
3+
protocol Protocol: Differentiable {
4+
// Test cross-file `@differentiable` attribute.
5+
@differentiable(wrt: self)
6+
func identityDifferentiableAttr() -> Self
7+
}
8+
9+
extension Protocol {
10+
func identityDerivativeAttr() -> Self { self }
11+
12+
// Test cross-file `@derivative` attribute.
13+
@derivative(of: identityDerivativeAttr)
14+
func vjpIdentityDerivativeAttr() -> (
15+
value: Self, pullback: (TangentVector) -> TangentVector
16+
) {
17+
fatalError()
18+
}
19+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/differentiation_diagnostics_other_file.swift -module-name main -o /dev/null
2+
3+
// Test differentiation transform cross-file diagnostics.
4+
5+
import _Differentiation
6+
7+
// TF-1271: Test `@differentiable` original function in other file.
8+
@differentiable
9+
func crossFileDifferentiableAttr<T: Protocol>(
10+
_ input: T
11+
) -> T {
12+
return input.identityDifferentiableAttr()
13+
}
14+
15+
// TF-1272: Test original function with registered derivatives in other files.
16+
// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other
17+
// files.
18+
@differentiable
19+
func crossFileDerivativeAttr<T: Protocol>(
20+
_ input: T
21+
) -> T {
22+
// expected-error @+2 {{expression is not differentiable}}
23+
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
24+
return input.identityDerivativeAttr()
25+
}

0 commit comments

Comments
 (0)