Skip to content

feat: ensure an overriding member matches the overridden one #758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
import { SafeDsServices } from '../safe-ds-module.js';
import {
CompositeGeneratorNode,
expandToNode,
expandTracedToNode,
findRootNode,
getContainerOfType,
getDocument,
joinToNode,
joinTracedToNode,
LangiumDocument,
NL,
streamAllContents,
toStringAndTrace,
TraceRegion,
traceToNode,
TreeStreamImpl,
URI,
} from 'langium';
import path from 'path';
import { SourceMapGenerator, StartOfSourceMap } from 'source-map';
import { TextDocument } from 'vscode-languageserver-textdocument';
import { groupBy } from '../../helpers/collectionUtils.js';
import { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js';
import {
isSdsAbstractResult,
isSdsAssignment,
Expand Down Expand Up @@ -46,49 +68,27 @@ import {
SdsStatement,
} from '../generated/ast.js';
import { isInStubFile, isStubFile } from '../helpers/fileExtensions.js';
import path from 'path';
import {
CompositeGeneratorNode,
expandToNode,
expandTracedToNode,
findRootNode,
getContainerOfType,
getDocument,
joinToNode,
joinTracedToNode,
LangiumDocument,
NL,
streamAllContents,
toStringAndTrace,
TraceRegion,
traceToNode,
TreeStreamImpl,
URI,
} from 'langium';
import { IdManager } from '../helpers/idManager.js';
import {
getAbstractResults,
getAssignees,
getImportedDeclarations,
getImports,
getModuleMembers,
getStatements,
isRequiredParameter,
Parameter,
streamBlockLambdaResults,
} from '../helpers/nodeProperties.js';
import { groupBy } from '../../helpers/collectionUtils.js';
import { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js';
import {
BooleanConstant,
FloatConstant,
IntConstant,
NullConstant,
StringConstant,
} from '../partialEvaluation/model.js';
import { IdManager } from '../helpers/idManager.js';
import { TextDocument } from 'vscode-languageserver-textdocument';
import { SafeDsAnnotations } from '../builtins/safe-ds-annotations.js';
import { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js';
import { SafeDsPartialEvaluator } from '../partialEvaluation/safe-ds-partial-evaluator.js';
import { SourceMapGenerator, StartOfSourceMap } from 'source-map';
import { SafeDsServices } from '../safe-ds-module.js';

export const CODEGEN_PREFIX = '__gen_';
const BLOCK_LAMBDA_PREFIX = `${CODEGEN_PREFIX}block_lambda_`;
Expand Down Expand Up @@ -685,7 +685,7 @@ export class SafeDsPythonGenerator {
private generateArgument(argument: SdsArgument, frame: GenerationInfoFrame): CompositeGeneratorNode {
const parameter = this.nodeMapper.argumentToParameter(argument);
return expandTracedToNode(argument)`${
parameter !== undefined && !isRequiredParameter(parameter)
parameter !== undefined && !Parameter.isRequired(parameter)
? expandToNode`${this.generateParameter(parameter, frame, false)}=`
: ''
}${this.generateExpression(argument.value, frame)}`;
Expand Down
48 changes: 26 additions & 22 deletions packages/safe-ds-lang/src/language/helpers/nodeProperties.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
isSdsLambda,
isSdsModule,
isSdsModuleMember,
isSdsParameter,
isSdsPlaceholder,
isSdsSegment,
isSdsTypeParameterList,
Expand Down Expand Up @@ -85,28 +86,30 @@ export const isPositionalArgument = (node: SdsArgument): boolean => {
return !node.parameter;
};

export const isNamedTypeArgument = (node: SdsTypeArgument): boolean => {
return Boolean(node.typeParameter);
};
export namespace Parameter {
export const isConstant = (node: SdsParameter | undefined): boolean => {
if (!node) {
return false;
}

export const isConstantParameter = (node: SdsParameter | undefined): boolean => {
if (!node) {
return false;
}
const containingCallable = getContainerOfType(node, isSdsCallable);

const containingCallable = getContainerOfType(node, isSdsCallable);
// In those cases, the const modifier is not applicable
if (isSdsCallableType(containingCallable) || isSdsLambda(containingCallable)) {
return false;
}

// In those cases, the const modifier is not applicable
if (isSdsCallableType(containingCallable) || isSdsLambda(containingCallable)) {
return false;
}
return isSdsAnnotation(containingCallable) || node.isConstant;
};

return isSdsAnnotation(containingCallable) || node.isConstant;
};
export const isOptional = (node: SdsParameter | undefined): boolean => {
return Boolean(node?.defaultValue);
};

export const isRequiredParameter = (node: SdsParameter): boolean => {
return !node.defaultValue;
};
export const isRequired = (node: SdsParameter | undefined): boolean => {
return isSdsParameter(node) && !node.defaultValue;
};
}

export const isStatic = (node: SdsClassMember): boolean => {
if (isSdsClass(node) || isSdsEnum(node)) {
Expand All @@ -121,6 +124,10 @@ export const isStatic = (node: SdsClassMember): boolean => {
}
};

export const isNamedTypeArgument = (node: SdsTypeArgument): boolean => {
return Boolean(node.typeParameter);
};

// -------------------------------------------------------------------------------------------------
// Accessors for list elements
// -------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -190,11 +197,8 @@ export const streamBlockLambdaResults = (node: SdsBlockLambda | undefined): Stre
.filter(isSdsBlockLambdaResult);
};

export const getMatchingClassMembers = (
node: SdsClass | undefined,
filterFunction: (member: SdsClassMember) => boolean = () => true,
): SdsClassMember[] => {
return node?.body?.members?.filter(filterFunction) ?? [];
export const getClassMembers = (node: SdsClass | undefined): SdsClassMember[] => {
return node?.body?.members ?? [];
};

export const getColumns = (node: SdsSchema | undefined): SdsColumn[] => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ import {
getAbstractResults,
getAnnotationCallTarget,
getAssignees,
getClassMembers,
getEnumVariants,
getImportedDeclarations,
getImports,
getMatchingClassMembers,
getPackageName,
getParameters,
getResults,
Expand Down Expand Up @@ -185,7 +185,7 @@ export class SafeDsScopeProvider extends DefaultScopeProvider {
// Static access
const declaration = this.getUniqueReferencedDeclarationForExpression(node.receiver);
if (isSdsClass(declaration)) {
const ownStaticMembers = getMatchingClassMembers(declaration, isStatic);
const ownStaticMembers = getClassMembers(declaration).filter(isStatic);
const superclassStaticMembers = this.classHierarchy.streamSuperclassMembers(declaration).filter(isStatic);

return this.createScopeForNodes(ownStaticMembers, this.createScopeForNodes(superclassStaticMembers));
Expand Down Expand Up @@ -215,7 +215,7 @@ export class SafeDsScopeProvider extends DefaultScopeProvider {
}

if (receiverType instanceof ClassType) {
const ownInstanceMembers = getMatchingClassMembers(receiverType.declaration, (it) => !isStatic(it));
const ownInstanceMembers = getClassMembers(receiverType.declaration).filter((it) => !isStatic(it));
const superclassInstanceMembers = this.classHierarchy
.streamSuperclassMembers(receiverType.declaration)
.filter((it) => !isStatic(it));
Expand Down
7 changes: 6 additions & 1 deletion packages/safe-ds-lang/src/language/typing/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
SdsEnumVariant,
SdsParameter,
} from '../generated/ast.js';
import { Parameter } from '../helpers/nodeProperties.js';
import { Constant, NullConstant } from '../partialEvaluation/model.js';

/**
Expand Down Expand Up @@ -73,7 +74,11 @@ export class CallableType extends Type {
}

override toString(): string {
return `${this.inputType} -> ${this.outputType}`;
const inputTypeString = this.inputType.entries
.map((it) => `${it.name}${Parameter.isOptional(it.declaration) ? '?' : ''}: ${it.type}`)
.join(', ');

return `(${inputTypeString}) -> ${this.outputType}`;
}

override unwrap(): CallableType {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { EMPTY_STREAM, stream, Stream } from 'langium';
import { EMPTY_STREAM, getContainerOfType, stream, Stream } from 'langium';
import { SafeDsClasses } from '../builtins/safe-ds-classes.js';
import { isSdsClass, isSdsNamedType, SdsClass, type SdsClassMember } from '../generated/ast.js';
import { getMatchingClassMembers, getParentTypes } from '../helpers/nodeProperties.js';
import { getClassMembers, getParentTypes, isStatic } from '../helpers/nodeProperties.js';
import { SafeDsServices } from '../safe-ds-module.js';

export class SafeDsClassHierarchy {
Expand Down Expand Up @@ -61,7 +61,7 @@ export class SafeDsClassHierarchy {
return EMPTY_STREAM;
}

return this.streamSuperclasses(node).flatMap(getMatchingClassMembers);
return this.streamSuperclasses(node).flatMap(getClassMembers);
}

/**
Expand All @@ -79,4 +79,31 @@ export class SafeDsClassHierarchy {

return undefined;
}

/**
* Returns the member that is overridden by the given member, or `undefined` if the member does not override
* anything.
*/
getOverriddenMember(node: SdsClassMember | undefined): SdsClassMember | undefined {
// Static members cannot override anything
if (!node || isStatic(node)) {
return undefined;
}

// Don't consider members with the same name as a previous member
const containingClass = getContainerOfType(node, isSdsClass);
if (!containingClass) {
return undefined;
}
const firstMemberWithSameName = getClassMembers(containingClass).find(
(it) => !isStatic(it) && it.name === node.name,
);
if (firstMemberWithSameName !== node) {
return undefined;
}

return this.streamSuperclassMembers(containingClass)
.filter((it) => !isStatic(it) && it.name === node.name)
.head();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { getContainerOfType } from 'langium';
import type { SafeDsClasses } from '../builtins/safe-ds-classes.js';
import { isSdsEnum, type SdsAbstractResult, SdsDeclaration } from '../generated/ast.js';
import { getParameters } from '../helpers/nodeProperties.js';
import { getParameters, Parameter } from '../helpers/nodeProperties.js';
import { Constant } from '../partialEvaluation/model.js';
import { SafeDsServices } from '../safe-ds-module.js';
import {
Expand Down Expand Up @@ -84,6 +84,11 @@ export class SafeDsTypeChecker {
return false;
}

// Optionality must match (all but required to optional is OK)
if (Parameter.isRequired(typeEntry.declaration) && Parameter.isOptional(otherEntry.declaration)) {
return false;
}

// Types must be contravariant
if (!this.isAssignableTo(otherEntry.type, typeEntry.type)) {
return false;
Expand All @@ -93,7 +98,7 @@ export class SafeDsTypeChecker {
// Additional parameters must be optional
for (let i = other.inputType.length; i < type.inputType.length; i++) {
const typeEntry = type.inputType.entries[i]!;
if (!typeEntry.declaration?.defaultValue) {
if (!Parameter.isOptional(typeEntry.declaration)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ValidationAcceptor } from 'langium';
import { DiagnosticTag } from 'vscode-languageserver';
import {
isSdsParameter,
isSdsResult,
Expand All @@ -10,10 +11,9 @@ import {
SdsParameter,
SdsReference,
} from '../../generated/ast.js';
import { Parameter } from '../../helpers/nodeProperties.js';
import { SafeDsServices } from '../../safe-ds-module.js';
import { isRequiredParameter } from '../../helpers/nodeProperties.js';
import { parameterCanBeAnnotated } from '../other/declarations/annotationCalls.js';
import { DiagnosticTag } from 'vscode-languageserver';

export const CODE_DEPRECATED_ASSIGNED_RESULT = 'deprecated/assigned-result';
export const CODE_DEPRECATED_CALLED_ANNOTATION = 'deprecated/called-annotation';
Expand Down Expand Up @@ -108,7 +108,7 @@ export const referenceTargetShouldNotBeDeprecated =

export const requiredParameterMustNotBeDeprecated =
(services: SafeDsServices) => (node: SdsParameter, accept: ValidationAcceptor) => {
if (isRequiredParameter(node) && parameterCanBeAnnotated(node)) {
if (Parameter.isRequired(node) && parameterCanBeAnnotated(node)) {
if (services.builtins.Annotations.isDeprecated(node)) {
accept('error', 'A deprecated parameter must be optional.', {
node,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { ValidationAcceptor } from 'langium';
import { SdsParameter } from '../../generated/ast.js';
import { Parameter } from '../../helpers/nodeProperties.js';
import { SafeDsServices } from '../../safe-ds-module.js';
import { isRequiredParameter } from '../../helpers/nodeProperties.js';
import { parameterCanBeAnnotated } from '../other/declarations/annotationCalls.js';

export const CODE_EXPERT_TARGET_PARAMETER = 'expert/target-parameter';

export const requiredParameterMustNotBeExpert =
(services: SafeDsServices) => (node: SdsParameter, accept: ValidationAcceptor) => {
if (isRequiredParameter(node) && parameterCanBeAnnotated(node)) {
if (Parameter.isRequired(node) && parameterCanBeAnnotated(node)) {
if (services.builtins.Annotations.isExpert(node)) {
accept('error', 'An expert parameter must be optional.', {
node,
Expand Down
39 changes: 36 additions & 3 deletions packages/safe-ds-lang/src/language/validation/inheritance.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,47 @@
import { ValidationAcceptor } from 'langium';
import { SdsClass } from '../generated/ast.js';
import { expandToStringWithNL, ValidationAcceptor } from 'langium';
import { isEmpty } from '../../helpers/collectionUtils.js';
import { SdsClass, type SdsClassMember } from '../generated/ast.js';
import { getParentTypes } from '../helpers/nodeProperties.js';
import { SafeDsServices } from '../safe-ds-module.js';
import { ClassType, UnknownType } from '../typing/model.js';
import { isEmpty } from '../../helpers/collectionUtils.js';

export const CODE_INHERITANCE_CYCLE = 'inheritance/cycle';
export const CODE_INHERITANCE_MULTIPLE_INHERITANCE = 'inheritance/multiple-inheritance';
export const CODE_INHERITANCE_MUST_MATCH_OVERRIDDEN_MEMBER = 'inheritance/must-match-overridden-member';
export const CODE_INHERITANCE_NOT_A_CLASS = 'inheritance/not-a-class';

export const classMemberMustMatchOverriddenMember = (services: SafeDsServices) => {
const classHierarchy = services.types.ClassHierarchy;
const typeChecker = services.types.TypeChecker;
const typeComputer = services.types.TypeComputer;

return (node: SdsClassMember, accept: ValidationAcceptor): void => {
const overriddenMember = classHierarchy.getOverriddenMember(node);
if (!overriddenMember) {
return;
}

const ownMemberType = typeComputer.computeType(node);
const overriddenMemberType = typeComputer.computeType(overriddenMember);

if (!typeChecker.isAssignableTo(ownMemberType, overriddenMemberType)) {
accept(
'error',
expandToStringWithNL`
Overriding member does not match the overridden member:
- Expected type: ${overriddenMemberType}
- Actual type: ${ownMemberType}
`,
{
node,
property: 'name',
code: CODE_INHERITANCE_MUST_MATCH_OVERRIDDEN_MEMBER,
},
);
}
};
};

export const classMustOnlyInheritASingleClass = (services: SafeDsServices) => {
const typeComputer = services.types.TypeComputer;
const computeType = typeComputer.computeType.bind(typeComputer);
Expand Down
Loading