From b3bcd67a68f4be507c63d578ff2ac2af4c110769 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Tue, 14 Jan 2025 21:22:37 +0100 Subject: [PATCH 01/15] Add module autocomplete to PyREPL --- Lib/_pyrepl/completing_reader.py | 4 + Lib/_pyrepl/readline.py | 332 ++++++++++++++++++++++++++++ Lib/test/test_pyrepl/test_pyrepl.py | 194 +++++++++++++++- 3 files changed, 529 insertions(+), 1 deletion(-) diff --git a/Lib/_pyrepl/completing_reader.py b/Lib/_pyrepl/completing_reader.py index e856bb9807c7f6..09e00cc762334d 100644 --- a/Lib/_pyrepl/completing_reader.py +++ b/Lib/_pyrepl/completing_reader.py @@ -288,3 +288,7 @@ def get_stem(self) -> str: def get_completions(self, stem: str) -> list[str]: return [] + + def get_line(self): + """Return the current line until the cursor position.""" + return ''.join(self.buffer[:self.pos]) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 888185eb03be66..e1a2d400494617 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -28,8 +28,14 @@ from __future__ import annotations +import importlib +import pkgutil +import tokenize import warnings +from io import StringIO +from contextlib import contextmanager from dataclasses import dataclass, field +from tokenize import TokenInfo import os from site import gethistoryfile # type: ignore[attr-defined] @@ -59,6 +65,7 @@ if TYPE_CHECKING: from typing import Any, Mapping + from types import ModuleType MoreLinesCallable = Callable[[str], bool] @@ -132,6 +139,8 @@ def get_stem(self) -> str: return "".join(b[p + 1 : self.pos]) def get_completions(self, stem: str) -> list[str]: + if module_completions := self.get_module_completions(): + return module_completions if len(stem) == 0 and self.more_lines is not None: b = self.buffer p = self.pos @@ -161,6 +170,11 @@ def get_completions(self, stem: str) -> list[str]: result.sort() return result + def get_module_completions(self) -> list[str]: + completer = ModuleCompleter(namespace={'__package__': '_pyrepl'}) # TODO: namespace? + line = self.get_line() + return completer.get_completions(line) + def get_trimmed_history(self, maxlength: int) -> list[str]: if maxlength >= 0: cut = len(self.history) - maxlength @@ -596,3 +610,321 @@ def _setup(namespace: Mapping[str, Any]) -> None: raw_input: Callable[[object], str] | None = None + + +class ModuleCompleter: + """A completer for Python import statements. + + Examples: + - import + - import foo + - import foo. + - import foo as bar, baz + + - from + - from foo + - from foo import + - from foo import bar + - from foo import (bar as baz, qux + """ + + def __init__(self, namespace: Mapping[str, Any] | None = None): + self.namespace = namespace or {} + self._global_cache: list[str] = [] + self._curr_sys_path: list[str] = sys.path[:] + + def get_completions(self, line: str) -> list[str]: + """Return the next possible import completions for 'line'.""" + + parser = ImportParser(line) + if not (result := parser.parse()): + return [] + return self.complete(*result) + + def complete(self, from_name: str | None, name: str | None) -> list[str]: + # import x.y.z + if from_name is None: + if not name: + return [] + return self.complete_import(name) + + # from x.y.z + if name is None: + if not from_name: + return [] + return self.complete_import(from_name) + + # from x.y import z + if not (module := self.import_module(from_name)): + return [] + + submodules = self.filter_submodules(module, name) + attributes = self.filter_attributes(module, name) + return list(set(submodules + attributes)) + + def complete_import(self, name: str) -> list[str]: + is_relative = name.startswith('.') + path, prefix = self.get_path_and_prefix(name) + + if not is_relative and not path: + return [name for name in self.global_cache if name.startswith(prefix)] + + if not (module := self.import_module(path)): + return [] + + submodules = self.filter_submodules(module, prefix) + if not is_relative: + return [f'{path}.{name}' for name in submodules] + return [f'.{name}' for name in submodules] + + def import_module(self, path: str) -> ModuleType | None: + package = self.namespace.get('__package__') + is_relative = path.startswith('.') + if is_relative and not package: + return None + try: + module = importlib.import_module( + path, + package=package if is_relative else None) + except ImportError: + return None + return module + + def filter_submodules(self, module: ModuleType, prefix: str) -> list[str]: + if not hasattr(module, '__path__'): + return [] + return [name for _, name, _ in pkgutil.iter_modules(module.__path__) + if name.startswith(prefix)] + + def filter_attributes(self, module: ModuleType, prefix: str) -> list[str]: + return [attr for attr in module.__dict__ if attr.startswith(prefix)] + + def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: + if '.' not in dotted_name: + return '', dotted_name + if dotted_name.startswith('.'): + stripped = dotted_name.lstrip('.') + dots = '.' * (len(dotted_name) - len(stripped)) + if '.' not in stripped: + return dots, stripped + path, prefix = stripped.rsplit('.', 1) + return dots + path, prefix + path, prefix = dotted_name.rsplit('.', 1) + return path, prefix + + @property + def global_cache(self) -> list[str]: + if not self._global_cache or self._curr_sys_path != sys.path: + self._curr_sys_path = sys.path[:] + self._global_cache = [ + name for _, name, _ in pkgutil.iter_modules()] + return self._global_cache + + +class ImportParser: + """ + Parses incomplete import statements that are + suitable for autocomplete suggestions. + + Examples: + - import foo -> Result(from_name=None, name='foo') + - import foo. -> Result(from_name=None, name='foo.') + - from foo -> Result(from_name='foo', name=None) + - from foo import bar -> Result(from_name='foo', name='bar') + - from .foo import ( -> Result(from_name='.foo', name='') + + Note that the parser works in reverse order, starting from the + last token in the input string. This makes the parser more robust + when parsing multiple statements. + """ + _ignored_tokens = { + tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, + tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER + } + _keywords = {'import', 'from', 'as'} + + def __init__(self, code: str): + self.code = code + tokens = [] + try: + for t in tokenize.generate_tokens(StringIO(code).readline): + if t.type not in self._ignored_tokens: + tokens.append(t) + except tokenize.TokenError as e: + if 'unexpected EOF' not in str(e): + # unexpected EOF is fine, since we're parsing an + # incomplete statement, but other errors are not + # because we may not have all the tokens so it's + # safer to bail out + tokens = [] + except SyntaxError: + tokens = [] + self.tokens = TokenQueue(tokens[::-1]) + + def parse(self): + if not (res := self._parse()): + return None + return res.from_name, res.name + + def _parse(self): + with self.tokens.save_state(): + return self.parse_from_import() + with self.tokens.save_state(): + return self.parse_import() + + def parse_import(self): + if self.code.rstrip().endswith('import') and self.code.endswith(' '): + return Result(name='') + if self.tokens.peek_string(','): + name = '' + else: + if self.code.endswith(' '): + raise ParseError('parse_import') + name = self.parse_dotted_name() + if name.startswith('.'): + raise ParseError('parse_import') + while self.tokens.peek_string(','): + self.tokens.pop() + self.parse_dotted_as_name() + if self.tokens.peek_string('import'): + return Result(name=name) + raise ParseError('parse_import') + + def parse_from_import(self): + if self.code.rstrip().endswith('import') and self.code.endswith(' '): + return Result(from_name=self.parse_empty_from_import(), name='') + if self.code.rstrip().endswith('from') and self.code.endswith(' '): + return Result(from_name='') + if self.tokens.peek_string('(') or self.tokens.peek_string(','): + return Result(from_name=self.parse_empty_from_import(), name='') + if self.code.endswith(' '): + raise ParseError('parse_from_import') + name = self.parse_dotted_name() + if '.' in name: + self.tokens.pop_string('from') + return Result(from_name=name) + if self.tokens.peek_string('from'): + return Result(from_name=name) + from_name = self.parse_empty_from_import() + return Result(from_name=from_name, name=name) + + def parse_empty_from_import(self): + if self.tokens.peek_string(','): + self.tokens.pop() + self.parse_as_names() + if self.tokens.peek_string('('): + self.tokens.pop() + self.tokens.pop_string('import') + return self.parse_from() + + def parse_from(self): + from_name = self.parse_dotted_name() + self.tokens.pop_string('from') + return from_name + + def parse_dotted_as_name(self): + self.tokens.pop_name() + if self.tokens.peek_string('as'): + self.tokens.pop() + with self.tokens.save_state(): + return self.parse_dotted_name() + + def parse_dotted_name(self): + name = [] + if self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + if self.tokens.peek_name() and self.tokens.peek().string not in self._keywords: + name.append(self.tokens.pop_name()) + if not name: + raise ParseError('parse_dotted_name') + while self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + if self.tokens.peek_name() and self.tokens.peek().string not in self._keywords: + name.append(self.tokens.pop_name()) + else: + break + + while self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + return ''.join(name[::-1]) + + def parse_as_names(self): + self.parse_as_name() + while self.tokens.peek_string(','): + self.tokens.pop() + self.parse_as_name() + + def parse_as_name(self): + self.tokens.pop_name() + if self.tokens.peek_string('as'): + self.tokens.pop() + self.tokens.pop_name() + + +class ParseError(Exception): + pass + + +@dataclass(frozen=True) +class Result: + from_name: str | None = None + name: str | None = None + + +class TokenQueue: + """Provides helper functions for working with a sequence of tokens.""" + + def __init__(self, tokens: list[TokenInfo]) -> None: + self.tokens: list[TokenInfo] = tokens + self.index: int = 0 + self.stack: list[int] = [] + + @contextmanager + def save_state(self): + try: + self.stack.append(self.index) + yield + except ParseError: + self.index = self.stack.pop() + else: + self.stack.pop() + + def __bool__(self): + return self.index < len(self.tokens) + + def peek(self) -> TokenInfo | None: + if not self: + return None + return self.tokens[self.index] + + def peek_name(self) -> bool: + if not (tok := self.peek()): + return False + return tok.type == tokenize.NAME + + def pop_name(self) -> str: + tok = self.pop() + if tok.type != tokenize.NAME: + raise ParseError('pop_name') + return tok.string + + def peek_string(self, string: str) -> bool: + if not (tok := self.peek()): + return False + return tok.string == string + + def pop_string(self, string: str) -> str: + tok = self.pop() + if tok.string != string: + raise ParseError('pop_string') + return tok.string + + def pop(self) -> TokenInfo: + if not self: + raise ParseError('pop') + tok = self.tokens[self.index] + self.index += 1 + return tok diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index f29a7ffbd7cafd..1d46aefd43c7bf 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -27,7 +27,7 @@ ) from _pyrepl.console import Event from _pyrepl.readline import (ReadlineAlikeReader, ReadlineConfig, - _ReadlineWrapper) + _ReadlineWrapper, ImportParser, ModuleCompleter) from _pyrepl.readline import multiline_input as readline_multiline_input try: @@ -896,6 +896,198 @@ def test_func(self): self.assertEqual(mock_stderr.getvalue(), "") +class TestPyReplModuleCompleter(TestCase): + def prepare_reader(self, events, namespace): + console = FakeConsole(events) + config = ReadlineConfig() + config.readline_completer = rlcompleter.Completer(namespace).complete + reader = ReadlineAlikeReader(console=console, config=config) + return reader + + def test_import(self): + cases = [ + ("import path\t\n", "import pathlib"), + ("import importlib.\t\tres\t\n", "import importlib.resources"), + ("import importlib.resources.\t\ta\t\n", "import importlib.resources.abc"), + ("import foo, impo\t\n", "import foo, importlib"), + ("import foo as bar, impo\t\n", "import foo as bar, importlib"), + ] + + for code, expected in cases: + with self.subTest(code=code): + events = code_to_events(code) + reader = self.prepare_reader(events, namespace={}) + output = reader.readline() + self.assertEqual(output, expected) + + def test_from_import(self): + cases = [ + ("from impo\t\n", "from importlib"), + ("from importlib.res\t\n", "from importlib.resources"), + ("from importlib.\t\tres\t\n", "from importlib.resources"), + ("from importlib.resources.ab\t\n", "from importlib.resources.abc"), + ] + + for code, expected in cases: + with self.subTest(code=code): + events = code_to_events(code) + reader = self.prepare_reader(events, namespace={}) + output = reader.readline() + self.assertEqual(output, expected) + + def test_from_import_attributes(self): + cases = [ + ("from importlib import mac\t\n", "from importlib import machinery"), + ("from importlib import res\t\n", "from importlib import resources"), + ("from importlib import invalidate_\t\n", "from importlib import invalidate_caches"), + ("from importlib import (inval\t\n", "from importlib import (invalidate_caches"), + ("from importlib import foo, invalidate_\t\n", "from importlib import foo, invalidate_caches"), + ("from importlib import (foo, invalidate_\t\n", "from importlib import (foo, invalidate_caches"), + ("from importlib import foo as bar, invalidate_\t\n", "from importlib import foo as bar, invalidate_caches"), + ("from importlib import (foo as bar, invalidate_\t\n", "from importlib import (foo as bar, invalidate_caches"), + ] + + for code, expected in cases: + with self.subTest(code=code): + events = code_to_events(code) + reader = self.prepare_reader(events, namespace={}) + output = reader.readline() + self.assertEqual(output, expected) + + def test_relative_from_import(self): + cases = [ + ("from .readl\t\n", "from .readline"), + ("from .readline import Mod\t\n", "from .readline import ModuleCompleter"), + ] + + for code, expected in cases: + with self.subTest(code=code): + events = code_to_events(code) + reader = self.prepare_reader(events, namespace={}) + output = reader.readline() + self.assertEqual(output, expected) + + def test_get_path_and_prefix(self): + cases = [ + ('', ('', '')), + ('.', ('.', '')), + ('..', ('..', '')), + ('.foo', ('.', 'foo')), + ('..foo', ('..', 'foo')), + ('..foo.', ('..foo', '')), + ('..foo.bar', ('..foo', 'bar')), + ('.foo.bar.', ('.foo.bar', '')), + ('..foo.bar.', ('..foo.bar', '')), + ('foo', ('', 'foo')), + ('foo.', ('foo', '')), + ('foo.bar', ('foo', 'bar')), + ('foo.bar.', ('foo.bar', '')), + ('foo.bar.baz', ('foo.bar', 'baz')), + ] + + completer = ModuleCompleter() + for name, expected in cases: + with self.subTest(name=name): + self.assertEqual(completer.get_path_and_prefix(name), expected) + + def test_parse(self): + cases = [ + ('import ', (None, '')), + ('import foo', (None, 'foo')), + ('import foo,', (None, '')), + ('import foo, ', (None, '')), + ('import foo, bar', (None, 'bar')), + ('import foo, bar, baz', (None, 'baz')), + ('import foo as bar,', (None, '')), + ('import foo as bar, ', (None, '')), + ('import foo as bar, baz', (None, 'baz')), + ('import a.', (None, 'a.')), + ('import a.b', (None, 'a.b')), + ('import a.b.', (None, 'a.b.')), + ('import a.b.c', (None, 'a.b.c')), + ('import a.b.c, foo', (None, 'foo')), + ('import a.b.c, foo.bar', (None, 'foo.bar')), + ('import a.b.c, foo.bar,', (None, '')), + ('import a.b.c, foo.bar, ', (None, '')), + ('from foo', ('foo', None)), + ('from a.', ('a.', None)), + ('from a.b', ('a.b', None)), + ('from a.b.', ('a.b.', None)), + ('from a.b.c', ('a.b.c', None)), + ('from foo import ', ('foo', '')), + ('from foo import a', ('foo', 'a')), + ('from ', ('', None)), + ('from . import a', ('.', 'a')), + ('from .foo import a', ('.foo', 'a')), + ('from ..foo import a', ('..foo', 'a')), + ('from foo import (', ('foo', '')), + ('from foo import ( ', ('foo', '')), + ('from foo import (a', ('foo', 'a')), + ('from foo import (a,', ('foo', '')), + ('from foo import (a, ', ('foo', '')), + ('from foo import (a, c', ('foo', 'c')), + ('from foo import (a as b, c', ('foo', 'c')), + ] + + for code, parsed in cases: + parser = ImportParser(code) + actual = parser.parse() + with self.subTest(code=code): + self.assertEqual(actual, parsed) + # The parser should not get tripped up by any + # other preceding statements + code = f'import xyz\n{code}' + with self.subTest(code=code): + self.assertEqual(actual, parsed) + + def test_parse_error(self): + cases = [ + '', + 'import foo ', + 'from foo ', + 'import foo. ', + 'import foo.bar ', + 'from foo ', + 'from foo. ', + 'from foo.bar ', + 'from foo import bar ', + 'from foo import (bar ', + 'from foo import bar, baz ', + 'import foo as', + 'import a. as', + 'import a.b as', + 'import a.b. as', + 'import a.b.c as', + 'import (foo', + 'import (', + 'import foo; x = 1', + 'import a.; x = 1', + 'import a.b; x = 1', + 'import a.b.; x = 1', + 'import a.b.c; x = 1', + 'from foo import a as', + 'from foo import a. as', + 'from foo import a.b as', + 'from foo import a.b. as', + 'from foo import a.b.c as', + 'from foo impo', + 'import import', + 'import from', + 'import as', + 'from import', + 'from from', + 'from as', + 'from foo import import', + 'from foo import from', + 'from foo import as', + ] + + for code in cases: + parser = ImportParser(code) + actual = parser.parse() + with self.subTest(code=code): + self.assertEqual(actual, None) + class TestPasteEvent(TestCase): def prepare_reader(self, events): console = FakeConsole(events) From bcd35274bc8765dbaa3a38d76abea491bb6c8175 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sun, 26 Jan 2025 23:46:48 +0100 Subject: [PATCH 02/15] Add news entry --- .../2025-01-26-23-46-43.gh-issue-69605._2Qc1w.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Core_and_Builtins/2025-01-26-23-46-43.gh-issue-69605._2Qc1w.rst diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-01-26-23-46-43.gh-issue-69605._2Qc1w.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-26-23-46-43.gh-issue-69605._2Qc1w.rst new file mode 100644 index 00000000000000..1f85be8bc61fd0 --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-26-23-46-43.gh-issue-69605._2Qc1w.rst @@ -0,0 +1 @@ +Add module autocomplete to PyREPL. From 0917d24e9e0eb41ab492bfcb9ba0f05b18a7d93d Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Mon, 3 Mar 2025 23:39:09 +0100 Subject: [PATCH 03/15] Remove attribute completion, never import modules --- Lib/_pyrepl/readline.py | 132 ++++++++++++++++------------ Lib/test/test_pyrepl/test_pyrepl.py | 66 +++++--------- 2 files changed, 97 insertions(+), 101 deletions(-) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index e1a2d400494617..9deb03da1a66af 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -35,6 +35,7 @@ from io import StringIO from contextlib import contextmanager from dataclasses import dataclass, field +from itertools import chain from tokenize import TokenInfo import os @@ -635,71 +636,69 @@ def __init__(self, namespace: Mapping[str, Any] | None = None): def get_completions(self, line: str) -> list[str]: """Return the next possible import completions for 'line'.""" - - parser = ImportParser(line) - if not (result := parser.parse()): + result = ImportParser(line).parse() + if not result: return [] return self.complete(*result) def complete(self, from_name: str | None, name: str | None) -> list[str]: - # import x.y.z if from_name is None: - if not name: - return [] - return self.complete_import(name) + # import x.y.z + path, prefix = self.get_path_and_prefix(name) + modules = self.find_modules(path, prefix) + return [self.format_completion(path, module) for module in modules] - # from x.y.z if name is None: - if not from_name: - return [] - return self.complete_import(from_name) + # from x.y.z + path, prefix = self.get_path_and_prefix(from_name) + modules = self.find_modules(path, prefix) + return [self.format_completion(path, module) for module in modules] # from x.y import z - if not (module := self.import_module(from_name)): - return [] - - submodules = self.filter_submodules(module, name) - attributes = self.filter_attributes(module, name) - return list(set(submodules + attributes)) - - def complete_import(self, name: str) -> list[str]: - is_relative = name.startswith('.') - path, prefix = self.get_path_and_prefix(name) - - if not is_relative and not path: - return [name for name in self.global_cache if name.startswith(prefix)] - - if not (module := self.import_module(path)): - return [] - - submodules = self.filter_submodules(module, prefix) - if not is_relative: - return [f'{path}.{name}' for name in submodules] - return [f'.{name}' for name in submodules] - - def import_module(self, path: str) -> ModuleType | None: - package = self.namespace.get('__package__') - is_relative = path.startswith('.') - if is_relative and not package: - return None - try: - module = importlib.import_module( - path, - package=package if is_relative else None) - except ImportError: - return None - return module - - def filter_submodules(self, module: ModuleType, prefix: str) -> list[str]: - if not hasattr(module, '__path__'): - return [] - return [name for _, name, _ in pkgutil.iter_modules(module.__path__) - if name.startswith(prefix)] + return self.find_modules(from_name, name) + + def find_modules(self, path: str, prefix: str) -> list[str]: + """Find all modules under 'path' that start with 'prefix'.""" + if not path: + # Top-level import (e.g. `import foo`` or `from foo`)` + return [name for _, name, _ in self.global_cache + if name.startswith(prefix)] + + if path.startswith('.'): + # Convert relative path to absolute path + package = self.namespace.get('__package__') + path = self.resolve_relative_name(path, package) + if path is None: + return [] - def filter_attributes(self, module: ModuleType, prefix: str) -> list[str]: - return [attr for attr in module.__dict__ if attr.startswith(prefix)] + modules = self.global_cache + for segment in path.split('.'): + modules = [mod_info for mod_info in modules + if mod_info.ispkg and mod_info.name == segment] + modules = self.iter_submodules(modules) + return [module.name for module in modules + if module.name.startswith(prefix)] + + def iter_submodules(self, parent_modules): + """Iterate over all submodules of the given parent modules.""" + specs = [info.module_finder.find_spec(info.name) + for info in parent_modules if info.ispkg] + search_locations = set(chain.from_iterable( + getattr(spec, 'submodule_search_locations', []) + for spec in specs if spec + )) + return pkgutil.iter_modules(search_locations) def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: + """ + Split a dotted name into an import path and a + final prefix that is to be completed. + + Examples: + 'foo.bar' -> 'foo', 'bar' + 'foo.' -> 'foo', '' + '.foo' -> '.', 'foo' + """ if '.' not in dotted_name: return '', dotted_name if dotted_name.startswith('.'): @@ -712,12 +711,35 @@ def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: path, prefix = dotted_name.rsplit('.', 1) return path, prefix + def format_completion(self, path: str, module: str) -> str: + if path == '' or path.endswith('.'): + return f'{path}{module}' + return f'{path}.{module}' + + def resolve_relative_name(self, name, package): + """Resolve a relative module name to an absolute name. + + Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' + """ + # taken from importlib._bootstrap + level = 0 + for character in name: + if character != '.': + break + level += 1 + bits = package.rsplit('.', level - 1) + if len(bits) < level: + return None + base = bits[0] + name = name[level:] + return f'{base}.{name}' if name else base + @property def global_cache(self) -> list[str]: + """Global module cache""" if not self._global_cache or self._curr_sys_path != sys.path: self._curr_sys_path = sys.path[:] - self._global_cache = [ - name for _, name, _ in pkgutil.iter_modules()] + self._global_cache = list(pkgutil.iter_modules()) return self._global_cache diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index d4547f9a7e403a..d51f7c747b7689 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -904,49 +904,21 @@ def prepare_reader(self, events, namespace): reader = ReadlineAlikeReader(console=console, config=config) return reader - def test_import(self): - cases = [ + def test_import_completions(self): + cases = ( ("import path\t\n", "import pathlib"), ("import importlib.\t\tres\t\n", "import importlib.resources"), ("import importlib.resources.\t\ta\t\n", "import importlib.resources.abc"), ("import foo, impo\t\n", "import foo, importlib"), ("import foo as bar, impo\t\n", "import foo as bar, importlib"), - ] - - for code, expected in cases: - with self.subTest(code=code): - events = code_to_events(code) - reader = self.prepare_reader(events, namespace={}) - output = reader.readline() - self.assertEqual(output, expected) - - def test_from_import(self): - cases = [ ("from impo\t\n", "from importlib"), ("from importlib.res\t\n", "from importlib.resources"), ("from importlib.\t\tres\t\n", "from importlib.resources"), ("from importlib.resources.ab\t\n", "from importlib.resources.abc"), - ] - - for code, expected in cases: - with self.subTest(code=code): - events = code_to_events(code) - reader = self.prepare_reader(events, namespace={}) - output = reader.readline() - self.assertEqual(output, expected) - - def test_from_import_attributes(self): - cases = [ ("from importlib import mac\t\n", "from importlib import machinery"), ("from importlib import res\t\n", "from importlib import resources"), - ("from importlib import invalidate_\t\n", "from importlib import invalidate_caches"), - ("from importlib import (inval\t\n", "from importlib import (invalidate_caches"), - ("from importlib import foo, invalidate_\t\n", "from importlib import foo, invalidate_caches"), - ("from importlib import (foo, invalidate_\t\n", "from importlib import (foo, invalidate_caches"), - ("from importlib import foo as bar, invalidate_\t\n", "from importlib import foo as bar, invalidate_caches"), - ("from importlib import (foo as bar, invalidate_\t\n", "from importlib import (foo as bar, invalidate_caches"), - ] - + ("from importlib.res\t import a\t\n", "from importlib.resources import abc"), + ) for code, expected in cases: with self.subTest(code=code): events = code_to_events(code) @@ -954,12 +926,11 @@ def test_from_import_attributes(self): output = reader.readline() self.assertEqual(output, expected) - def test_relative_from_import(self): - cases = [ + def test_relative_import_completions(self): + cases = ( ("from .readl\t\n", "from .readline"), - ("from .readline import Mod\t\n", "from .readline import ModuleCompleter"), - ] - + ("from . import readl\t\n", "from . import readline"), + ) for code, expected in cases: with self.subTest(code=code): events = code_to_events(code) @@ -968,7 +939,7 @@ def test_relative_from_import(self): self.assertEqual(output, expected) def test_get_path_and_prefix(self): - cases = [ + cases = ( ('', ('', '')), ('.', ('.', '')), ('..', ('..', '')), @@ -983,15 +954,14 @@ def test_get_path_and_prefix(self): ('foo.bar', ('foo', 'bar')), ('foo.bar.', ('foo.bar', '')), ('foo.bar.baz', ('foo.bar', 'baz')), - ] - + ) completer = ModuleCompleter() for name, expected in cases: with self.subTest(name=name): self.assertEqual(completer.get_path_and_prefix(name), expected) def test_parse(self): - cases = [ + cases = ( ('import ', (None, '')), ('import foo', (None, 'foo')), ('import foo,', (None, '')), @@ -1027,8 +997,7 @@ def test_parse(self): ('from foo import (a, ', ('foo', '')), ('from foo import (a, c', ('foo', 'c')), ('from foo import (a as b, c', ('foo', 'c')), - ] - + ) for code, parsed in cases: parser = ImportParser(code) actual = parser.parse() @@ -1039,9 +1008,12 @@ def test_parse(self): code = f'import xyz\n{code}' with self.subTest(code=code): self.assertEqual(actual, parsed) + code = f'import xyz;{code}' + with self.subTest(code=code): + self.assertEqual(actual, parsed) def test_parse_error(self): - cases = [ + cases = ( '', 'import foo ', 'from foo ', @@ -1060,6 +1032,9 @@ def test_parse_error(self): 'import a.b.c as', 'import (foo', 'import (', + 'import .foo', + 'import ..foo', + 'import .foo.bar', 'import foo; x = 1', 'import a.; x = 1', 'import a.b; x = 1', @@ -1080,8 +1055,7 @@ def test_parse_error(self): 'from foo import import', 'from foo import from', 'from foo import as', - ] - + ) for code in cases: parser = ImportParser(code) actual = parser.parse() From 589cf63aa50376fbb6740523176ac89b72e1d580 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 8 Mar 2025 15:06:55 +0100 Subject: [PATCH 04/15] Add type annotations --- Lib/_pyrepl/completing_reader.py | 2 +- Lib/_pyrepl/readline.py | 36 +++++++++++++++----------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/Lib/_pyrepl/completing_reader.py b/Lib/_pyrepl/completing_reader.py index 033de8bd729679..9d2d43be5144e8 100644 --- a/Lib/_pyrepl/completing_reader.py +++ b/Lib/_pyrepl/completing_reader.py @@ -294,6 +294,6 @@ def get_stem(self) -> str: def get_completions(self, stem: str) -> list[str]: return [] - def get_line(self): + def get_line(self) -> str: """Return the current line until the cursor position.""" return ''.join(self.buffer[:self.pos]) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 9deb03da1a66af..56fc271b936b2b 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -28,7 +28,6 @@ from __future__ import annotations -import importlib import pkgutil import tokenize import warnings @@ -65,8 +64,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Any, Mapping - from types import ModuleType + from typing import Any, Iterator, Mapping MoreLinesCallable = Callable[[str], bool] @@ -629,7 +627,7 @@ class ModuleCompleter: - from foo import (bar as baz, qux """ - def __init__(self, namespace: Mapping[str, Any] | None = None): + def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: self.namespace = namespace or {} self._global_cache: list[str] = [] self._curr_sys_path: list[str] = sys.path[:] @@ -679,7 +677,7 @@ def find_modules(self, path: str, prefix: str) -> list[str]: return [module.name for module in modules if module.name.startswith(prefix)] - def iter_submodules(self, parent_modules): + def iter_submodules(self, parent_modules) -> Iterator[pkgutil.ModuleInfo]: """Iterate over all submodules of the given parent modules.""" specs = [info.module_finder.find_spec(info.name) for info in parent_modules if info.ispkg] @@ -716,7 +714,7 @@ def format_completion(self, path: str, module: str) -> str: return f'{path}{module}' return f'{path}.{module}' - def resolve_relative_name(self, name, package): + def resolve_relative_name(self, name, package) -> str | None: """Resolve a relative module name to an absolute name. Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' @@ -765,7 +763,7 @@ class ImportParser: } _keywords = {'import', 'from', 'as'} - def __init__(self, code: str): + def __init__(self, code: str) -> None: self.code = code tokens = [] try: @@ -783,18 +781,18 @@ def __init__(self, code: str): tokens = [] self.tokens = TokenQueue(tokens[::-1]) - def parse(self): + def parse(self) -> tuple[str | None, str | None] | None: if not (res := self._parse()): return None return res.from_name, res.name - def _parse(self): + def _parse(self) -> Result | None: with self.tokens.save_state(): return self.parse_from_import() with self.tokens.save_state(): return self.parse_import() - def parse_import(self): + def parse_import(self) -> Result: if self.code.rstrip().endswith('import') and self.code.endswith(' '): return Result(name='') if self.tokens.peek_string(','): @@ -812,7 +810,7 @@ def parse_import(self): return Result(name=name) raise ParseError('parse_import') - def parse_from_import(self): + def parse_from_import(self) -> Result: if self.code.rstrip().endswith('import') and self.code.endswith(' '): return Result(from_name=self.parse_empty_from_import(), name='') if self.code.rstrip().endswith('from') and self.code.endswith(' '): @@ -830,7 +828,7 @@ def parse_from_import(self): from_name = self.parse_empty_from_import() return Result(from_name=from_name, name=name) - def parse_empty_from_import(self): + def parse_empty_from_import(self) -> str: if self.tokens.peek_string(','): self.tokens.pop() self.parse_as_names() @@ -839,19 +837,19 @@ def parse_empty_from_import(self): self.tokens.pop_string('import') return self.parse_from() - def parse_from(self): + def parse_from(self) -> str: from_name = self.parse_dotted_name() self.tokens.pop_string('from') return from_name - def parse_dotted_as_name(self): + def parse_dotted_as_name(self) -> str: self.tokens.pop_name() if self.tokens.peek_string('as'): self.tokens.pop() with self.tokens.save_state(): return self.parse_dotted_name() - def parse_dotted_name(self): + def parse_dotted_name(self) -> str: name = [] if self.tokens.peek_string('.'): name.append('.') @@ -873,13 +871,13 @@ def parse_dotted_name(self): self.tokens.pop() return ''.join(name[::-1]) - def parse_as_names(self): + def parse_as_names(self) -> None: self.parse_as_name() while self.tokens.peek_string(','): self.tokens.pop() self.parse_as_name() - def parse_as_name(self): + def parse_as_name(self) -> None: self.tokens.pop_name() if self.tokens.peek_string('as'): self.tokens.pop() @@ -905,7 +903,7 @@ def __init__(self, tokens: list[TokenInfo]) -> None: self.stack: list[int] = [] @contextmanager - def save_state(self): + def save_state(self) -> Any: try: self.stack.append(self.index) yield @@ -914,7 +912,7 @@ def save_state(self): else: self.stack.pop() - def __bool__(self): + def __bool__(self) -> bool: return self.index < len(self.tokens) def peek(self) -> TokenInfo | None: From 62d0b55e16637c78863bfd9a987257d99299b89e Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 8 Mar 2025 15:51:55 +0100 Subject: [PATCH 05/15] fix some mypy issues --- Lib/_pyrepl/readline.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 56fc271b936b2b..8b6b11393ea8ae 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -64,7 +64,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Any, Iterator, Mapping + from typing import Any, Iterable, Iterator, Mapping MoreLinesCallable = Callable[[str], bool] @@ -629,7 +629,7 @@ class ModuleCompleter: def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: self.namespace = namespace or {} - self._global_cache: list[str] = [] + self._global_cache: list[pkgutil.ModuleInfo] = [] self._curr_sys_path: list[str] = sys.path[:] def get_completions(self, line: str) -> list[str]: @@ -642,6 +642,7 @@ def get_completions(self, line: str) -> list[str]: def complete(self, from_name: str | None, name: str | None) -> list[str]: if from_name is None: # import x.y.z + assert name is not None path, prefix = self.get_path_and_prefix(name) modules = self.find_modules(path, prefix) return [self.format_completion(path, module) for module in modules] @@ -664,12 +665,12 @@ def find_modules(self, path: str, prefix: str) -> list[str]: if path.startswith('.'): # Convert relative path to absolute path - package = self.namespace.get('__package__') - path = self.resolve_relative_name(path, package) + package = self.namespace.get('__package__', '') + path = self.resolve_relative_name(path, package) # type: ignore[assignment] if path is None: return [] - modules = self.global_cache + modules: Iterable[pkgutil.ModuleInfo] = self.global_cache for segment in path.split('.'): modules = [mod_info for mod_info in modules if mod_info.ispkg and mod_info.name == segment] @@ -677,7 +678,7 @@ def find_modules(self, path: str, prefix: str) -> list[str]: return [module.name for module in modules if module.name.startswith(prefix)] - def iter_submodules(self, parent_modules) -> Iterator[pkgutil.ModuleInfo]: + def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: """Iterate over all submodules of the given parent modules.""" specs = [info.module_finder.find_spec(info.name) for info in parent_modules if info.ispkg] @@ -714,7 +715,7 @@ def format_completion(self, path: str, module: str) -> str: return f'{path}{module}' return f'{path}.{module}' - def resolve_relative_name(self, name, package) -> str | None: + def resolve_relative_name(self, name: str, package: str) -> str | None: """Resolve a relative module name to an absolute name. Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' @@ -733,7 +734,7 @@ def resolve_relative_name(self, name, package) -> str | None: return f'{base}.{name}' if name else base @property - def global_cache(self) -> list[str]: + def global_cache(self) -> list[pkgutil.ModuleInfo]: """Global module cache""" if not self._global_cache or self._curr_sys_path != sys.path: self._curr_sys_path = sys.path[:] @@ -854,14 +855,18 @@ def parse_dotted_name(self) -> str: if self.tokens.peek_string('.'): name.append('.') self.tokens.pop() - if self.tokens.peek_name() and self.tokens.peek().string not in self._keywords: + if (self.tokens.peek_name() + and (tok := self.tokens.peek()) + and tok.string not in self._keywords): name.append(self.tokens.pop_name()) if not name: raise ParseError('parse_dotted_name') while self.tokens.peek_string('.'): name.append('.') self.tokens.pop() - if self.tokens.peek_name() and self.tokens.peek().string not in self._keywords: + if (self.tokens.peek_name() + and (tok := self.tokens.peek()) + and tok.string not in self._keywords): name.append(self.tokens.pop_name()) else: break From 46ca249668de0edcdc22a452e674620ed5ff1859 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 8 Mar 2025 21:27:18 +0100 Subject: [PATCH 06/15] Pass explicit None to find_spec --- Lib/_pyrepl/readline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 8b6b11393ea8ae..0e4fdda85833ab 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -680,7 +680,7 @@ def find_modules(self, path: str, prefix: str) -> list[str]: def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: """Iterate over all submodules of the given parent modules.""" - specs = [info.module_finder.find_spec(info.name) + specs = [info.module_finder.find_spec(info.name, None) for info in parent_modules if info.ispkg] search_locations = set(chain.from_iterable( getattr(spec, 'submodule_search_locations', []) From 3c13f865db618c4118d18a1422a0f3863feff5c5 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Thu, 17 Apr 2025 19:17:26 +0200 Subject: [PATCH 07/15] Do not suggest modules which are not legal identifiers --- Lib/_pyrepl/readline.py | 6 ++++++ Lib/test/test_pyrepl/test_pyrepl.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 7e65fcd5060f70..27dc54e25c09b6 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -658,6 +658,12 @@ def complete(self, from_name: str | None, name: str | None) -> list[str]: def find_modules(self, path: str, prefix: str) -> list[str]: """Find all modules under 'path' that start with 'prefix'.""" + modules = self._find_modules(path, prefix) + # Filter out invalid module names + # (for example those containing dashes that cannot be imported with 'import') + return [mod for mod in modules if mod.isidentifier()] + + def _find_modules(self, path: str, prefix: str) -> list[str]: if not path: # Top-level import (e.g. `import foo`` or `from foo`)` return [name for _, name, _ in self.global_cache diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index 48049fc10f33fe..0793124a7df300 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -938,6 +938,23 @@ def test_relative_import_completions(self): output = reader.readline() self.assertEqual(output, expected) + @patch("pkgutil.iter_modules", lambda: [(None, 'valid_name', None), + (None, 'invalid-name', None)]) + def test_invalid_identifiers(self): + # Make sure modules which are not valid identifiers + # are not suggested as those cannot be imported via 'import'. + cases = ( + ("import valid\t\n", "import valid_name"), + # 'invalid-name' contains a dash and should not be completed + ("import invalid\t\n", "import invalid"), + ) + for code, expected in cases: + with self.subTest(code=code): + events = code_to_events(code) + reader = self.prepare_reader(events, namespace={}) + output = reader.readline() + self.assertEqual(output, expected) + def test_get_path_and_prefix(self): cases = ( ('', ('', '')), From 8eb656f7a1942c8652338447464e5914be01602d Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Thu, 17 Apr 2025 19:39:36 +0200 Subject: [PATCH 08/15] Make the tests more robust --- Lib/test/test_pyrepl/test_pyrepl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index 0793124a7df300..5b7424580a2cec 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -10,7 +10,7 @@ import tempfile from unittest import TestCase, skipUnless, skipIf from unittest.mock import patch -from test.support import force_not_colorized, make_clean_env +from test.support import force_not_colorized, make_clean_env, patch_list from test.support import SHORT_TIMEOUT from test.support.import_helper import import_module from test.support.os_helper import unlink @@ -904,7 +904,17 @@ def prepare_reader(self, events, namespace): reader = ReadlineAlikeReader(console=console, config=config) return reader + @patch_list(sys.meta_path) def test_import_completions(self): + from importlib.machinery import BuiltinImporter + # Remove all importers except for the builtin one + # to prevent searching anything but the builtin modules. + # This makes the test more reliable in case there are + # other user packages/scripts on PYTHONPATH which can + # intefere with the completions. + sys.meta_path = [finder for finder in sys.meta_path + if isinstance(finder, BuiltinImporter)] + cases = ( ("import path\t\n", "import pathlib"), ("import importlib.\t\tres\t\n", "import importlib.resources"), From 7a2fde06e0b5cb9310b1947e0f449b2e22a0d39d Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 07:34:26 +0200 Subject: [PATCH 09/15] Remove todo comment --- Lib/_pyrepl/readline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 27dc54e25c09b6..282b66dadbfd05 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -170,7 +170,8 @@ def get_completions(self, stem: str) -> list[str]: return result def get_module_completions(self) -> list[str]: - completer = ModuleCompleter(namespace={'__package__': '_pyrepl'}) # TODO: namespace? + # Inside pyrepl, __package__ is set to '_pyrepl' + completer = ModuleCompleter(namespace={'__package__': '_pyrepl'}) line = self.get_line() return completer.get_completions(line) From 5c11124c39bbac4e9ec1a1de4b29b7ce445abe5c Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 07:55:15 +0200 Subject: [PATCH 10/15] Move to a separate file and cache ModuleCompleter --- Lib/_pyrepl/_module_completer.py | 366 +++++++++++++++++++++++++++++++ Lib/_pyrepl/readline.py | 366 +------------------------------ 2 files changed, 371 insertions(+), 361 deletions(-) create mode 100644 Lib/_pyrepl/_module_completer.py diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py new file mode 100644 index 00000000000000..d480085a3339cc --- /dev/null +++ b/Lib/_pyrepl/_module_completer.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import pkgutil +import sys +import tokenize +from io import StringIO +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from tokenize import TokenInfo + +TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Any, Iterable, Iterator, Mapping + + +class ModuleCompleter: + """A completer for Python import statements. + + Examples: + - import + - import foo + - import foo. + - import foo as bar, baz + + - from + - from foo + - from foo import + - from foo import bar + - from foo import (bar as baz, qux + """ + + def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: + self.namespace = namespace or {} + self._global_cache: list[pkgutil.ModuleInfo] = [] + self._curr_sys_path: list[str] = sys.path[:] + + def get_completions(self, line: str) -> list[str]: + """Return the next possible import completions for 'line'.""" + result = ImportParser(line).parse() + if not result: + return [] + return self.complete(*result) + + def complete(self, from_name: str | None, name: str | None) -> list[str]: + if from_name is None: + # import x.y.z + assert name is not None + path, prefix = self.get_path_and_prefix(name) + modules = self.find_modules(path, prefix) + return [self.format_completion(path, module) for module in modules] + + if name is None: + # from x.y.z + path, prefix = self.get_path_and_prefix(from_name) + modules = self.find_modules(path, prefix) + return [self.format_completion(path, module) for module in modules] + + # from x.y import z + return self.find_modules(from_name, name) + + def find_modules(self, path: str, prefix: str) -> list[str]: + """Find all modules under 'path' that start with 'prefix'.""" + modules = self._find_modules(path, prefix) + # Filter out invalid module names + # (for example those containing dashes that cannot be imported with 'import') + return [mod for mod in modules if mod.isidentifier()] + + def _find_modules(self, path: str, prefix: str) -> list[str]: + if not path: + # Top-level import (e.g. `import foo`` or `from foo`)` + return [name for _, name, _ in self.global_cache + if name.startswith(prefix)] + + if path.startswith('.'): + # Convert relative path to absolute path + package = self.namespace.get('__package__', '') + path = self.resolve_relative_name(path, package) # type: ignore[assignment] + if path is None: + return [] + + modules: Iterable[pkgutil.ModuleInfo] = self.global_cache + for segment in path.split('.'): + modules = [mod_info for mod_info in modules + if mod_info.ispkg and mod_info.name == segment] + modules = self.iter_submodules(modules) + return [module.name for module in modules + if module.name.startswith(prefix)] + + def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: + """Iterate over all submodules of the given parent modules.""" + specs = [info.module_finder.find_spec(info.name, None) + for info in parent_modules if info.ispkg] + search_locations = set(chain.from_iterable( + getattr(spec, 'submodule_search_locations', []) + for spec in specs if spec + )) + return pkgutil.iter_modules(search_locations) + + def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: + """ + Split a dotted name into an import path and a + final prefix that is to be completed. + + Examples: + 'foo.bar' -> 'foo', 'bar' + 'foo.' -> 'foo', '' + '.foo' -> '.', 'foo' + """ + if '.' not in dotted_name: + return '', dotted_name + if dotted_name.startswith('.'): + stripped = dotted_name.lstrip('.') + dots = '.' * (len(dotted_name) - len(stripped)) + if '.' not in stripped: + return dots, stripped + path, prefix = stripped.rsplit('.', 1) + return dots + path, prefix + path, prefix = dotted_name.rsplit('.', 1) + return path, prefix + + def format_completion(self, path: str, module: str) -> str: + if path == '' or path.endswith('.'): + return f'{path}{module}' + return f'{path}.{module}' + + def resolve_relative_name(self, name: str, package: str) -> str | None: + """Resolve a relative module name to an absolute name. + + Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' + """ + # taken from importlib._bootstrap + level = 0 + for character in name: + if character != '.': + break + level += 1 + bits = package.rsplit('.', level - 1) + if len(bits) < level: + return None + base = bits[0] + name = name[level:] + return f'{base}.{name}' if name else base + + @property + def global_cache(self) -> list[pkgutil.ModuleInfo]: + """Global module cache""" + if not self._global_cache or self._curr_sys_path != sys.path: + self._curr_sys_path = sys.path[:] + # print('getting packages') + self._global_cache = list(pkgutil.iter_modules()) + return self._global_cache + + +class ImportParser: + """ + Parses incomplete import statements that are + suitable for autocomplete suggestions. + + Examples: + - import foo -> Result(from_name=None, name='foo') + - import foo. -> Result(from_name=None, name='foo.') + - from foo -> Result(from_name='foo', name=None) + - from foo import bar -> Result(from_name='foo', name='bar') + - from .foo import ( -> Result(from_name='.foo', name='') + + Note that the parser works in reverse order, starting from the + last token in the input string. This makes the parser more robust + when parsing multiple statements. + """ + _ignored_tokens = { + tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, + tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER + } + _keywords = {'import', 'from', 'as'} + + def __init__(self, code: str) -> None: + self.code = code + tokens = [] + try: + for t in tokenize.generate_tokens(StringIO(code).readline): + if t.type not in self._ignored_tokens: + tokens.append(t) + except tokenize.TokenError as e: + if 'unexpected EOF' not in str(e): + # unexpected EOF is fine, since we're parsing an + # incomplete statement, but other errors are not + # because we may not have all the tokens so it's + # safer to bail out + tokens = [] + except SyntaxError: + tokens = [] + self.tokens = TokenQueue(tokens[::-1]) + + def parse(self) -> tuple[str | None, str | None] | None: + if not (res := self._parse()): + return None + return res.from_name, res.name + + def _parse(self) -> Result | None: + with self.tokens.save_state(): + return self.parse_from_import() + with self.tokens.save_state(): + return self.parse_import() + + def parse_import(self) -> Result: + if self.code.rstrip().endswith('import') and self.code.endswith(' '): + return Result(name='') + if self.tokens.peek_string(','): + name = '' + else: + if self.code.endswith(' '): + raise ParseError('parse_import') + name = self.parse_dotted_name() + if name.startswith('.'): + raise ParseError('parse_import') + while self.tokens.peek_string(','): + self.tokens.pop() + self.parse_dotted_as_name() + if self.tokens.peek_string('import'): + return Result(name=name) + raise ParseError('parse_import') + + def parse_from_import(self) -> Result: + if self.code.rstrip().endswith('import') and self.code.endswith(' '): + return Result(from_name=self.parse_empty_from_import(), name='') + if self.code.rstrip().endswith('from') and self.code.endswith(' '): + return Result(from_name='') + if self.tokens.peek_string('(') or self.tokens.peek_string(','): + return Result(from_name=self.parse_empty_from_import(), name='') + if self.code.endswith(' '): + raise ParseError('parse_from_import') + name = self.parse_dotted_name() + if '.' in name: + self.tokens.pop_string('from') + return Result(from_name=name) + if self.tokens.peek_string('from'): + return Result(from_name=name) + from_name = self.parse_empty_from_import() + return Result(from_name=from_name, name=name) + + def parse_empty_from_import(self) -> str: + if self.tokens.peek_string(','): + self.tokens.pop() + self.parse_as_names() + if self.tokens.peek_string('('): + self.tokens.pop() + self.tokens.pop_string('import') + return self.parse_from() + + def parse_from(self) -> str: + from_name = self.parse_dotted_name() + self.tokens.pop_string('from') + return from_name + + def parse_dotted_as_name(self) -> str: + self.tokens.pop_name() + if self.tokens.peek_string('as'): + self.tokens.pop() + with self.tokens.save_state(): + return self.parse_dotted_name() + + def parse_dotted_name(self) -> str: + name = [] + if self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + if (self.tokens.peek_name() + and (tok := self.tokens.peek()) + and tok.string not in self._keywords): + name.append(self.tokens.pop_name()) + if not name: + raise ParseError('parse_dotted_name') + while self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + if (self.tokens.peek_name() + and (tok := self.tokens.peek()) + and tok.string not in self._keywords): + name.append(self.tokens.pop_name()) + else: + break + + while self.tokens.peek_string('.'): + name.append('.') + self.tokens.pop() + return ''.join(name[::-1]) + + def parse_as_names(self) -> None: + self.parse_as_name() + while self.tokens.peek_string(','): + self.tokens.pop() + self.parse_as_name() + + def parse_as_name(self) -> None: + self.tokens.pop_name() + if self.tokens.peek_string('as'): + self.tokens.pop() + self.tokens.pop_name() + + +class ParseError(Exception): + pass + + +@dataclass(frozen=True) +class Result: + from_name: str | None = None + name: str | None = None + + +class TokenQueue: + """Provides helper functions for working with a sequence of tokens.""" + + def __init__(self, tokens: list[TokenInfo]) -> None: + self.tokens: list[TokenInfo] = tokens + self.index: int = 0 + self.stack: list[int] = [] + + @contextmanager + def save_state(self) -> Any: + try: + self.stack.append(self.index) + yield + except ParseError: + self.index = self.stack.pop() + else: + self.stack.pop() + + def __bool__(self) -> bool: + return self.index < len(self.tokens) + + def peek(self) -> TokenInfo | None: + if not self: + return None + return self.tokens[self.index] + + def peek_name(self) -> bool: + if not (tok := self.peek()): + return False + return tok.type == tokenize.NAME + + def pop_name(self) -> str: + tok = self.pop() + if tok.type != tokenize.NAME: + raise ParseError('pop_name') + return tok.string + + def peek_string(self, string: str) -> bool: + if not (tok := self.peek()): + return False + return tok.string == string + + def pop_string(self, string: str) -> str: + tok = self.pop() + if tok.string != string: + raise ParseError('pop_string') + return tok.string + + def pop(self) -> TokenInfo: + if not self: + raise ParseError('pop') + tok = self.tokens[self.index] + self.index += 1 + return tok diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 282b66dadbfd05..71c3e441d9d261 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -28,14 +28,8 @@ from __future__ import annotations -import pkgutil -import tokenize import warnings -from io import StringIO -from contextlib import contextmanager from dataclasses import dataclass, field -from itertools import chain -from tokenize import TokenInfo import os from site import gethistoryfile @@ -45,6 +39,7 @@ from . import commands, historical_reader from .completing_reader import CompletingReader from .console import Console as ConsoleType +from ._module_completer import ModuleCompleter Console: type[ConsoleType] _error: tuple[type[Exception], ...] | type[Exception] @@ -64,7 +59,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Any, Iterable, Iterator, Mapping + from typing import Any, Mapping MoreLinesCallable = Callable[[str], bool] @@ -105,7 +100,8 @@ class ReadlineConfig: readline_completer: Completer | None = None completer_delims: frozenset[str] = frozenset(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?") - + # Inside pyrepl, __package__ is set to '_pyrepl' + module_completer: ModuleCompleter = ModuleCompleter(namespace={'__package__': '_pyrepl'}) @dataclass(kw_only=True) class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader): @@ -170,10 +166,8 @@ def get_completions(self, stem: str) -> list[str]: return result def get_module_completions(self) -> list[str]: - # Inside pyrepl, __package__ is set to '_pyrepl' - completer = ModuleCompleter(namespace={'__package__': '_pyrepl'}) line = self.get_line() - return completer.get_completions(line) + return self.config.module_completer.get_completions(line) def get_trimmed_history(self, maxlength: int) -> list[str]: if maxlength >= 0: @@ -610,353 +604,3 @@ def _setup(namespace: Mapping[str, Any]) -> None: raw_input: Callable[[object], str] | None = None - - -class ModuleCompleter: - """A completer for Python import statements. - - Examples: - - import - - import foo - - import foo. - - import foo as bar, baz - - - from - - from foo - - from foo import - - from foo import bar - - from foo import (bar as baz, qux - """ - - def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: - self.namespace = namespace or {} - self._global_cache: list[pkgutil.ModuleInfo] = [] - self._curr_sys_path: list[str] = sys.path[:] - - def get_completions(self, line: str) -> list[str]: - """Return the next possible import completions for 'line'.""" - result = ImportParser(line).parse() - if not result: - return [] - return self.complete(*result) - - def complete(self, from_name: str | None, name: str | None) -> list[str]: - if from_name is None: - # import x.y.z - assert name is not None - path, prefix = self.get_path_and_prefix(name) - modules = self.find_modules(path, prefix) - return [self.format_completion(path, module) for module in modules] - - if name is None: - # from x.y.z - path, prefix = self.get_path_and_prefix(from_name) - modules = self.find_modules(path, prefix) - return [self.format_completion(path, module) for module in modules] - - # from x.y import z - return self.find_modules(from_name, name) - - def find_modules(self, path: str, prefix: str) -> list[str]: - """Find all modules under 'path' that start with 'prefix'.""" - modules = self._find_modules(path, prefix) - # Filter out invalid module names - # (for example those containing dashes that cannot be imported with 'import') - return [mod for mod in modules if mod.isidentifier()] - - def _find_modules(self, path: str, prefix: str) -> list[str]: - if not path: - # Top-level import (e.g. `import foo`` or `from foo`)` - return [name for _, name, _ in self.global_cache - if name.startswith(prefix)] - - if path.startswith('.'): - # Convert relative path to absolute path - package = self.namespace.get('__package__', '') - path = self.resolve_relative_name(path, package) # type: ignore[assignment] - if path is None: - return [] - - modules: Iterable[pkgutil.ModuleInfo] = self.global_cache - for segment in path.split('.'): - modules = [mod_info for mod_info in modules - if mod_info.ispkg and mod_info.name == segment] - modules = self.iter_submodules(modules) - return [module.name for module in modules - if module.name.startswith(prefix)] - - def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: - """Iterate over all submodules of the given parent modules.""" - specs = [info.module_finder.find_spec(info.name, None) - for info in parent_modules if info.ispkg] - search_locations = set(chain.from_iterable( - getattr(spec, 'submodule_search_locations', []) - for spec in specs if spec - )) - return pkgutil.iter_modules(search_locations) - - def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: - """ - Split a dotted name into an import path and a - final prefix that is to be completed. - - Examples: - 'foo.bar' -> 'foo', 'bar' - 'foo.' -> 'foo', '' - '.foo' -> '.', 'foo' - """ - if '.' not in dotted_name: - return '', dotted_name - if dotted_name.startswith('.'): - stripped = dotted_name.lstrip('.') - dots = '.' * (len(dotted_name) - len(stripped)) - if '.' not in stripped: - return dots, stripped - path, prefix = stripped.rsplit('.', 1) - return dots + path, prefix - path, prefix = dotted_name.rsplit('.', 1) - return path, prefix - - def format_completion(self, path: str, module: str) -> str: - if path == '' or path.endswith('.'): - return f'{path}{module}' - return f'{path}.{module}' - - def resolve_relative_name(self, name: str, package: str) -> str | None: - """Resolve a relative module name to an absolute name. - - Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' - """ - # taken from importlib._bootstrap - level = 0 - for character in name: - if character != '.': - break - level += 1 - bits = package.rsplit('.', level - 1) - if len(bits) < level: - return None - base = bits[0] - name = name[level:] - return f'{base}.{name}' if name else base - - @property - def global_cache(self) -> list[pkgutil.ModuleInfo]: - """Global module cache""" - if not self._global_cache or self._curr_sys_path != sys.path: - self._curr_sys_path = sys.path[:] - self._global_cache = list(pkgutil.iter_modules()) - return self._global_cache - - -class ImportParser: - """ - Parses incomplete import statements that are - suitable for autocomplete suggestions. - - Examples: - - import foo -> Result(from_name=None, name='foo') - - import foo. -> Result(from_name=None, name='foo.') - - from foo -> Result(from_name='foo', name=None) - - from foo import bar -> Result(from_name='foo', name='bar') - - from .foo import ( -> Result(from_name='.foo', name='') - - Note that the parser works in reverse order, starting from the - last token in the input string. This makes the parser more robust - when parsing multiple statements. - """ - _ignored_tokens = { - tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, - tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER - } - _keywords = {'import', 'from', 'as'} - - def __init__(self, code: str) -> None: - self.code = code - tokens = [] - try: - for t in tokenize.generate_tokens(StringIO(code).readline): - if t.type not in self._ignored_tokens: - tokens.append(t) - except tokenize.TokenError as e: - if 'unexpected EOF' not in str(e): - # unexpected EOF is fine, since we're parsing an - # incomplete statement, but other errors are not - # because we may not have all the tokens so it's - # safer to bail out - tokens = [] - except SyntaxError: - tokens = [] - self.tokens = TokenQueue(tokens[::-1]) - - def parse(self) -> tuple[str | None, str | None] | None: - if not (res := self._parse()): - return None - return res.from_name, res.name - - def _parse(self) -> Result | None: - with self.tokens.save_state(): - return self.parse_from_import() - with self.tokens.save_state(): - return self.parse_import() - - def parse_import(self) -> Result: - if self.code.rstrip().endswith('import') and self.code.endswith(' '): - return Result(name='') - if self.tokens.peek_string(','): - name = '' - else: - if self.code.endswith(' '): - raise ParseError('parse_import') - name = self.parse_dotted_name() - if name.startswith('.'): - raise ParseError('parse_import') - while self.tokens.peek_string(','): - self.tokens.pop() - self.parse_dotted_as_name() - if self.tokens.peek_string('import'): - return Result(name=name) - raise ParseError('parse_import') - - def parse_from_import(self) -> Result: - if self.code.rstrip().endswith('import') and self.code.endswith(' '): - return Result(from_name=self.parse_empty_from_import(), name='') - if self.code.rstrip().endswith('from') and self.code.endswith(' '): - return Result(from_name='') - if self.tokens.peek_string('(') or self.tokens.peek_string(','): - return Result(from_name=self.parse_empty_from_import(), name='') - if self.code.endswith(' '): - raise ParseError('parse_from_import') - name = self.parse_dotted_name() - if '.' in name: - self.tokens.pop_string('from') - return Result(from_name=name) - if self.tokens.peek_string('from'): - return Result(from_name=name) - from_name = self.parse_empty_from_import() - return Result(from_name=from_name, name=name) - - def parse_empty_from_import(self) -> str: - if self.tokens.peek_string(','): - self.tokens.pop() - self.parse_as_names() - if self.tokens.peek_string('('): - self.tokens.pop() - self.tokens.pop_string('import') - return self.parse_from() - - def parse_from(self) -> str: - from_name = self.parse_dotted_name() - self.tokens.pop_string('from') - return from_name - - def parse_dotted_as_name(self) -> str: - self.tokens.pop_name() - if self.tokens.peek_string('as'): - self.tokens.pop() - with self.tokens.save_state(): - return self.parse_dotted_name() - - def parse_dotted_name(self) -> str: - name = [] - if self.tokens.peek_string('.'): - name.append('.') - self.tokens.pop() - if (self.tokens.peek_name() - and (tok := self.tokens.peek()) - and tok.string not in self._keywords): - name.append(self.tokens.pop_name()) - if not name: - raise ParseError('parse_dotted_name') - while self.tokens.peek_string('.'): - name.append('.') - self.tokens.pop() - if (self.tokens.peek_name() - and (tok := self.tokens.peek()) - and tok.string not in self._keywords): - name.append(self.tokens.pop_name()) - else: - break - - while self.tokens.peek_string('.'): - name.append('.') - self.tokens.pop() - return ''.join(name[::-1]) - - def parse_as_names(self) -> None: - self.parse_as_name() - while self.tokens.peek_string(','): - self.tokens.pop() - self.parse_as_name() - - def parse_as_name(self) -> None: - self.tokens.pop_name() - if self.tokens.peek_string('as'): - self.tokens.pop() - self.tokens.pop_name() - - -class ParseError(Exception): - pass - - -@dataclass(frozen=True) -class Result: - from_name: str | None = None - name: str | None = None - - -class TokenQueue: - """Provides helper functions for working with a sequence of tokens.""" - - def __init__(self, tokens: list[TokenInfo]) -> None: - self.tokens: list[TokenInfo] = tokens - self.index: int = 0 - self.stack: list[int] = [] - - @contextmanager - def save_state(self) -> Any: - try: - self.stack.append(self.index) - yield - except ParseError: - self.index = self.stack.pop() - else: - self.stack.pop() - - def __bool__(self) -> bool: - return self.index < len(self.tokens) - - def peek(self) -> TokenInfo | None: - if not self: - return None - return self.tokens[self.index] - - def peek_name(self) -> bool: - if not (tok := self.peek()): - return False - return tok.type == tokenize.NAME - - def pop_name(self) -> str: - tok = self.pop() - if tok.type != tokenize.NAME: - raise ParseError('pop_name') - return tok.string - - def peek_string(self, string: str) -> bool: - if not (tok := self.peek()): - return False - return tok.string == string - - def pop_string(self, string: str) -> str: - tok = self.pop() - if tok.string != string: - raise ParseError('pop_string') - return tok.string - - def pop(self) -> TokenInfo: - if not self: - raise ParseError('pop') - tok = self.tokens[self.index] - self.index += 1 - return tok From 10da15b3b1509333a296663ac54a7a8822945664 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 07:58:33 +0200 Subject: [PATCH 11/15] Avoid calling rstrip more than once --- Lib/_pyrepl/_module_completer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py index d480085a3339cc..1e4337c9f0fbcb 100644 --- a/Lib/_pyrepl/_module_completer.py +++ b/Lib/_pyrepl/_module_completer.py @@ -223,9 +223,10 @@ def parse_import(self) -> Result: raise ParseError('parse_import') def parse_from_import(self) -> Result: - if self.code.rstrip().endswith('import') and self.code.endswith(' '): + stripped = self.code.rstrip() + if stripped.endswith('import') and self.code.endswith(' '): return Result(from_name=self.parse_empty_from_import(), name='') - if self.code.rstrip().endswith('from') and self.code.endswith(' '): + if stripped.endswith('from') and self.code.endswith(' '): return Result(from_name='') if self.tokens.peek_string('(') or self.tokens.peek_string(','): return Result(from_name=self.parse_empty_from_import(), name='') From fd81999df165f15a901cccef6e6b0ca35b4b2324 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 09:01:18 +0200 Subject: [PATCH 12/15] Catch exceptions --- Lib/_pyrepl/_module_completer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py index 1e4337c9f0fbcb..0af335ba7f4ff9 100644 --- a/Lib/_pyrepl/_module_completer.py +++ b/Lib/_pyrepl/_module_completer.py @@ -41,7 +41,12 @@ def get_completions(self, line: str) -> list[str]: result = ImportParser(line).parse() if not result: return [] - return self.complete(*result) + try: + return self.complete(*result) + except Exception: + # Some unexpected error occurred, make it look like + # no completions are available + return [] def complete(self, from_name: str | None, name: str | None) -> list[str]: if from_name is None: From 8fba3d3d57158b095eaa78472c2f5cfe2bb6c3e0 Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 10:08:41 +0200 Subject: [PATCH 13/15] Fix tests --- Lib/test/test_pyrepl/test_pyrepl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index 5b7424580a2cec..376c5759a0d6e9 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -25,8 +25,9 @@ code_to_events, ) from _pyrepl.console import Event +from _pyrepl._module_completer import ImportParser, ModuleCompleter from _pyrepl.readline import (ReadlineAlikeReader, ReadlineConfig, - _ReadlineWrapper, ImportParser, ModuleCompleter) + _ReadlineWrapper) from _pyrepl.readline import multiline_input as readline_multiline_input try: From f4e290a03f1f8e42f47175e372248c62b14b726d Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Sat, 19 Apr 2025 10:30:36 +0200 Subject: [PATCH 14/15] Every Reader has its own ModuleCompleter instance --- Lib/_pyrepl/_module_completer.py | 5 +++++ Lib/_pyrepl/readline.py | 5 ++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py index 0af335ba7f4ff9..1fb043e0b70479 100644 --- a/Lib/_pyrepl/_module_completer.py +++ b/Lib/_pyrepl/_module_completer.py @@ -15,6 +15,11 @@ from typing import Any, Iterable, Iterator, Mapping +def make_default_module_completer() -> ModuleCompleter: + # Inside pyrepl, __package__ is set to '_pyrepl' + return ModuleCompleter(namespace={'__package__': '_pyrepl'}) + + class ModuleCompleter: """A completer for Python import statements. diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index 71c3e441d9d261..27037f730c200a 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -39,7 +39,7 @@ from . import commands, historical_reader from .completing_reader import CompletingReader from .console import Console as ConsoleType -from ._module_completer import ModuleCompleter +from ._module_completer import ModuleCompleter, make_default_module_completer Console: type[ConsoleType] _error: tuple[type[Exception], ...] | type[Exception] @@ -100,8 +100,7 @@ class ReadlineConfig: readline_completer: Completer | None = None completer_delims: frozenset[str] = frozenset(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?") - # Inside pyrepl, __package__ is set to '_pyrepl' - module_completer: ModuleCompleter = ModuleCompleter(namespace={'__package__': '_pyrepl'}) + module_completer: ModuleCompleter = field(default_factory=make_default_module_completer) @dataclass(kw_only=True) class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader): From 602121da0ff3bcbe313e7be0a2ff4dac3997083d Mon Sep 17 00:00:00 2001 From: Tomas Roun Date: Wed, 23 Apr 2025 21:43:18 +0200 Subject: [PATCH 15/15] tests: Only look for modules in the stdlib --- Lib/test/test_pyrepl/test_pyrepl.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index 376c5759a0d6e9..5600dcb8a82bc0 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -10,7 +10,7 @@ import tempfile from unittest import TestCase, skipUnless, skipIf from unittest.mock import patch -from test.support import force_not_colorized, make_clean_env, patch_list +from test.support import force_not_colorized, make_clean_env from test.support import SHORT_TIMEOUT from test.support.import_helper import import_module from test.support.os_helper import unlink @@ -898,6 +898,12 @@ def test_func(self): class TestPyReplModuleCompleter(TestCase): + def setUp(self): + self._saved_sys_path = sys.path + + def tearDown(self): + sys.path = self._saved_sys_path + def prepare_reader(self, events, namespace): console = FakeConsole(events) config = ReadlineConfig() @@ -905,16 +911,14 @@ def prepare_reader(self, events, namespace): reader = ReadlineAlikeReader(console=console, config=config) return reader - @patch_list(sys.meta_path) def test_import_completions(self): - from importlib.machinery import BuiltinImporter - # Remove all importers except for the builtin one - # to prevent searching anything but the builtin modules. + import importlib + # Make iter_modules() search only the standard library. # This makes the test more reliable in case there are # other user packages/scripts on PYTHONPATH which can # intefere with the completions. - sys.meta_path = [finder for finder in sys.meta_path - if isinstance(finder, BuiltinImporter)] + lib_path = os.path.dirname(importlib.__path__[0]) + sys.path = [lib_path] cases = ( ("import path\t\n", "import pathlib"),