From a2439b2ec22f3821e0f3cf607e1ac4e242f9aedb Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 5 Jul 2020 14:30:25 -0700 Subject: [PATCH 1/2] [AutoDiff] Disallow differentiation of opaque-result-typed functions. Reject `@differentiable` and `@derivative` attribute for original functions with opaque result types. It is not possible to support derivative registration nor the differentiation transform for such functions. Resolves SR-12656. --- include/swift/AST/DiagnosticsSema.def | 2 ++ lib/Sema/TypeCheckAttr.cpp | 27 ++++++++++++++++++- .../Sema/derivative_attr_type_checking.swift | 12 ++++++++- .../differentiable_attr_type_checking.swift | 6 ++++- 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 2ab5a7a338234..ced5a901f8a1e 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3151,6 +3151,8 @@ ERROR(autodiff_attr_original_multiple_semantic_results,none, ERROR(autodiff_attr_result_not_differentiable,none, "can only differentiate functions with results that conform to " "'Differentiable', but %0 does not conform to 'Differentiable'", (Type)) +ERROR(autodiff_attr_opaque_result_type_unsupported,none, + "cannot differentiate functions returning opaque result types", ()) // differentiation `wrt` parameters clause ERROR(diff_function_no_parameters,none, diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index e2f05cfb4f7b4..8883713e5aa3f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4240,6 +4240,15 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( auto *originalFnTy = original->getInterfaceType()->castTo(); + // Diagnose if original function has opaque result types. + if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl()) { + diags.diagnose( + attr->getLocation(), + diag::autodiff_attr_opaque_result_type_unsupported); + attr->setInvalid(); + return nullptr; + } + // Diagnose if original function is an invalid class member. bool isOriginalClassMember = original->getDeclContext() && original->getDeclContext()->getSelfClassDecl(); @@ -4532,6 +4541,16 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, return true; } } + + // Diagnose if original function has opaque result types. + if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) { + diags.diagnose( + attr->getLocation(), + diag::autodiff_attr_opaque_result_type_unsupported); + attr->setInvalid(); + return true; + } + // Diagnose if original function is an invalid class member. bool isOriginalClassMember = originalAFD->getDeclContext() && @@ -5083,9 +5102,15 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { attr->setInvalid(); return; } - attr->setOriginalFunction(originalAFD); + // Diagnose if original function has opaque result types. + if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) { + diagnose(attr->getLocation(), diag::autodiff_attr_opaque_result_type_unsupported); + attr->setInvalid(); + return; + } + // Get the linearity parameter types. SmallVector linearParams; expectedOriginalFnType->getSubsetParameters(linearParamIndices, linearParams, diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 1de073a5d35fa..0b17860a1b51a 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-frontend-typecheck -verify %s +// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s import _Differentiation @@ -1124,3 +1124,13 @@ extension Float { fatalError() } } + +// Test original function with opaque result type. + +func opaqueResult(_ x: Float) -> some Differentiable { x } + +// expected-error @+1 {{could not find function 'opaqueResult' with expected type '(Float) -> Float'}} +@derivative(of: opaqueResult) +func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + fatalError() +} diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 55cdfee43d9c6..a090b5de6a121 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-frontend-typecheck -verify %s +// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s import _Differentiation @@ -697,3 +697,7 @@ struct Accessors: Differentiable { _modify { yield &stored } } } + +// expected-error @+1 {{cannot differentiate functions returning opaque result types}} +@differentiable +func opaqueResult(_ x: Float) -> some Differentiable { x } From bacb532a1b7558ee1a66d662ec0ffb4b2d9baf12 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 5 Jul 2020 17:36:44 -0700 Subject: [PATCH 2/2] Mark compiler crasher test as fixed. --- .../sr12656-differentiation-opaque-result-type.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/AutoDiff/{compiler_crashers => compiler_crashers_fixed}/sr12656-differentiation-opaque-result-type.swift (94%) diff --git a/test/AutoDiff/compiler_crashers/sr12656-differentiation-opaque-result-type.swift b/test/AutoDiff/compiler_crashers_fixed/sr12656-differentiation-opaque-result-type.swift similarity index 94% rename from test/AutoDiff/compiler_crashers/sr12656-differentiation-opaque-result-type.swift rename to test/AutoDiff/compiler_crashers_fixed/sr12656-differentiation-opaque-result-type.swift index 703786fcd8031..ed48abb31be0f 100644 --- a/test/AutoDiff/compiler_crashers/sr12656-differentiation-opaque-result-type.swift +++ b/test/AutoDiff/compiler_crashers_fixed/sr12656-differentiation-opaque-result-type.swift @@ -1,4 +1,4 @@ -// RUN: not --crash %target-swift-frontend -disable-availability-checking -emit-sil -verify %s +// RUN: not %target-swift-frontend -disable-availability-checking -emit-sil -verify %s // REQUIRES: asserts // SR-12656: Differentiation transform crashes for original function with opaque