Skip to content

Commit d4bbcf9

Browse files
authored
[AutoDiff] Fix unexpected non-differentiable property access error. (#32670)
Add base type parameter to `TangentStoredPropertyRequest`. Use `TypeBase::getTypeOfMember` instead of `VarDecl::getType` to correctly compute the member type of original stored properties, using the base type. Resolves SR-13134.
1 parent ca87af1 commit d4bbcf9

File tree

9 files changed

+87
-44
lines changed

9 files changed

+87
-44
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,7 @@ class DerivativeAttrOriginalDeclRequest
21922192
/// property in a `Differentiable`-conforming type.
21932193
class TangentStoredPropertyRequest
21942194
: public SimpleRequest<TangentStoredPropertyRequest,
2195-
TangentPropertyInfo(VarDecl *),
2195+
TangentPropertyInfo(VarDecl *, CanType),
21962196
RequestFlags::Cached> {
21972197
public:
21982198
using SimpleRequest::SimpleRequest;
@@ -2201,8 +2201,8 @@ class TangentStoredPropertyRequest
22012201
friend SimpleRequest;
22022202

22032203
// Evaluation.
2204-
TangentPropertyInfo evaluate(Evaluator &evaluator,
2205-
VarDecl *originalField) const;
2204+
TangentPropertyInfo evaluate(Evaluator &evaluator, VarDecl *originalField,
2205+
CanType parentType) const;
22062206

22072207
public:
22082208
// Caching.

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest,
204204
AccessorDecl *(AbstractStorageDecl *, AccessorKind),
205205
SeparatelyCached, NoLocationInfo)
206206
SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest,
207-
llvm::Expected<VarDecl *>(VarDecl *), Cached, NoLocationInfo)
207+
llvm::Expected<VarDecl *>(VarDecl *, CanType), Cached, NoLocationInfo)
208208
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest,
209209
bool(AbstractFunctionDecl *), Cached, NoLocationInfo)
210210
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest,

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,18 @@ SILLocation getValidLocation(SILInstruction *inst);
157157
// Tangent property lookup utilities
158158
//===----------------------------------------------------------------------===//
159159

160-
/// Returns the tangent stored property of `originalField`. On error, emits
161-
/// diagnostic and returns nullptr.
160+
/// Returns the tangent stored property of the given original stored property
161+
/// and base type. On error, emits diagnostic and returns nullptr.
162162
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
163-
SILLocation loc,
163+
CanType baseType, SILLocation loc,
164164
DifferentiationInvoker invoker);
165165

166166
/// Returns the tangent stored property of the original stored property
167-
/// referenced by `inst`. On error, emits diagnostic and returns nullptr.
167+
/// referenced by the given projection instruction with the given base type.
168+
/// On error, emits diagnostic and returns nullptr.
168169
VarDecl *getTangentStoredProperty(ADContext &context,
169170
FieldIndexCacheBase *projectionInst,
171+
CanType baseType,
170172
DifferentiationInvoker invoker);
171173

172174
//===----------------------------------------------------------------------===//

lib/AST/AutoDiff.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,17 +475,16 @@ void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) {
475475
os << " }";
476476
}
477477

