Skip to content

Commit 2b8c2cf

Browse files
authored
[AutoDiff] Fix unexpected non-differentiable property access error. (#32671)
Add base type parameter to `TangentStoredPropertyRequest`. Use `TypeBase::getTypeOfMember` instead of `VarDecl::getType` to correctly compute the member type of original stored properties, using the base type. Resolves SR-13134.
1 parent 8cced04 commit 2b8c2cf

File tree

9 files changed

+87
-44
lines changed

9 files changed

+87
-44
lines changed

include/swift/AST/TypeCheckRequests.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,7 @@ class DerivativeAttrOriginalDeclRequest
21922192
/// property in a `Differentiable`-conforming type.
21932193
class TangentStoredPropertyRequest
21942194
: public SimpleRequest<TangentStoredPropertyRequest,
2195-
TangentPropertyInfo(VarDecl *),
2195+
TangentPropertyInfo(VarDecl *, CanType),
21962196
RequestFlags::Cached> {
21972197
public:
21982198
using SimpleRequest::SimpleRequest;
@@ -2201,8 +2201,8 @@ class TangentStoredPropertyRequest
22012201
friend SimpleRequest;
22022202

22032203
// Evaluation.
2204-
TangentPropertyInfo evaluate(Evaluator &evaluator,
2205-
VarDecl *originalField) const;
2204+
TangentPropertyInfo evaluate(Evaluator &evaluator, VarDecl *originalField,
2205+
CanType parentType) const;
22062206

22072207
public:
22082208
// Caching.

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest,
204204
AccessorDecl *(AbstractStorageDecl *, AccessorKind),
205205
SeparatelyCached, NoLocationInfo)
206206
SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest,
207-
llvm::Expected<VarDecl *>(VarDecl *), Cached, NoLocationInfo)
207+
llvm::Expected<VarDecl *>(VarDecl *, CanType), Cached, NoLocationInfo)
208208
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest,
209209
bool(AbstractFunctionDecl *), Cached, NoLocationInfo)
210210
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest,

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,18 @@ SILLocation getValidLocation(SILInstruction *inst);
158158
// Tangent property lookup utilities
159159
//===----------------------------------------------------------------------===//
160160

161-
/// Returns the tangent stored property of `originalField`. On error, emits
162-
/// diagnostic and returns nullptr.
161+
/// Returns the tangent stored property of the given original stored property
162+
/// and base type. On error, emits diagnostic and returns nullptr.
163163
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
164-
SILLocation loc,
164+
CanType baseType, SILLocation loc,
165165
DifferentiationInvoker invoker);
166166

167167
/// Returns the tangent stored property of the original stored property
168-
/// referenced by `inst`. On error, emits diagnostic and returns nullptr.
168+
/// referenced by the given projection instruction with the given base type.
169+
/// On error, emits diagnostic and returns nullptr.
169170
VarDecl *getTangentStoredProperty(ADContext &context,
170171
FieldIndexCacheBase *projectionInst,
172+
CanType baseType,
171173
DifferentiationInvoker invoker);
172174

173175
//===----------------------------------------------------------------------===//

lib/AST/AutoDiff.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,17 +475,16 @@ void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) {
475475
os << " }";
476476
}
477477

