Skip to content

Commit 9c20198

Browse files
authored
[AutoDiff] Disallow differentiation of opaque-result-typed functions. (#32714)
Reject `@differentiable` and `@derivative` attribute for original functions with opaque result types. It is not possible to support derivative registration nor the differentiation transform for such functions. Resolves SR-12656.
1 parent c07b3b0 commit 9c20198

File tree

5 files changed

+45
-4
lines changed

5 files changed

+45
-4
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,8 @@ ERROR(autodiff_attr_original_multiple_semantic_results,none,
31513151
ERROR(autodiff_attr_result_not_differentiable,none,
31523152
"can only differentiate functions with results that conform to "
31533153
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
3154+
ERROR(autodiff_attr_opaque_result_type_unsupported,none,
3155+
"cannot differentiate functions returning opaque result types", ())
31543156

31553157
// differentiation `wrt` parameters clause
31563158
ERROR(diff_function_no_parameters,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4240,6 +4240,15 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
42404240

42414241
auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();
42424242

4243+
// Diagnose if original function has opaque result types.
4244+
if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl()) {
4245+
diags.diagnose(
4246+
attr->getLocation(),
4247+
diag::autodiff_attr_opaque_result_type_unsupported);
4248+
attr->setInvalid();
4249+
return nullptr;
4250+
}
4251+
42434252
// Diagnose if original function is an invalid class member.
42444253
bool isOriginalClassMember = original->getDeclContext() &&
42454254
original->getDeclContext()->getSelfClassDecl();
@@ -4532,6 +4541,16 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45324541
return true;
45334542
}
45344543
}
4544+
4545+
// Diagnose if original function has opaque result types.
4546+
if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) {
4547+
diags.diagnose(
4548+
attr->getLocation(),
4549+
diag::autodiff_attr_opaque_result_type_unsupported);
4550+
attr->setInvalid();
4551+
return true;
4552+
}
4553+
45354554
// Diagnose if original function is an invalid class member.
45364555
bool isOriginalClassMember =
45374556
originalAFD->getDeclContext() &&
@@ -5083,9 +5102,15 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
50835102
attr->setInvalid();
50845103
return;
50855104
}
5086-
50875105
attr->setOriginalFunction(originalAFD);
50885106

5107+
// Diagnose if original function has opaque result types.
5108+
if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) {
5109+
diagnose(attr->getLocation(), diag::autodiff_attr_opaque_result_type_unsupported);
5110+
attr->setInvalid();
5111+
return;
5112+
}
5113+
50895114
// Get the linearity parameter types.
50905115
SmallVector<AnyFunctionType::Param, 4> linearParams;
50915116
expectedOriginalFnType->getSubsetParameters(linearParamIndices, linearParams,

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend-typecheck -verify %s
1+
// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s
22

33
import _Differentiation
44

@@ -1124,3 +1124,13 @@ extension Float {
11241124
fatalError()
11251125
}
11261126
}
1127+
1128+
// Test original function with opaque result type.
1129+
1130+
func opaqueResult(_ x: Float) -> some Differentiable { x }
1131+
1132+
// expected-error @+1 {{could not find function 'opaqueResult' with expected type '(Float) -> Float'}}
1133+
@derivative(of: opaqueResult)
1134+
func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1135+
fatalError()
1136+
}

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend-typecheck -verify %s
1+
// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s
22

33
import _Differentiation
44

@@ -697,3 +697,7 @@ struct Accessors: Differentiable {
697697
_modify { yield &stored }
698698
}
699699
}
700+
701+
// expected-error @+1 {{cannot differentiate functions returning opaque result types}}
702+
@differentiable
703+
func opaqueResult(_ x: Float) -> some Differentiable { x }

test/AutoDiff/compiler_crashers/sr12656-differentiation-opaque-result-type.swift renamed to test/AutoDiff/compiler_crashers_fixed/sr12656-differentiation-opaque-result-type.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: not --crash %target-swift-frontend -disable-availability-checking -emit-sil -verify %s
1+
// RUN: not %target-swift-frontend -disable-availability-checking -emit-sil -verify %s
22
// REQUIRES: asserts
33

44
// SR-12656: Differentiation transform crashes for original function with opaque

0 commit comments

Comments
 (0)