478-
TangentPropertyInfo
479-
TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
480-
VarDecl *originalField) const {
481-
assert(originalField->hasStorage() && originalField->isInstanceMember() &&
482-
"Expected stored property");
478+
TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
479+
Evaluator &evaluator, VarDecl *originalField, CanType baseType) const {
480+
assert(((originalField->hasStorage() && originalField->isInstanceMember()) ||
481+
originalField->hasAttachedPropertyWrapper()) &&
482+
"Expected a stored property or a property-wrapped property");
483483
auto *parentDC = originalField->getDeclContext();
484484
assert(parentDC->isTypeContext());
485-
auto parentType = parentDC->getDeclaredTypeInContext();
486485
auto *moduleDecl = originalField->getModuleContext();
487-
auto parentTan = parentType->getAutoDiffTangentSpace(
488-
LookUpConformanceInModule(moduleDecl));
486+
auto parentTan =
487+
baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl));
489488
// Error if parent nominal type does not conform to `Differentiable`.
490489
if (!parentTan) {
491490
return TangentPropertyInfo(
@@ -497,13 +496,18 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
497496
TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
498497
}
499498
// Error if original property's type does not conform to `Differentiable`.
500-
auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace(
499+
auto originalFieldType = baseType->getTypeOfMember(
500+
originalField->getModuleContext(), originalField);
501+
auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace(
501502
LookUpConformanceInModule(moduleDecl));
502503
if (!originalFieldTan) {
503504
return TangentPropertyInfo(
504505
TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
505506
}
506-
auto parentTanType = parentTan->getType();
507+
// Get the parent `TangentVector` type.
508+
auto parentTanType =
509+
baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl))
510+
->getType();
507511
auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct();
508512
// Error if parent `TangentVector` is not a struct.
509513
if (!parentTanStruct) {
@@ -533,7 +537,9 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
533537
// Error if tangent property's type is not equal to the original property's
534538
// `TangentVector` type.
535539
auto originalFieldTanType = originalFieldTan->getType();
536-
if (!originalFieldTanType->isEqual(tanField->getType())) {
540+
auto tanFieldType =
541+
parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField);
542+
if (!originalFieldTanType->isEqual(tanFieldType)) {
537543
return TangentPropertyInfo(
538544
TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
539545
originalFieldTanType);

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ SILLocation getValidLocation(SILInstruction *inst) {
271271
//===----------------------------------------------------------------------===//
272272

273273
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
274-
SILLocation loc,
274+
CanType baseType, SILLocation loc,
275275
DifferentiationInvoker invoker) {
276276
auto &astCtx = context.getASTContext();
277277
auto tanFieldInfo = evaluateOrDefault(
278-
astCtx.evaluator, TangentStoredPropertyRequest{originalField},
278+
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
279279
TangentPropertyInfo(nullptr));
280280
// If no error, return the tangent property.
281281
if (tanFieldInfo)
@@ -330,13 +330,14 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
330330

331331
VarDecl *getTangentStoredProperty(ADContext &context,
332332
FieldIndexCacheBase *projectionInst,
333+
CanType baseType,
333334
DifferentiationInvoker invoker) {
334335
assert(isa<StructExtractInst>(projectionInst) ||
335336
isa<StructElementAddrInst>(projectionInst) ||
336337
isa<RefElementAddrInst>(projectionInst));
337338
auto loc = getValidLocation(projectionInst);
338-
return getTangentStoredProperty(context, projectionInst->getField(), loc,
339-
invoker);
339+
return getTangentStoredProperty(context, projectionInst->getField(), baseType,
340+
loc, invoker);
340341
}
341342

342343
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Differentiation/JVPEmitter.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -550,17 +550,18 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
550550
"`struct_extract` with `@noDerivative` field should not be "
551551
"differentiated; activity analysis should not marked as varied.");
552552
auto diffBuilder = getDifferentialBuilder();
553+
auto loc = getValidLocation(sei);
553554
// Find the corresponding field in the tangent space.
554-
auto *tanField = getTangentStoredProperty(context, sei, invoker);
555+
auto structType =
556+
remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType();
557+
auto *tanField = getTangentStoredProperty(context, sei, structType, invoker);
555558
if (!tanField) {
556559
errorOccurred = true;
557560
return;
558561
}
559562
// Emit tangent `struct_extract`.
560-
auto tanStruct =
561-
materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc());
562-
auto tangentInst =
563-
diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField);
563+
auto tanStruct = materializeTangent(getTangentValue(sei->getOperand()), loc);
564+
auto tangentInst = diffBuilder.createStructExtract(loc, tanStruct, tanField);
564565
// Update tangent value mapping for `struct_extract` result.
565566
auto tangentResult = makeConcreteTangentValue(tangentInst);
566567
setTangentValue(sei->getParent(), sei, tangentResult);
@@ -577,16 +578,19 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
577578
"differentiated; activity analysis should not marked as varied.");
578579
auto diffBuilder = getDifferentialBuilder();
579580
auto *bb = seai->getParent();
581+
auto loc = getValidLocation(seai);
580582
// Find the corresponding field in the tangent space.
581-
auto *tanField = getTangentStoredProperty(context, seai, invoker);
583+
auto structType =
584+
remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType();
585+
auto *tanField = getTangentStoredProperty(context, seai, structType, invoker);
582586
if (!tanField) {
583587
errorOccurred = true;
584588
return;
585589
}
586590
// Emit tangent `struct_element_addr`.
587591
auto tanOperand = getTangentBuffer(bb, seai->getOperand());
588592
auto tangentInst =
589-
diffBuilder.createStructElementAddr(seai->getLoc(), tanOperand, tanField);
593+
diffBuilder.createStructElementAddr(loc, tanOperand, tanField);
590594
// Update tangent buffer map for `struct_element_addr`.
591595
setTangentBuffer(bb, seai, tangentInst);
592596
}

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
371371
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
372372
"`@noDerivative` struct projections should never be active");
373373
auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
374-
auto *tanField = getTangentStoredProperty(getContext(), seai, getInvoker());
374+
auto structType = remapType(seai->getOperand()->getType()).getASTType();
375+
auto *tanField =
376+
getTangentStoredProperty(getContext(), seai, structType, getInvoker());
375377
assert(tanField && "Invalid projections should have been diagnosed");
376378
return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
377379
}
@@ -400,7 +402,10 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
400402
auto loc = reai->getLoc();
401403
// Get the class operand, stripping `begin_borrow`.
402404
auto classOperand = stripBorrow(reai->getOperand());
403-
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
405+
auto classType = remapType(reai->getOperand()->getType()).getASTType();
406+
auto *tanField =
407+
getTangentStoredProperty(getContext(), reai->getField(), classType,
408+
reai->getLoc(), getInvoker());
404409
assert(tanField && "Invalid projections should have been diagnosed");
405410
// Create a local allocation for the element adjoint buffer.
406411
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
@@ -666,8 +671,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
666671

