diff --git a/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Scopes.cs b/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Scopes.cs index 7ecf1dee4..5c775f42a 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Scopes.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Scopes.cs @@ -43,15 +43,18 @@ public void DeclareVariable(string name, IMember value, VariableSource source) public void DeclareVariable(string name, IMember value, VariableSource source, IPythonModule module) => DeclareVariable(name, value, source, new Location(module)); - public void DeclareVariable(string name, IMember value, VariableSource source, Node location, bool overwrite = false) + public void DeclareVariable(string name, IMember value, VariableSource source, Node location, bool overwrite = true) => DeclareVariable(name, value, source, GetLocationOfName(location), overwrite); - public void DeclareVariable(string name, IMember value, VariableSource source, Location location, bool overwrite = false) { + public void DeclareVariable(string name, IMember value, VariableSource source, Location location, bool overwrite = true) { + var member = GetInScope(name); + if (member != null && !overwrite) { + return; + } if (source == VariableSource.Import && value is IVariable v) { CurrentScope.LinkVariable(name, v, location); return; } - var member = GetInScope(name); if (member != null) { if (!value.IsUnknown()) { CurrentScope.DeclareVariable(name, value, source, location); diff --git a/src/Analysis/Ast/Impl/Analyzer/Handlers/AssignmentHandler.cs b/src/Analysis/Ast/Impl/Analyzer/Handlers/AssignmentHandler.cs index 6c887ef0f..20f9f9d12 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Handlers/AssignmentHandler.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Handlers/AssignmentHandler.cs @@ -111,7 +111,7 @@ private void TryHandleClassVariable(AssignmentStatement node, IMember value) { var cls = m.GetPythonType(); if (cls != null) { using (Eval.OpenScope(Eval.Module, cls.ClassDefinition, out _)) { - Eval.DeclareVariable(mex.Name, value, VariableSource.Declaration, Eval.GetLocationOfName(mex), true); + Eval.DeclareVariable(mex.Name, value, VariableSource.Declaration, Eval.GetLocationOfName(mex)); } } } diff --git a/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs b/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs index 182f83d5d..50edd0261 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs @@ -59,7 +59,7 @@ private void AssignVariables(FromImportStatement node, IImportSearchResult impor // TODO: warn this is not a good style per // TODO: https://docs.python.org/3/faq/programming.html#what-are-the-best-practices-for-using-import-in-a-module // TODO: warn this is invalid if not in the global scope. - HandleModuleImportStar(variableModule, imports is ImplicitPackageImport); + HandleModuleImportStar(variableModule, imports is ImplicitPackageImport, node.StartIndex); return; } @@ -68,14 +68,16 @@ private void AssignVariables(FromImportStatement node, IImportSearchResult impor if (!string.IsNullOrEmpty(memberName)) { var nameExpression = asNames[i] ?? names[i]; var variableName = nameExpression?.Name ?? memberName; - var exported = variableModule.Analysis?.GlobalScope.Variables[memberName] ?? variableModule.GetMember(memberName); + var variable = variableModule.Analysis?.GlobalScope?.Variables[memberName]; + var exported = variable ?? variableModule.GetMember(memberName); var value = exported ?? GetValueFromImports(variableModule, imports as IImportChildrenSource, memberName); - Eval.DeclareVariable(variableName, value, VariableSource.Import, nameExpression); + // Do not allow imported variables to override local declarations + Eval.DeclareVariable(variableName, value, VariableSource.Import, nameExpression, CanOverwriteVariable(variableName, node.StartIndex)); } } } - private void HandleModuleImportStar(PythonVariableModule variableModule, bool isImplicitPackage) { + private void HandleModuleImportStar(PythonVariableModule variableModule, bool isImplicitPackage, int importPosition) { if (variableModule.Module == Module) { // from self import * won't define any new members return; @@ -100,10 +102,31 @@ private void HandleModuleImportStar(PythonVariableModule variableModule, bool is } var variable = variableModule.Analysis?.GlobalScope?.Variables[memberName]; - Eval.DeclareVariable(memberName, variable ?? member, VariableSource.Import); + // Do not allow imported variables to override local declarations + Eval.DeclareVariable(memberName, variable ?? member, VariableSource.Import, Eval.DefaultLocation, CanOverwriteVariable(memberName, importPosition)); } } + private bool CanOverwriteVariable(string name, int importPosition) { + var v = Eval.CurrentScope.Variables[name]; + if(v == null) { + return true; // Variable does not exist + } + // Allow overwrite if import is below the variable. Consider + // x = 1 + // x = 2 + // from A import * # brings another x + // x = 3 + var references = v.References.Where(r => r.DocumentUri == Module.Uri).ToArray(); + if(references.Length == 0) { + // No references to the variable in this file - the variable + // is imported from another module. OK to overwrite. + return true; + } + var firstAssignmentPosition = references.Min(r => r.Span.ToIndexSpan(Ast).Start); + return firstAssignmentPosition < importPosition; + } + private IMember GetValueFromImports(PythonVariableModule parentModule, IImportChildrenSource childrenSource, string memberName) { if (childrenSource == null || !childrenSource.TryGetChildImport(memberName, out var childImport)) { return Interpreter.UnknownType; diff --git a/src/Analysis/Ast/Test/ImportTests.cs b/src/Analysis/Ast/Test/ImportTests.cs index 43da7a3af..5ce08f316 100644 --- a/src/Analysis/Ast/Test/ImportTests.cs +++ b/src/Analysis/Ast/Test/ImportTests.cs @@ -231,5 +231,19 @@ public async Task PreferTypeToAny() { var a = analysis.Should().HaveClass("A").Which; a.GetMember("x").Should().HaveType(BuiltinTypeId.Int); } + + [TestMethod, Priority(0)] + public async Task StarImportDoesNotOverwriteFunction() { + const string code = @" +from sys import * + +def exit(): + return 1234 + +x = exit() +"; + var analysis = await GetAnalysisAsync(code); + analysis.Should().HaveVariable("x").OfType(BuiltinTypeId.Int); + } } }