diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 04fe9c2bc2183..48dd1a5b58301 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -2192,7 +2192,7 @@ class DerivativeAttrOriginalDeclRequest /// property in a `Differentiable`-conforming type. class TangentStoredPropertyRequest : public SimpleRequest { public: using SimpleRequest::SimpleRequest; @@ -2201,8 +2201,8 @@ class TangentStoredPropertyRequest friend SimpleRequest; // Evaluation. - TangentPropertyInfo evaluate(Evaluator &evaluator, - VarDecl *originalField) const; + TangentPropertyInfo evaluate(Evaluator &evaluator, VarDecl *originalField, + CanType parentType) const; public: // Caching. diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index c77b86bde2ed1..2f5ee6f22796a 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -204,7 +204,7 @@ SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest, AccessorDecl *(AbstractStorageDecl *, AccessorKind), SeparatelyCached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest, - llvm::Expected(VarDecl *), Cached, NoLocationInfo) + llvm::Expected(VarDecl *, CanType), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest, bool(AbstractFunctionDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest, diff --git a/include/swift/SILOptimizer/Differentiation/Common.h b/include/swift/SILOptimizer/Differentiation/Common.h index 82bdb4a2a492d..2f996740b08b6 100644 --- a/include/swift/SILOptimizer/Differentiation/Common.h +++ b/include/swift/SILOptimizer/Differentiation/Common.h @@ -157,16 +157,18 @@ SILLocation getValidLocation(SILInstruction *inst); // Tangent property lookup utilities //===----------------------------------------------------------------------===// -/// Returns the tangent stored property of `originalField`. On error, emits -/// diagnostic and returns nullptr. +/// Returns the tangent stored property of the given original stored property +/// and base type. On error, emits diagnostic and returns nullptr. VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, - SILLocation loc, + CanType baseType, SILLocation loc, DifferentiationInvoker invoker); /// Returns the tangent stored property of the original stored property -/// referenced by `inst`. On error, emits diagnostic and returns nullptr. +/// referenced by the given projection instruction with the given base type. +/// On error, emits diagnostic and returns nullptr. VarDecl *getTangentStoredProperty(ADContext &context, FieldIndexCacheBase *projectionInst, + CanType baseType, DifferentiationInvoker invoker); //===----------------------------------------------------------------------===// diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 87c65561dcf6c..83b7d981f08db 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -475,17 +475,16 @@ void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) { os << " }"; } -TangentPropertyInfo -TangentStoredPropertyRequest::evaluate(Evaluator &evaluator, - VarDecl *originalField) const { - assert(originalField->hasStorage() && originalField->isInstanceMember() && - "Expected stored property"); +TangentPropertyInfo TangentStoredPropertyRequest::evaluate( + Evaluator &evaluator, VarDecl *originalField, CanType baseType) const { + assert(((originalField->hasStorage() && originalField->isInstanceMember()) || + originalField->hasAttachedPropertyWrapper()) && + "Expected a stored property or a property-wrapped property"); auto *parentDC = originalField->getDeclContext(); assert(parentDC->isTypeContext()); - auto parentType = parentDC->getDeclaredTypeInContext(); auto *moduleDecl = originalField->getModuleContext(); - auto parentTan = parentType->getAutoDiffTangentSpace( - LookUpConformanceInModule(moduleDecl)); + auto parentTan = + baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl)); // Error if parent nominal type does not conform to `Differentiable`. if (!parentTan) { return TangentPropertyInfo( @@ -497,13 +496,18 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator, TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty); } // Error if original property's type does not conform to `Differentiable`. - auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace( + auto originalFieldType = baseType->getTypeOfMember( + originalField->getModuleContext(), originalField); + auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace( LookUpConformanceInModule(moduleDecl)); if (!originalFieldTan) { return TangentPropertyInfo( TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable); } - auto parentTanType = parentTan->getType(); + // Get the parent `TangentVector` type. + auto parentTanType = + baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl)) + ->getType(); auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct(); // Error if parent `TangentVector` is not a struct. if (!parentTanStruct) { @@ -533,7 +537,9 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator, // Error if tangent property's type is not equal to the original property's // `TangentVector` type. auto originalFieldTanType = originalFieldTan->getType(); - if (!originalFieldTanType->isEqual(tanField->getType())) { + auto tanFieldType = + parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField); + if (!originalFieldTanType->isEqual(tanFieldType)) { return TangentPropertyInfo( TangentPropertyInfo::Error::Kind::TangentPropertyWrongType, originalFieldTanType); diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index d9226575dafc6..57c925e54ca0a 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -269,11 +269,11 @@ SILLocation getValidLocation(SILInstruction *inst) { //===----------------------------------------------------------------------===// VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, - SILLocation loc, + CanType baseType, SILLocation loc, DifferentiationInvoker invoker) { auto &astCtx = context.getASTContext(); auto tanFieldInfo = evaluateOrDefault( - astCtx.evaluator, TangentStoredPropertyRequest{originalField}, + astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType}, TangentPropertyInfo(nullptr)); // If no error, return the tangent property. if (tanFieldInfo) @@ -328,13 +328,14 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, VarDecl *getTangentStoredProperty(ADContext &context, FieldIndexCacheBase *projectionInst, + CanType baseType, DifferentiationInvoker invoker) { assert(isa(projectionInst) || isa(projectionInst) || isa(projectionInst)); auto loc = getValidLocation(projectionInst); - return getTangentStoredProperty(context, projectionInst->getField(), loc, - invoker); + return getTangentStoredProperty(context, projectionInst->getField(), baseType, + loc, invoker); } //===----------------------------------------------------------------------===// diff --git a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp b/lib/SILOptimizer/Differentiation/JVPEmitter.cpp index 9e0cd4f47a2ae..d27ea634d36cf 100644 --- a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/JVPEmitter.cpp @@ -548,17 +548,18 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) { "`struct_extract` with `@noDerivative` field should not be " "differentiated; activity analysis should not marked as varied."); auto diffBuilder = getDifferentialBuilder(); + auto loc = getValidLocation(sei); // Find the corresponding field in the tangent space. - auto *tanField = getTangentStoredProperty(context, sei, invoker); + auto structType = + remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(context, sei, structType, invoker); if (!tanField) { errorOccurred = true; return; } // Emit tangent `struct_extract`. - auto tanStruct = - materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc()); - auto tangentInst = - diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField); + auto tanStruct = materializeTangent(getTangentValue(sei->getOperand()), loc); + auto tangentInst = diffBuilder.createStructExtract(loc, tanStruct, tanField); // Update tangent value mapping for `struct_extract` result. auto tangentResult = makeConcreteTangentValue(tangentInst); setTangentValue(sei->getParent(), sei, tangentResult); @@ -575,8 +576,11 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { "differentiated; activity analysis should not marked as varied."); auto diffBuilder = getDifferentialBuilder(); auto *bb = seai->getParent(); + auto loc = getValidLocation(seai); // Find the corresponding field in the tangent space. - auto *tanField = getTangentStoredProperty(context, seai, invoker); + auto structType = + remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(context, seai, structType, invoker); if (!tanField) { errorOccurred = true; return; @@ -584,7 +588,7 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { // Emit tangent `struct_element_addr`. auto tanOperand = getTangentBuffer(bb, seai->getOperand()); auto tangentInst = - diffBuilder.createStructElementAddr(seai->getLoc(), tanOperand, tanField); + diffBuilder.createStructElementAddr(loc, tanOperand, tanField); // Update tangent buffer map for `struct_element_addr`. setTangentBuffer(bb, seai, tangentInst); } diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp index 48e937ec63176..e6c1844cb7a9e 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp @@ -371,7 +371,9 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, assert(!seai->getField()->getAttrs().hasAttribute() && "`@noDerivative` struct projections should never be active"); auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); - auto *tanField = getTangentStoredProperty(getContext(), seai, getInvoker()); + auto structType = remapType(seai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), seai, structType, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); } @@ -400,7 +402,10 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, auto loc = reai->getLoc(); // Get the class operand, stripping `begin_borrow`. auto classOperand = stripBorrow(reai->getOperand()); - auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker()); + auto classType = remapType(reai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), reai->getField(), classType, + reai->getLoc(), getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); // Create a local allocation for the element adjoint buffer. auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); @@ -666,8 +671,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() { // Look up the corresponding field in the tangent space. auto *origField = cast(accessor->getStorage()); - auto *tanField = - getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker()); + auto baseType = remapType(origSelf->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, + pbLoc, getInvoker()); if (!tanField) { errorOccurred = true; return true; @@ -772,8 +778,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() { // Look up the corresponding field in the tangent space. auto *origField = cast(accessor->getStorage()); - auto *tanField = - getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker()); + auto baseType = remapType(origSelf->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, + pbLoc, getInvoker()); if (!tanField) { errorOccurred = true; return true; @@ -882,7 +889,10 @@ bool PullbackEmitter::run() { } // Diagnose unsupported stored property projections. if (auto *inst = dyn_cast(v)) { - if (!getTangentStoredProperty(getContext(), inst, getInvoker())) { + assert(inst->getNumOperands() == 1); + auto baseType = remapType(inst->getOperand(0)->getType()).getASTType(); + if (!getTangentStoredProperty(getContext(), inst, baseType, + getInvoker())) { errorOccurred = true; return true; } @@ -1694,8 +1704,8 @@ void PullbackEmitter::visitStructInst(StructInst *si) { if (field->getAttrs().hasAttribute()) continue; // Find the corresponding field in the tangent space. - auto *tanField = - getTangentStoredProperty(getContext(), field, loc, getInvoker()); + auto *tanField = getTangentStoredProperty(getContext(), field, structTy, + loc, getInvoker()); if (!tanField) { errorOccurred = true; return; @@ -1727,6 +1737,7 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) { void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { auto *bb = sei->getParent(); + auto loc = getValidLocation(sei); auto structTy = remapType(sei->getOperand()->getType()).getASTType(); auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); @@ -1734,14 +1745,15 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); assert(tangentVectorDecl); // Find the corresponding field in the tangent space. - auto *tanField = getTangentStoredProperty(getContext(), sei, getInvoker()); + auto *tanField = + getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); // Accumulate adjoint for the `struct_extract` operand. auto av = getAdjointValue(bb, sei); switch (av.getKind()) { case AdjointValueKind::Zero: addAdjointValue(bb, sei->getOperand(), - makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc()); + makeZeroAdjointValue(tangentVectorSILTy), loc); break; case AdjointValueKind::Concrete: case AdjointValueKind::Aggregate: { @@ -1760,7 +1772,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { } addAdjointValue(bb, sei->getOperand(), makeAggregateAdjointValue(tangentVectorSILTy, eltVals), - sei->getLoc()); + loc); } } } @@ -1770,7 +1782,9 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) { auto loc = reai->getLoc(); auto adjBuf = getAdjointBuffer(bb, reai); auto classOperand = reai->getOperand(); - auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker()); + auto classType = remapType(reai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), reai, classType, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); switch (getTangentValueCategory(classOperand)) { case SILValueCategory::Object: { diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 713b7ca8129b8..afa901b6185d5 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -668,7 +668,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { tangentProperty->setSetterAccess(member->getFormalAccess()); // Cache the tangent property. - C.evaluator.cacheOutput(TangentStoredPropertyRequest{member}, + C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()}, TangentPropertyInfo(tangentProperty)); // Now that the original property has a corresponding tangent property, it diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index a038e01dafbf2..9e754a92e10e0 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -636,6 +636,22 @@ func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Fl // CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored // CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x +// SR-13134: Test stored property access with conditionally `Differentiable` base type. + +struct Complex { + var real: T + var imaginary: T +} +extension Complex: Differentiable where T: Differentiable { + typealias TangentVector = Complex +} +extension Complex: AdditiveArithmetic {} + +@differentiable +func SR_13134(lhs: Complex, rhs: Complex) -> Float { + return lhs.real + rhs.real +} + //===----------------------------------------------------------------------===// // Wrapped property differentiation //===----------------------------------------------------------------------===//