Skip to content

[AutoDiff] Fix unexpected non-differentiable property access error. #32670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -2192,7 +2192,7 @@ class DerivativeAttrOriginalDeclRequest
/// property in a `Differentiable`-conforming type.
class TangentStoredPropertyRequest
: public SimpleRequest<TangentStoredPropertyRequest,
TangentPropertyInfo(VarDecl *),
TangentPropertyInfo(VarDecl *, CanType),
Copy link
Contributor Author

@dan-zheng dan-zheng Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: adding a base type parameter to TangentStoredPropertyRequest results in fewer cache hits.
It may be possible to fix SR-13134 without adding a base type parameter, I may take a look later.

RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -2201,8 +2201,8 @@ class TangentStoredPropertyRequest
friend SimpleRequest;

// Evaluation.
TangentPropertyInfo evaluate(Evaluator &evaluator,
VarDecl *originalField) const;
TangentPropertyInfo evaluate(Evaluator &evaluator, VarDecl *originalField,
CanType parentType) const;

public:
// Caching.
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest,
AccessorDecl *(AbstractStorageDecl *, AccessorKind),
SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest,
llvm::Expected<VarDecl *>(VarDecl *), Cached, NoLocationInfo)
llvm::Expected<VarDecl *>(VarDecl *, CanType), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest,
bool(AbstractFunctionDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest,
Expand Down
10 changes: 6 additions & 4 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,18 @@ SILLocation getValidLocation(SILInstruction *inst);
// Tangent property lookup utilities
//===----------------------------------------------------------------------===//

/// Returns the tangent stored property of `originalField`. On error, emits
/// diagnostic and returns nullptr.
/// Returns the tangent stored property of the given original stored property
/// and base type. On error, emits diagnostic and returns nullptr.
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
SILLocation loc,
CanType baseType, SILLocation loc,
DifferentiationInvoker invoker);

/// Returns the tangent stored property of the original stored property
/// referenced by `inst`. On error, emits diagnostic and returns nullptr.
/// referenced by the given projection instruction with the given base type.
/// On error, emits diagnostic and returns nullptr.
VarDecl *getTangentStoredProperty(ADContext &context,
FieldIndexCacheBase *projectionInst,
CanType baseType,
DifferentiationInvoker invoker);

//===----------------------------------------------------------------------===//
Expand Down
28 changes: 17 additions & 11 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,17 +475,16 @@ void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) {
os << " }";
}