478-
TangentPropertyInfo
479-
TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
480-
VarDecl *originalField) const {
481-
assert(originalField->hasStorage() && originalField->isInstanceMember() &&
482-
"Expected stored property");
478+
TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
479+
Evaluator &evaluator, VarDecl *originalField, CanType baseType) const {
480+
assert(((originalField->hasStorage() && originalField->isInstanceMember()) ||
481+
originalField->hasAttachedPropertyWrapper()) &&
482+
"Expected a stored property or a property-wrapped property");
483483
auto *parentDC = originalField->getDeclContext();
484484
assert(parentDC->isTypeContext());
485-
auto parentType = parentDC->getDeclaredTypeInContext();
486485
auto *moduleDecl = originalField->getModuleContext();
487-
auto parentTan = parentType->getAutoDiffTangentSpace(
488-
LookUpConformanceInModule(moduleDecl));
486+
auto parentTan =
487+
baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl));
489488
// Error if parent nominal type does not conform to `Differentiable`.
490489
if (!parentTan) {
491490
return TangentPropertyInfo(
@@ -497,13 +496,18 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
497496
TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
498497
}
499498
// Error if original property's type does not conform to `Differentiable`.
500-
auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace(
499+
auto originalFieldType = baseType->getTypeOfMember(
500+
originalField->getModuleContext(), originalField);
501+
auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace(
501502
LookUpConformanceInModule(moduleDecl));
502503
if (!originalFieldTan) {
503504
return TangentPropertyInfo(
504505
TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
505506
}
506-
auto parentTanType = parentTan->getType();
507+
// Get the parent `TangentVector` type.
508+
auto memberSubs = baseType->getMemberSubstitutionMap(
509+
originalField->getModuleContext(), originalField);
510+
auto parentTanType = parentTan->getType().subst(memberSubs);
507511
auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct();
508512
// Error if parent `TangentVector` is not a struct.
509513
if (!parentTanStruct) {
@@ -533,7 +537,9 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
533537
// Error if tangent property's type is not equal to the original property's
534538
// `TangentVector` type.
535539
auto originalFieldTanType = originalFieldTan->getType();
536-
if (!originalFieldTanType->isEqual(tanField->getType())) {
540+
auto tanFieldType =
541+
parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField);
542+
if (!originalFieldTanType->isEqual(tanFieldType)) {
537543
return TangentPropertyInfo(
538544
TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
539545
originalFieldTanType);

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ SILLocation getValidLocation(SILInstruction *inst) {
269269
//===----------------------------------------------------------------------===//
270270

271271
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
272-
SILLocation loc,
272+
CanType baseType, SILLocation loc,
273273
DifferentiationInvoker invoker) {
274274
auto &astCtx = context.getASTContext();
275275
auto tanFieldInfo = evaluateOrDefault(
276-
astCtx.evaluator, TangentStoredPropertyRequest{originalField},
276+
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
277277
TangentPropertyInfo(nullptr));
278278
// If no error, return the tangent property.
279279
if (tanFieldInfo)
@@ -328,13 +328,14 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
328328

329329
VarDecl *getTangentStoredProperty(ADContext &context,
330330
FieldIndexCacheBase *projectionInst,
331+
CanType baseType,
331332
DifferentiationInvoker invoker) {
332333
assert(isa<StructExtractInst>(projectionInst) ||
333334
isa<StructElementAddrInst>(projectionInst) ||
334335
isa<RefElementAddrInst>(projectionInst));
335336
auto loc = getValidLocation(projectionInst);
336-
return getTangentStoredProperty(context, projectionInst->getField(), loc,
337-
invoker);
337+
return getTangentStoredProperty(context, projectionInst->getField(), baseType,
338+
loc, invoker);
338339
}
339340

340341
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Differentiation/JVPEmitter.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,18 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
548548
"`struct_extract` with `@noDerivative` field should not be "
549549
"differentiated; activity analysis should not marked as varied.");
550550
auto diffBuilder = getDifferentialBuilder();
551+
auto loc = getValidLocation(sei);
551552
// Find the corresponding field in the tangent space.
552-
auto *tanField = getTangentStoredProperty(context, sei, invoker);
553+
auto structType =
554+
remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType();
555+
auto *tanField = getTangentStoredProperty(context, sei, structType, invoker);
553556
if (!tanField) {
554557
errorOccurred = true;
555558
return;
556559
}
557560
// Emit tangent `struct_extract`.
558-
auto tanStruct =
559-
materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc());
560-
auto tangentInst =
561-
diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField);
561+
auto tanStruct = materializeTangent(getTangentValue(sei->getOperand()), loc);
562+
auto tangentInst = diffBuilder.createStructExtract(loc, tanStruct, tanField);
562563
// Update tangent value mapping for `struct_extract` result.
563564
auto tangentResult = makeConcreteTangentValue(tangentInst);
564565
setTangentValue(sei->getParent(), sei, tangentResult);
@@ -575,16 +576,19 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
575576
"differentiated; activity analysis should not marked as varied.");
576577
auto diffBuilder = getDifferentialBuilder();
577578
auto *bb = seai->getParent();
579+
auto loc = getValidLocation(seai);
578580
// Find the corresponding field in the tangent space.
579-
auto *tanField = getTangentStoredProperty(context, seai, invoker);
581+
auto structType =
582+
remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType();
583+
auto *tanField = getTangentStoredProperty(context, seai, structType, invoker);
580584
if (!tanField) {
581585
errorOccurred = true;
582586
return;
583587
}
584588
// Emit tangent `struct_element_addr`.
585589
auto tanOperand = getTangentBuffer(bb, seai->getOperand());
586590
auto tangentInst =
587-
diffBuilder.createStructElementAddr(seai->getLoc(), tanOperand, tanField);
591+
diffBuilder.createStructElementAddr(loc, tanOperand, tanField);
588592
// Update tangent buffer map for `struct_element_addr`.
589593
setTangentBuffer(bb, seai, tangentInst);
590594
}

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
371371
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
372372
"`@noDerivative` struct projections should never be active");
373373
auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
374-
auto *tanField = getTangentStoredProperty(getContext(), seai, getInvoker());
374+
auto structType = remapType(seai->getOperand()->getType()).getASTType();
375+
auto *tanField =
376+
getTangentStoredProperty(getContext(), seai, structType, getInvoker());
375377
assert(tanField && "Invalid projections should have been diagnosed");
376378
return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
377379
}
@@ -400,7 +402,10 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
400402
auto loc = reai->getLoc();
401403
// Get the class operand, stripping `begin_borrow`.
402404
auto classOperand = stripBorrow(reai->getOperand());
403-
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
405+
auto classType = remapType(reai->getOperand()->getType()).getASTType();
406+
auto *tanField =
407+
getTangentStoredProperty(getContext(), reai->getField(), classType,
408+
reai->getLoc(), getInvoker());
404409
assert(tanField && "Invalid projections should have been diagnosed");
405410
// Create a local allocation for the element adjoint buffer.
406411
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
@@ -666,8 +671,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
666671

