diff --git a/grumpy-runtime-src/runtime/module.go b/grumpy-runtime-src/runtime/module.go index 5711591b..4c5a5b55 100644 --- a/grumpy-runtime-src/runtime/module.go +++ b/grumpy-runtime-src/runtime/module.go @@ -178,6 +178,58 @@ func importOne(f *Frame, name string) (*Object, *BaseException) { return o, nil } +// LoadMembers scans over all the members in module +// and populates globals with them, taking __all__ into +// account. +func LoadMembers(f *Frame, module *Object) *BaseException { + allAttr, raised := GetAttr(f, module, NewStr("__all__"), nil) + if raised != nil && !raised.isInstance(AttributeErrorType) { + return raised + } + f.RestoreExc(nil, nil) + + if raised == nil { + raised = loadMembersFromIterable(f, module, allAttr, nil) + if raised != nil { + return raised + } + return nil + } + + // Fall back on __dict__ + dictAttr := module.dict.ToObject() + raised = loadMembersFromIterable(f, module, dictAttr, func(key *Object) bool { + return strings.HasPrefix(toStrUnsafe(key).value, "_") + }) + if raised != nil { + return raised + } + return nil +} + +func loadMembersFromIterable(f *Frame, module, iterable *Object, filterF func(*Object) bool) *BaseException { + globals := f.Globals() + raised := seqForEach(f, iterable, func(memberName *Object) *BaseException { + if !memberName.isInstance(StrType) { + errorMessage := fmt.Sprintf("attribute name must be string, not '%v'", memberName.typ.Name()) + return f.RaiseType(AttributeErrorType, errorMessage) + } + member, raised := GetAttr(f, module, toStrUnsafe(memberName), nil) + if raised != nil { + return raised + } + if filterF != nil && filterF(memberName) { + return nil + } + raised = globals.SetItem(f, memberName, member) + if raised != nil { + return raised + } + return nil + }) + return raised +} + // newModule creates a new Module object with the given fully qualified name // (e.g a.b.c) and its corresponding Python filename and package. func newModule(name, filename string) *Module { diff --git a/grumpy-runtime-src/runtime/module_test.go b/grumpy-runtime-src/runtime/module_test.go index 2edc4bbe..3e89a711 100644 --- a/grumpy-runtime-src/runtime/module_test.go +++ b/grumpy-runtime-src/runtime/module_test.go @@ -184,7 +184,6 @@ func TestImportModule(t *testing.T) { } } } - func TestModuleGetNameAndFilename(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, m *Module) (*Tuple, *BaseException) { name, raised := m.GetName(f) diff --git a/grumpy-tools-src/grumpy_tools/compiler/imputil.py b/grumpy-tools-src/grumpy_tools/compiler/imputil.py index b7a9fb89..f045338e 100644 --- a/grumpy-tools-src/grumpy_tools/compiler/imputil.py +++ b/grumpy-tools-src/grumpy_tools/compiler/imputil.py @@ -46,6 +46,7 @@ class Import(object): MODULE = "" MEMBER = "" + STAR = "" def __init__(self, name, script=None, is_native=False): self.name = name @@ -104,7 +105,14 @@ def visit_Import(self, node): def visit_ImportFrom(self, node): if any(a.name == '*' for a in node.names): - raise util.ImportError(node, 'wildcard member import is not implemented') + if len(node.names) != 1: + # TODO: Change to SyntaxError, as CPython does on "from foo import *, bar" + raise util.ImportError(node, 'invalid syntax on wildcard import') + + # Imported name is * (star). Will bind __all__ the module contents. + imp = self._resolve_import(node, node.module) + imp.add_binding(Import.STAR, '*', imp.name.count('.')) + return [imp] if not node.level and node.module == '__future__': return [] diff --git a/grumpy-tools-src/grumpy_tools/compiler/imputil_test.py b/grumpy-tools-src/grumpy_tools/compiler/imputil_test.py index c3cd6aa6..80e06fea 100644 --- a/grumpy-tools-src/grumpy_tools/compiler/imputil_test.py +++ b/grumpy-tools-src/grumpy_tools/compiler/imputil_test.py @@ -174,9 +174,8 @@ def testImportFromAsMembers(self): imp.add_binding(imputil.Import.MEMBER, 'baz', 'bar') self._check_imports('from foo import bar as baz', [imp]) - def testImportFromWildcardRaises(self): - self.assertRaises(util.ImportError, self.importer.visit, - pythonparser.parse('from foo import *').body[0]) + # def testImportFromWildcardRaises(self): + # self._check_imports('from foo import *', []) def testImportFromFuture(self): self._check_imports('from __future__ import print_function', []) diff --git a/grumpy-tools-src/grumpy_tools/compiler/stmt.py b/grumpy-tools-src/grumpy_tools/compiler/stmt.py index c02f2664..dfa79c20 100644 --- a/grumpy-tools-src/grumpy_tools/compiler/stmt.py +++ b/grumpy-tools-src/grumpy_tools/compiler/stmt.py @@ -629,6 +629,8 @@ def _import_and_bind(self, imp): self.writer.write('{} = {}[{}]'.format( mod.name, mod_slice.expr, binding.value)) self.block.bind_var(self.writer, binding.alias, mod.expr) + elif binding.bind_type == imputil.Import.STAR: + self.writer.write_checked_call1('πg.LoadMembers(πF, {}[0])', mod_slice.name) else: self.writer.write('{} = {}[{}]'.format( mod.name, mod_slice.expr, imp.name.count('.'))) diff --git a/grumpy-tools-src/grumpy_tools/compiler/stmt_test.py b/grumpy-tools-src/grumpy_tools/compiler/stmt_test.py index aa766823..edeb1271 100644 --- a/grumpy-tools-src/grumpy_tools/compiler/stmt_test.py +++ b/grumpy-tools-src/grumpy_tools/compiler/stmt_test.py @@ -333,19 +333,19 @@ def testImportNativeType(self): from "__go__/time" import Duration print Duration"""))) - def testImportWildcardMemberRaises(self): - regexp = 'wildcard member import is not implemented' - self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, - 'from foo import *') - self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, - 'from "__go__/foo" import *') - def testPrintStatement(self): self.assertEqual((0, 'abc 123\nfoo bar\n'), _GrumpRun(textwrap.dedent("""\ print 'abc', print '123' print 'foo', 'bar'"""))) + def testImportWildcard(self): + result = _GrumpRun(textwrap.dedent("""\ + from time import * + print sleep""")) + self.assertEqual(0, result[0]) + self.assertIn('