TangentPropertyInfo
TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
VarDecl *originalField) const {
assert(originalField->hasStorage() && originalField->isInstanceMember() &&
"Expected stored property");
TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
Evaluator &evaluator, VarDecl *originalField, CanType baseType) const {
assert(((originalField->hasStorage() && originalField->isInstanceMember()) ||
originalField->hasAttachedPropertyWrapper()) &&
"Expected a stored property or a property-wrapped property");
auto *parentDC = originalField->getDeclContext();
assert(parentDC->isTypeContext());
auto parentType = parentDC->getDeclaredTypeInContext();
auto *moduleDecl = originalField->getModuleContext();
auto parentTan = parentType->getAutoDiffTangentSpace(
LookUpConformanceInModule(moduleDecl));
auto parentTan =
baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl));
// Error if parent nominal type does not conform to `Differentiable`.
if (!parentTan) {
return TangentPropertyInfo(
Expand All @@ -497,13 +496,18 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
}
// Error if original property's type does not conform to `Differentiable`.
auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace(
auto originalFieldType = baseType->getTypeOfMember(
originalField->getModuleContext(), originalField);
auto originalFieldTan = originalFieldType->getAutoDiffTangentSpace(
LookUpConformanceInModule(moduleDecl));
if (!originalFieldTan) {
return TangentPropertyInfo(
TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
}
auto parentTanType = parentTan->getType();
// Get the parent `TangentVector` type.
auto parentTanType =
baseType->getAutoDiffTangentSpace(LookUpConformanceInModule(moduleDecl))
->getType();
auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct();
// Error if parent `TangentVector` is not a struct.
if (!parentTanStruct) {
Expand Down Expand Up @@ -533,7 +537,9 @@ TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
// Error if tangent property's type is not equal to the original property's
// `TangentVector` type.
auto originalFieldTanType = originalFieldTan->getType();
if (!originalFieldTanType->isEqual(tanField->getType())) {
auto tanFieldType =
parentTanType->getTypeOfMember(tanField->getModuleContext(), tanField);
if (!originalFieldTanType->isEqual(tanFieldType)) {
return TangentPropertyInfo(
TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
originalFieldTanType);
Expand Down
9 changes: 5 additions & 4 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,11 @@ SILLocation getValidLocation(SILInstruction *inst) {
//===----------------------------------------------------------------------===//

VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
SILLocation loc,
CanType baseType, SILLocation loc,
DifferentiationInvoker invoker) {
auto &astCtx = context.getASTContext();
auto tanFieldInfo = evaluateOrDefault(
astCtx.evaluator, TangentStoredPropertyRequest{originalField},
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
TangentPropertyInfo(nullptr));
// If no error, return the tangent property.
if (tanFieldInfo)
Expand Down Expand Up @@ -328,13 +328,14 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,

VarDecl *getTangentStoredProperty(ADContext &context,
FieldIndexCacheBase *projectionInst,
CanType baseType,
DifferentiationInvoker invoker) {
assert(isa<StructExtractInst>(projectionInst) ||
isa<StructElementAddrInst>(projectionInst) ||
isa<RefElementAddrInst>(projectionInst));
auto loc = getValidLocation(projectionInst);
return getTangentStoredProperty(context, projectionInst->getField(), loc,
invoker);
return getTangentStoredProperty(context, projectionInst->getField(), baseType,
loc, invoker);
}

//===----------------------------------------------------------------------===//
Expand Down
18 changes: 11 additions & 7 deletions lib/SILOptimizer/Differentiation/JVPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,18 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
"`struct_extract` with `@noDerivative` field should not be "
"differentiated; activity analysis should not marked as varied.");
auto diffBuilder = getDifferentialBuilder();
auto loc = getValidLocation(sei);
// Find the corresponding field in the tangent space.
auto *tanField = getTangentStoredProperty(context, sei, invoker);
auto structType =
remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType();
auto *tanField = getTangentStoredProperty(context, sei, structType, invoker);
if (!tanField) {
errorOccurred = true;
return;
}
// Emit tangent `struct_extract`.
auto tanStruct =
materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc());
auto tangentInst =
diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField);
auto tanStruct = materializeTangent(getTangentValue(sei->getOperand()), loc);
auto tangentInst = diffBuilder.createStructExtract(loc, tanStruct, tanField);
// Update tangent value mapping for `struct_extract` result.
auto tangentResult = makeConcreteTangentValue(tangentInst);
setTangentValue(sei->getParent(), sei, tangentResult);
Expand All @@ -575,16 +576,19 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
"differentiated; activity analysis should not marked as varied.");
auto diffBuilder = getDifferentialBuilder();
auto *bb = seai->getParent();
auto loc = getValidLocation(seai);
// Find the corresponding field in the tangent space.
auto *tanField = getTangentStoredProperty(context, seai, invoker);
auto structType =
remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType();
auto *tanField = getTangentStoredProperty(context, seai, structType, invoker);
if (!tanField) {
errorOccurred = true;
return;
}
// Emit tangent `struct_element_addr`.
auto tanOperand = getTangentBuffer(bb, seai->getOperand());
auto tangentInst =
diffBuilder.createStructElementAddr(seai->getLoc(), tanOperand, tanField);
diffBuilder.createStructElementAddr(loc, tanOperand, tanField);
// Update tangent buffer map for `struct_element_addr`.
setTangentBuffer(bb, seai, tangentInst);
}
Expand Down
40 changes: 27 additions & 13 deletions lib/SILOptimizer/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
"`@noDerivative` struct projections should never be active");
auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
auto *tanField = getTangentStoredProperty(getContext(), seai, getInvoker());
auto structType = remapType(seai->getOperand()->getType()).getASTType();
auto *tanField =
getTangentStoredProperty(getContext(), seai, structType, getInvoker());
assert(tanField && "Invalid projections should have been diagnosed");
return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
}
Expand Down Expand Up @@ -400,7 +402,10 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
auto loc = reai->getLoc();
// Get the class operand, stripping `begin_borrow`.
auto classOperand = stripBorrow(reai->getOperand());
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
auto classType = remapType(reai->getOperand()->getType()).getASTType();
auto *tanField =
getTangentStoredProperty(getContext(), reai->getField(), classType,
reai->getLoc(), getInvoker());
assert(tanField && "Invalid projections should have been diagnosed");
// Create a local allocation for the element adjoint buffer.
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
Expand Down Expand Up @@ -666,8 +671,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() {

// Look up the corresponding field in the tangent space.
auto *origField = cast<VarDecl>(accessor->getStorage());
auto *tanField =
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
auto baseType = remapType(origSelf->getType()).getASTType();
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
pbLoc, getInvoker());
if (!tanField) {
errorOccurred = true;
return true;
Expand Down Expand Up @@ -772,8 +778,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() {

// Look up the corresponding field in the tangent space.
auto *origField = cast<VarDecl>(accessor->getStorage());
auto *tanField =
getTangentStoredProperty(getContext(), origField, pbLoc, getInvoker());
auto baseType = remapType(origSelf->getType()).getASTType();
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
pbLoc, getInvoker());
if (!tanField) {
errorOccurred = true;
return true;
Expand Down Expand Up @@ -882,7 +889,10 @@ bool PullbackEmitter::run() {
}
// Diagnose unsupported stored property projections.
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
if (!getTangentStoredProperty(getContext(), inst, getInvoker())) {
assert(inst->getNumOperands() == 1);
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
if (!getTangentStoredProperty(getContext(), inst, baseType,
getInvoker())) {
errorOccurred = true;
return true;
}
Expand Down Expand Up @@ -1694,8 +1704,8 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Find the corresponding field in the tangent space.
auto *tanField =
getTangentStoredProperty(getContext(), field, loc, getInvoker());
auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
loc, getInvoker());
if (!tanField) {
errorOccurred = true;
return;
Expand Down Expand Up @@ -1727,21 +1737,23 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {

void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
auto *bb = sei->getParent();
auto loc = getValidLocation(sei);
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy);
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
assert(tangentVectorDecl);
// Find the corresponding field in the tangent space.
auto *tanField = getTangentStoredProperty(getContext(), sei, getInvoker());
auto *tanField =
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
assert(tanField && "Invalid projections should have been diagnosed");
// Accumulate adjoint for the `struct_extract` operand.
auto av = getAdjointValue(bb, sei);
switch (av.getKind()) {
case AdjointValueKind::Zero:
addAdjointValue(bb, sei->getOperand(),
makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc());
makeZeroAdjointValue(tangentVectorSILTy), loc);
break;
case AdjointValueKind::Concrete:
case AdjointValueKind::Aggregate: {
Expand All @@ -1760,7 +1772,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
}
addAdjointValue(bb, sei->getOperand(),
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
sei->getLoc());
loc);
}
}
}
Expand All @@ -1770,7 +1782,9 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
auto loc = reai->getLoc();
auto adjBuf = getAdjointBuffer(bb, reai);
auto classOperand = reai->getOperand();
auto *tanField = getTangentStoredProperty(getContext(), reai, getInvoker());
auto classType = remapType(reai->getOperand()->getType()).getASTType();
auto *tanField =
getTangentStoredProperty(getContext(), reai, classType, getInvoker());
assert(tanField && "Invalid projections should have been diagnosed");
switch (getTangentValueCategory(classOperand)) {
case SILValueCategory::Object: {
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
tangentProperty->setSetterAccess(member->getFormalAccess());

// Cache the tangent property.
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member},
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()},
TangentPropertyInfo(tangentProperty));

// Now that the original property has a corresponding tangent property, it
Expand Down
16 changes: 16 additions & 0 deletions test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,22 @@ func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Fl
// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored
// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x

// SR-13134: Test stored property access with conditionally `Differentiable` base type.

struct Complex<T: FloatingPoint> {
var real: T
var imaginary: T
}
extension Complex: Differentiable where T: Differentiable {
typealias TangentVector = Complex
}
extension Complex: AdditiveArithmetic {}

@differentiable
func SR_13134(lhs: Complex<Float>, rhs: Complex<Float>) -> Float {
return lhs.real + rhs.real
}

//===----------------------------------------------------------------------===//
// Wrapped property differentiation
//===----------------------------------------------------------------------===//
Expand Down