Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Commit b9a0c8a

Browse files
author
Dylan Trotter
committed
Move import logic into util for reuse
Import logic is currently embedded in the StatementVisitor class which makes it difficult to reuse. This change factors that logic out into an ImportVisitor class in util.py. The newly exposed logic can now be used by other tools.
1 parent de7e668 commit b9a0c8a

File tree

5 files changed

+221
-73
lines changed

5 files changed

+221
-73
lines changed

compiler/stmt.py

Lines changed: 49 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from grumpy.compiler import util
3131

3232

33-
_NATIVE_MODULE_PREFIX = '__go__.'
3433
_NATIVE_TYPE_PREFIX = 'type_'
3534

3635
# Partial list of known vcs for go module import
@@ -371,58 +370,39 @@ def visit_If(self, node):
371370

372371
def visit_Import(self, node):
373372
self._write_py_context(node.lineno)
374-
for alias in node.names:
375-
if alias.name.startswith(_NATIVE_MODULE_PREFIX):
376-
raise util.ParseError(
377-
node, 'for native imports use "from __go__.xyz import ..." syntax')
378-
with self._import(alias.name, 0) as mod:
379-
asname = alias.asname or alias.name.split('.')[0]
380-
self.block.bind_var(self.writer, asname, mod.expr)
373+
for imp in util.ImportVisitor().visit(node):
374+
self._import_and_bind(imp)
381375

382376
def visit_ImportFrom(self, node):
383-
# Wildcard imports are not yet supported.
384-
for alias in node.names:
385-
if alias.name == '*':
386-
msg = 'wildcard member import is not implemented: from %s import %s' % (
387-
node.module, alias.name)
388-
raise util.ParseError(node, msg)
389377
self._write_py_context(node.lineno)
390-
if node.module.startswith(_NATIVE_MODULE_PREFIX):
391-
values = [alias.name for alias in node.names]
392-
with self._import_native(node.module, values) as mod:
393-
for alias in node.names:
394-
# Strip the 'type_' prefix when populating the module. This means
395-
# that, e.g. 'from __go__.foo import type_Bar' will populate foo with
396-
# a member called Bar, not type_Bar (although the symbol in the
397-
# importing module will still be type_Bar unless aliased). This bends
398-
# the semantics of import but makes native module contents more
399-
# sensible.
400-
name = alias.name
401-
if name.startswith(_NATIVE_TYPE_PREFIX):
402-
name = name[len(_NATIVE_TYPE_PREFIX):]
403-
with self.block.alloc_temp() as member:
404-
self.writer.write_checked_call2(
405-
member, 'πg.GetAttr(πF, {}, {}, nil)',
406-
mod.expr, self.block.root.intern(name))
407-
self.block.bind_var(
408-
self.writer, alias.asname or alias.name, member.expr)
409-
elif node.module == '__future__':
410-
# At this stage all future imports are done in an initial pass (see
411-
# visit() above), so if they are encountered here after the last valid
412-
# __future__ then it's a syntax error.
413-
if node.lineno > self.future_features.future_lineno:
414-
raise util.ParseError(node, late_future)
415-
else:
416-
# NOTE: Assume that the names being imported are all modules within a
417-
# package. E.g. "from a.b import c" is importing the module c from package
418-
# a.b, not some member of module b. We cannot distinguish between these
419-
# two cases at compile time and the Google style guide forbids the latter
420-
# so we support that use case only.
421-
for alias in node.names:
422-
name = '{}.{}'.format(node.module, alias.name)
423-
with self._import(name, name.count('.')) as mod:
424-
asname = alias.asname or alias.name
425-
self.block.bind_var(self.writer, asname, mod.expr)
378+
for imp in util.ImportVisitor().visit(node):
379+
if imp.is_native:
380+
values = [b.value for b in imp.bindings]
381+
with self._import_native(imp.name, values) as mod:
382+
for binding in imp.bindings:
383+
# Strip the 'type_' prefix when populating the module. This means
384+
# that, e.g. 'from __go__.foo import type_Bar' will populate foo
385+
# with a member called Bar, not type_Bar (although the symbol in
386+
# the importing module will still be type_Bar unless aliased). This
387+
# bends the semantics of import but makes native module contents
388+
# more sensible.
389+
name = binding.value
390+
if name.startswith(_NATIVE_TYPE_PREFIX):
391+
name = name[len(_NATIVE_TYPE_PREFIX):]
392+
with self.block.alloc_temp() as member:
393+
self.writer.write_checked_call2(
394+
member, 'πg.GetAttr(πF, {}, {}, nil)',
395+
mod.expr, self.block.root.intern(name))
396+
self.block.bind_var(
397+
self.writer, binding.alias, member.expr)
398+
elif node.module == '__future__':
399+
# At this stage all future imports are done in an initial pass (see
400+
# visit() above), so if they are encountered here after the last valid
401+
# __future__ then it's a syntax error.
402+
if node.lineno > self.future_features.future_lineno:
403+
raise util.ImportError(node, late_future)
404+
else:
405+
self._import_and_bind(imp)
426406