667672
// Look up the corresponding field in the tangent space.
668673
auto *origField = cast<VarDecl>(accessor->getStorage());
669-
auto *tanField =
670-
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
674+
auto baseType = remapType(origSelf->getType()).getASTType();
675+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
676+
pbLoc, getInvoker());
671677
if (!tanField) {
672678
errorOccurred = true;
673679
return true;
@@ -772,8 +778,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() {
772778

773779
// Look up the corresponding field in the tangent space.
774780
auto *origField = cast<VarDecl>(accessor->getStorage());
775-
auto *tanField =
776-
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
781+
auto baseType = remapType(origSelf->getType()).getASTType();
782+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
783+
pbLoc, getInvoker());
777784
if (!tanField) {
778785
errorOccurred = true;
779786
return true;
@@ -882,7 +889,10 @@ bool PullbackEmitter::run() {
882889
}
883890
// Diagnose unsupported stored property projections.
884891
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
885-
if (!getTangentStoredProperty(getContext(), inst, getInvoker())) {
892+
assert(inst->getNumOperands() == 1);
893+
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
894+
if (!getTangentStoredProperty(getContext(), inst, baseType,
895+
getInvoker())) {
886896
errorOccurred = true;
887897
return true;
888898
}
@@ -1699,8 +1709,8 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
16991709
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
17001710
continue;
17011711
// Find the corresponding field in the tangent space.
1702-
auto *tanField =
1703-
getTangentStoredProperty(getContext(), field, loc, getInvoker());
1712+
auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
1713+
loc, getInvoker());
17041714
if (!tanField) {
17051715
errorOccurred = true;
17061716
return;
@@ -1732,21 +1742,23 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {
17321742

17331743
void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
17341744
auto *bb = sei->getParent();
1745+
auto loc = getValidLocation(sei);
17351746
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
17361747
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
17371748
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
17381749
auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy);
17391750
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
17401751
assert(tangentVectorDecl);
17411752
// Find the corresponding field in the tangent space.
1742-
auto *tanField = getTangentStoredProperty(getContext(), sei, getInvoker());
1753+
auto *tanField =
1754+
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
17431755
assert(tanField && "Invalid projections should have been diagnosed");
17441756
// Accumulate adjoint for the `struct_extract` operand.
17451757
auto av = getAdjointValue(bb, sei);
17461758
switch (av.getKind()) {
17471759
case AdjointValueKind::Zero:
17481760
addAdjointValue(bb, sei->getOperand(),
1749-
makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc());
1761+
makeZeroAdjointValue(tangentVectorSILTy), loc);
17501762
break;
17511763
case AdjointValueKind::Concrete:
17521764
case AdjointValueKind::Aggregate: {
@@ -1765,7 +1777,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
17651777
}
17661778
addAdjointValue(bb, sei->getOperand(),
17671779
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
1768-
sei->getLoc());
1780+
loc);
17691781
}
17701782
}
17711783
}
@@ -1775,7 +1787,9 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
17751787
auto loc = reai->getLoc();
17761788
auto adjBuf = getAdjointBuffer(bb, reai);
17771789
auto classOperand = reai->getOperand();
1778-
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
1790+
auto classType = remapType(reai->getOperand()->getType()).getASTType();
1791+
auto *tanField =
1792+
getTangentStoredProperty(getContext(), reai, classType, getInvoker());
17791793
assert(tanField && "Invalid projections should have been diagnosed");
17801794
switch (getTangentValueCategory(classOperand)) {
17811795
case SILValueCategory::Object: {

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
668668
tangentProperty->setSetterAccess(member->getFormalAccess());
669669

670670
// Cache the tangent property.
671-
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member},
671+
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()},
672672
TangentPropertyInfo(tangentProperty));
673673

674674
// Now that the original property has a corresponding tangent property, it

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,22 @@ func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Fl
624624
// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored
625625
// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x
626626

627+
// SR-13134: Test stored property access with conditionally `Differentiable` base type.
628+
629+
struct Complex<T: FloatingPoint> {
630+
var real: T
631+
var imaginary: T
632+
}
633+
extension Complex: Differentiable where T: Differentiable {
634+
typealias TangentVector = Complex
635+
}
636+
extension Complex: AdditiveArithmetic {}
637+
638+
@differentiable
639+
func SR_13134(lhs: Complex<Float>, rhs: Complex<Float>) -> Float {
640+
return lhs.real + rhs.real
641+
}
642+
627643
//===----------------------------------------------------------------------===//
628644
// Wrapped property differentiation
629645
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)