667672
// Look up the corresponding field in the tangent space.
668673
auto *origField = cast<VarDecl>(accessor->getStorage());
669-
auto *tanField =
670-
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
674+
auto baseType = remapType(origSelf->getType()).getASTType();
675+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
676+
pbLoc, getInvoker());
671677
if (!tanField) {
672678
errorOccurred = true;
673679
return true;
@@ -772,8 +778,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() {
772778

773779
// Look up the corresponding field in the tangent space.
774780
auto *origField = cast<VarDecl>(accessor->getStorage());
775-
auto *tanField =
776-
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
781+
auto baseType = remapType(origSelf->getType()).getASTType();
782+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
783+
pbLoc, getInvoker());
777784
if (!tanField) {
778785
errorOccurred = true;
779786
return true;
@@ -882,7 +889,10 @@ bool PullbackEmitter::run() {
882889
}
883890
// Diagnose unsupported stored property projections.
884891
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
885-
if (!getTangentStoredProperty(getContext(), inst, getInvoker())) {
892+
assert(inst->getNumOperands() == 1);
893+
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
894+
if (!getTangentStoredProperty(getContext(), inst, baseType,
895+
getInvoker())) {
886896
errorOccurred = true;
887897
return true;
888898
}
@@ -1694,8 +1704,8 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
16941704
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
16951705
continue;
16961706
// Find the corresponding field in the tangent space.
1697-
auto *tanField =
1698-
getTangentStoredProperty(getContext(), field, loc, getInvoker());
1707+
auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
1708+
loc, getInvoker());
16991709
if (!tanField) {
17001710
errorOccurred = true;
17011711
return;
@@ -1727,21 +1737,23 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {
17271737

17281738
void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
17291739
auto *bb = sei->getParent();
1740+
auto loc = getValidLocation(sei);
17301741
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
17311742
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
17321743
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
17331744
auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy);
17341745
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
17351746
assert(tangentVectorDecl);
17361747
// Find the corresponding field in the tangent space.
1737-
auto *tanField = getTangentStoredProperty(getContext(), sei, getInvoker());
1748+
auto *tanField =
1749+
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
17381750
assert(tanField && "Invalid projections should have been diagnosed");
17391751
// Accumulate adjoint for the `struct_extract` operand.
17401752
auto av = getAdjointValue(bb, sei);
17411753
switch (av.getKind()) {
17421754
case AdjointValueKind::Zero:
17431755
addAdjointValue(bb, sei->getOperand(),
1744-
makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc());
1756+
makeZeroAdjointValue(tangentVectorSILTy), loc);
17451757
break;
17461758
case AdjointValueKind::Concrete:
17471759
case AdjointValueKind::Aggregate: {
@@ -1760,7 +1772,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
17601772
}
17611773
addAdjointValue(bb, sei->getOperand(),
17621774
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
1763-
sei->getLoc());
1775+
loc);
17641776
}
17651777
}
17661778
}
@@ -1770,7 +1782,9 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
17701782
auto loc = reai->getLoc();
17711783
auto adjBuf = getAdjointBuffer(bb, reai);
17721784
auto classOperand = reai->getOperand();
1773-
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
1785+
auto classType = remapType(reai->getOperand()->getType()).getASTType();
1786+
auto *tanField =
1787+
getTangentStoredProperty(getContext(), reai, classType, getInvoker());
17741788
assert(tanField && "Invalid projections should have been diagnosed");
17751789
switch (getTangentValueCategory(classOperand)) {
17761790
case SILValueCategory::Object: {

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
764764
tangentProperty->setSetterAccess(member->getFormalAccess());
765765

766766
// Cache the tangent property.
767-
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member},
767+
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()},
768768
TangentPropertyInfo(tangentProperty));
769769

770770
// Now that the original property has a corresponding tangent property, it

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,22 @@ func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Fl
636636
// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored
637637
// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x
638638

639+
// SR-13134: Test stored property access with conditionally `Differentiable` base type.
640+
641+
struct Complex<T: FloatingPoint> {
642+
var real: T
643+
var imaginary: T
644+
}
645+
extension Complex: Differentiable where T: Differentiable {
646+
typealias TangentVector = Complex
647+
}
648+
extension Complex: AdditiveArithmetic {}
649+
650+
@differentiable
651+
func SR_13134(lhs: Complex<Float>, rhs: Complex<Float>) -> Float {
652+
return lhs.real + rhs.real
653+
}
654+
639655
//===----------------------------------------------------------------------===//
640656
// Wrapped property differentiation
641657
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)