427407
def visit_Module(self, node):
428408
self._visit_each(node.body)
@@ -681,18 +661,14 @@ def _build_assign_target(self, target, assigns):
681661
tmpl = 'πg.TieTarget{Target: &$temp}'
682662
return string.Template(tmpl).substitute(temp=temp.name)
683663

684-
def _import(self, name, index):
685-
"""Returns an expression for a Module object returned from ImportModule.
664+
def _import_and_bind(self, imp):
665+
"""Generates code that imports a module and binds it to a variable.
686666
687667
Args:
688-
name: The fully qualified Python module name, e.g. foo.bar.
689-
index: The element in the list of modules that this expression should
690-
select. E.g. for 'foo.bar', 0 corresponds to the package foo and 1
691-
corresponds to the module bar.
692-
Returns:
693-
A Go expression evaluating to an *Object (upcast from a *Module.)
668+
imp: Import object representing an import of the form "import x.y.z" or
669+
"from x.y import z". Expects only a single binding.
694670
"""
695-
parts = name.split('.')
671+
parts = imp.name.split('.')
696672
code_objs = []
697673
for i in xrange(len(parts)):
698674
package_name = '/'.join(parts[:i + 1])
@@ -701,27 +677,33 @@ def _import(self, name, index):
701677
code_objs.append('{}.Code'.format(package.alias))
702678
else:
703679
code_objs.append('Code')
704-
mod = self.block.alloc_temp()
705-
with self.block.alloc_temp('[]*πg.Object') as mod_slice:
680+
with self.block.alloc_temp() as mod, \
681+
self.block.alloc_temp('[]*πg.Object') as mod_slice:
706682
handles_expr = '[]*πg.Code{' + ', '.join(code_objs) + '}'
707683
self.writer.write_checked_call2(
708684
mod_slice, 'πg.ImportModule(πF, {}, {})',
709-
util.go_str(name), handles_expr)
685+
util.go_str(imp.name), handles_expr)
686+
# This method only handles simple module imports (i.e. not member
687+
# imports) which always have a single binding.
688+
binding = imp.bindings[0]
689+
if binding.value == util.Import.ROOT:
690+
index = 0
691+
else:
692+
index = len(parts) - 1
710693
self.writer.write('{} = {}[{}]'.format(mod.name, mod_slice.expr, index))
711-
return mod
694+
self.block.bind_var(self.writer, binding.alias, mod.expr)
712695

713696
def _import_native(self, name, values):
714697
reflect_package = self.block.root.add_native_import('reflect')
715-
import_name = name[len(_NATIVE_MODULE_PREFIX):]
716698
# Work-around for importing go module from VCS
717699
# TODO: support bzr|git|hg|svn from any server
718700
package_name = None
719701
for x in _KNOWN_VCS:
720-
if import_name.startswith(x):
721-
package_name = x + import_name[len(x):].replace('.', '/')
702+
if name.startswith(x):
703+
package_name = x + name[len(x):].replace('.', '/')
722704
break
723705
if not package_name:
724-
package_name = import_name.replace('.', '/')
706+
package_name = name.replace('.', '/')
725707

726708
package = self.block.root.add_native_import(package_name)
727709
mod = self.block.alloc_temp()

compiler/stmt_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def testImportGrump(self):
313313

314314
def testImportNativeModuleRaises(self):
315315
regexp = r'for native imports use "from __go__\.xyz import \.\.\." syntax'
316-
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
316+
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
317317
'import __go__.foo')
318318

319319
def testImportNativeType(self):
@@ -368,11 +368,11 @@ def testImportFromFutureParseError(self):
368368

369369
def testImportWildcardMemberRaises(self):
370370
regexp = r'wildcard member import is not implemented: from foo import *'
371-
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
371+
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
372372
'from foo import *')
373373
regexp = (r'wildcard member import is not '
374374
r'implemented: from __go__.foo import *')
375-
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
375+
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
376376
'from __go__.foo import *')
377377

378378
def testVisitFuture(self):

compiler/util.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
from __future__ import unicode_literals
2020

2121
import codecs
22+
import collections
2223
import contextlib
2324
import cStringIO
2425
import string
2526
import StringIO
2627
import textwrap
2728

29+
from pythonparser import algorithm
30+
2831

2932
_SIMPLE_CHARS = set(string.digits + string.letters + string.punctuation + " ")
3033
_ESCAPES = {'\t': r'\t', '\r': r'\r', '\n': r'\n', '"': r'\"', '\\': r'\\'}
@@ -34,13 +37,96 @@
3437
# This should match the number of specializations found in tuple.go.
3538
MAX_DIRECT_TUPLE = 6
3639

40+
_NATIVE_MODULE_PREFIX = '__go__.'
41+
3742

38-
class ParseError(Exception):
43+
class CompileError(Exception):
3944

4045
def __init__(self, node, msg):
4146
if hasattr(node, 'lineno'):
4247
msg = 'line {}: {}'.format(node.lineno, msg)
43-
super(ParseError, self).__init__(msg)
48+
super(CompileError, self).__init__(msg)
49+
50+
51+
class ParseError(CompileError):
52+
pass
53+
54+
55+
class ImportError(CompileError): # pylint: disable=redefined-builtin
56+
pass
57+
58+
59+
class Import(object):
60+
"""Represents a single module import and all its associated bindings.
61+
62+
Each import pertains to a single module that is imported. Thus one import
63+
statement may produce multiple Import objects. E.g. "import foo, bar" makes
64+
an Import object for module foo and another one for module bar.
65+
"""
66+
67+
Binding = collections.namedtuple('Binding', ('bind_type', 'alias', 'value'))
68+
69+
MODULE = "<BindType 'module'>"
70+
MEMBER = "<BindType 'member'>"
71+
72+
ROOT = "<BindValue 'root'>"
73+
LEAF = "<BindValue 'leaf'>"
74+
75+
def __init__(self, name, is_native=False):
76+
self.name = name
77+
self.is_native = is_native
78+
self.bindings = []
79+
80+
def add_binding(self, bind_type, alias, value):
81+
self.bindings.append(Import.Binding(bind_type, alias, value))
82+
83+
84+
class ImportVisitor(algorithm.Visitor):
85+
"""Visits import nodes and produces corresponding Import objects."""
86+
87+
# pylint: disable=invalid-name,missing-docstring,no-init
88+
89+
def visit_Import(self, node):
90+
imports = []
91+
for alias in node.names:
92+
if alias.name.startswith(_NATIVE_MODULE_PREFIX):
93+
raise ImportError(
94+
node, 'for native imports use "from __go__.xyz import ..." syntax')
95+
imp = Import(alias.name)
96+
if alias.asname:
97+
imp.add_binding(Import.MODULE, alias.asname, Import.LEAF)
98+
else:
99+
imp.add_binding(Import.MODULE, alias.name.split('.')[-1], Import.ROOT)
100+
imports.append(imp)
101+
return imports
102+
103+
def visit_ImportFrom(self, node):
104+
if any(a.name == '*' for a in node.names):
105+
msg = 'wildcard member import is not implemented: from %s import *' % (
106+
node.module)
107+
raise ImportError(node, msg)
108+
109+
if node.module == '__future__':
110+
return []
111+
112+
if node.module.startswith(_NATIVE_MODULE_PREFIX):
113+
imp = Import(node.module[len(_NATIVE_MODULE_PREFIX):], is_native=True)
114+
for alias in node.names:
115+
asname = alias.asname or alias.name
116+
imp.add_binding(Import.MEMBER, asname, alias.name)
117+
return [imp]
118+
119+
# NOTE: Assume that the names being imported are all modules within a
120+
# package. E.g. "from a.b import c" is importing the module c from package
121+
# a.b, not some member of module b. We cannot distinguish between these
122+
# two cases at compile time and the Google style guide forbids the latter
123+
# so we support that use case only.
124+
imports = []
125+
for alias in node.names:
126+
imp = Import('{}.{}'.format(node.module, alias.name))
127+
imp.add_binding(Import.MODULE, alias.asname or alias.name, Import.LEAF)
128+
imports.append(imp)
129+
return imports
44130

45131

46132
class Writer(object):

compiler/util_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,91 @@
2020

2121
import unittest
2222

23+
import pythonparser
24+
2325
from grumpy.compiler import block
2426
from grumpy.compiler import util
2527
from grumpy.compiler import stmt
2628

2729

30+
class ImportVisitorTest(unittest.TestCase):
31+
32+
def testImport(self):
33+
imp = util.Import('foo')
34+
imp.add_binding(util.Import.MODULE, 'foo', util.Import.ROOT)
35+
self._assert_imports_equal(imp, self._visit_import('import foo'))
36+
37+
def testImportMultiple(self):
38+
imp1 = util.Import('foo')
39+
imp1.add_binding(util.Import.MODULE, 'foo', util.Import.ROOT)
40+
imp2 = util.Import('bar')
41+
imp2.add_binding(util.Import.MODULE, 'bar', util.Import.ROOT)
42+
self._assert_imports_equal(
43+
[imp1, imp2], self._visit_import('import foo, bar'))
44+
45+
def testImportAs(self):
46+
imp = util.Import('foo')
47+
imp.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
48+
self._assert_imports_equal(imp, self._visit_import('import foo as bar'))
49+
50+
def testImportNativeRaises(self):
51+
self.assertRaises(util.ImportError, self._visit_import, 'import __go__.fmt')
52+
53+
def testImportFrom(self):
54+
imp = util.Import('foo.bar')
55+
imp.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
56+
self._assert_imports_equal(imp, self._visit_import('from foo import bar'))
57+
58+
def testImportFromMultiple(self):
59+
imp1 = util.Import('foo.bar')
60+
imp1.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
61+
imp2 = util.Import('foo.baz')
62+
imp2.add_binding(util.Import.MODULE, 'baz', util.Import.LEAF)
63+
self._assert_imports_equal(
64+
[imp1, imp2], self._visit_import('from foo import bar, baz'))
65+
66+
def testImportFromAs(self):
67+
imp = util.Import('foo.bar')
68+
imp.add_binding(util.Import.MODULE, 'baz', util.Import.LEAF)
69+
self._assert_imports_equal(
70+
imp, self._visit_import('from foo import bar as baz'))
71+
72+
def testImportFromWildcardRaises(self):
73+
self.assertRaises(util.ImportError, self._visit_import, 'from foo import *')
74+
75+
def testImportFromFuture(self):
76+
result = self._visit_import('from __future__ import print_function')
77+
self.assertEqual([], result)
78+
79+
def testImportFromNative(self):
80+
imp = util.Import('fmt', is_native=True)
81+
imp.add_binding(util.Import.MEMBER, 'Printf', 'Printf')
82+
self._assert_imports_equal(
83+
imp, self._visit_import('from __go__.fmt import Printf'))
84+
85+
def testImportFromNativeMultiple(self):
86+
imp = util.Import('fmt', is_native=True)
87+
imp.add_binding(util.Import.MEMBER, 'Printf', 'Printf')
88+
imp.add_binding(util.Import.MEMBER, 'Println', 'Println')
89+
self._assert_imports_equal(
90+
imp, self._visit_import('from __go__.fmt import Printf, Println'))
91+
92+
def testImportFromNativeAs(self):
93+
imp = util.Import('fmt', is_native=True)
94+
imp.add_binding(util.Import.MEMBER, 'foo', 'Printf')
95+
self._assert_imports_equal(
96+
imp, self._visit_import('from __go__.fmt import Printf as foo'))
97+
98+
def _visit_import(self, source):
99+
return util.ImportVisitor().visit(pythonparser.parse(source).body[0])
100+
101+
def _assert_imports_equal(self, want, got):
102+
if isinstance(want, util.Import):
103+
want = [want]
104+
self.assertEqual([imp.__dict__ for imp in want],
105+
[imp.__dict__ for imp in got])
106+
107+
28108
class WriterTest(unittest.TestCase):
29109

30110
def testIndentBlock(self):

0 commit comments

Comments
 (0)