From 293fb0aad9504ddccc68423a5648f5d456ee1068 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 22 Sep 2024 14:25:04 -0700 Subject: [PATCH 1/8] gh-119180: Fix bug where fwdrefs were evaluated in the annotationlib module scope We were sometimes passing None as the globals argument to eval(), which makes it inherit the globals from the calling scope. Instead, ensure that globals is always non-None. The test was passing accidentally because I passed "annotationlib" as a module object; fix that. Also document the parameters to ForwardRef() and remove two unused private ones. --- Lib/annotationlib.py | 32 +++++++++++++++++++------------- Lib/test/test_annotationlib.py | 6 ++++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 9d1943b27e8e9c..ffd47497056289 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -45,7 +45,17 @@ class Format(enum.IntEnum): class ForwardRef: - """Wrapper that holds a forward reference.""" + """Wrapper that holds a forward reference. + + Constructor arguments: + - arg: a string representing the code to be evaluated. + - module: the module where the forward reference was created. Must be a string, + not a module object. + - owner: The owning object (module, class, or function). + - is_argument: Does nothing, retained for compatibility. + - is_class: True if the forward reference was created in class scope. + + """ __slots__ = _SLOTS @@ -57,8 +67,6 @@ def __init__( owner=None, is_argument=True, is_class=False, - _globals=None, - _cell=None, ): if not isinstance(arg, str): raise TypeError(f"Forward reference must be a string -- got {arg!r}") @@ -71,8 +79,8 @@ def __init__( self.__forward_module__ = module self.__code__ = None self.__ast_node__ = None - self.__globals__ = _globals - self.__cell__ = _cell + self.__globals__ = None + self.__cell__ = None self.__owner__ = owner def __init_subclass__(cls, /, *args, **kwds): @@ -115,6 +123,10 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): elif callable(owner): globals = getattr(owner, "__globals__", None) + # If we pass None to eval() below, the globals of this module are used. + if globals is None: + globals = {} + if locals is None: locals = {} if isinstance(owner, type): @@ -134,14 +146,8 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): # but should in turn be overridden by names in the class scope # (which here are called `globalns`!) if type_params is not None: - if globals is None: - globals = {} - else: - globals = dict(globals) - if locals is None: - locals = {} - else: - locals = dict(locals) + globals = dict(globals) + locals = dict(locals) for param in type_params: param_name = param.__name__ if not self.__forward_is_class__ or param_name not in globals: diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index ce4f92624d9036..321cd38a9a7fff 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -1,6 +1,7 @@ """Tests for the annotations module.""" import annotationlib +import collections import functools import itertools import pickle @@ -278,11 +279,12 @@ class Gen[T]: ) def test_fwdref_with_module(self): - self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format) + self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format) + self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter) with self.assertRaises(NameError): # If globals are passed explicitly, we don't look at the module dict - ForwardRef("Format", module=annotationlib).evaluate(globals={}) + ForwardRef("Format", module="annotationlib").evaluate(globals={}) def test_fwdref_value_is_cached(self): fr = ForwardRef("hello") From 3cfa5048b7de0fae7d412845a11efdc497e90ceb Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 22 Sep 2024 14:40:24 -0700 Subject: [PATCH 2/8] gh-119180: Optimize annotationlib to avoid creation of AST nodes and eval() Often, ForwardRefs represent a single simple name. In that case, we can avoid going through the overhead of creating AST nodes and code objects and calling eval(): we can simply look up the name directly in the relevant namespaces. --- Lib/annotationlib.py | 77 ++++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index ffd47497056289..ea2ba087e60873 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -1,6 +1,7 @@ """Helpers for introspecting and wrapping annotations.""" import ast +import builtins import enum import functools import sys @@ -154,8 +155,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): globals[param_name] = param locals.pop(param_name, None) - code = self.__forward_code__ - value = eval(code, globals=globals, locals=locals) + arg = self.__forward_arg__ + if arg.isidentifier(): + if arg in locals: + value = locals[arg] + elif globals is not None and arg in globals: + value = globals[arg] + elif hasattr(builtins, arg): + return getattr(builtins, arg) + else: + raise NameError(arg) + else: + code = self.__forward_code__ + value = eval(code, globals=globals, locals=locals) self.__forward_evaluated__ = True self.__forward_value__ = value return value @@ -254,7 +266,9 @@ class _Stringifier: __slots__ = _SLOTS def __init__(self, node, globals=None, owner=None, is_class=False, cell=None): - assert isinstance(node, ast.AST) + # Either an AST node or a simple str (for the common case where a ForwardRef + # represent a single name). + assert isinstance(node, (ast.AST, str)) self.__arg__ = None self.__forward_evaluated__ = False self.__forward_value__ = None @@ -267,18 +281,25 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None): self.__cell__ = cell self.__owner__ = owner - def __convert(self, other): + def __convert_to_ast(self, other): if isinstance(other, _Stringifier): + if isinstance(other.__ast_node__, str): + return ast.Name(id=other.__ast_node__) return other.__ast_node__ elif isinstance(other, slice): return ast.Slice( - lower=self.__convert(other.start) if other.start is not None else None, - upper=self.__convert(other.stop) if other.stop is not None else None, - step=self.__convert(other.step) if other.step is not None else None, + lower=self.__convert_to_ast(other.start) if other.start is not None else None, + upper=self.__convert_to_ast(other.stop) if other.stop is not None else None, + step=self.__convert_to_ast(other.step) if other.step is not None else None, ) else: return ast.Constant(value=other) + def __get_ast(self): + if isinstance(self.__ast_node__, str): + return ast.Name(id=self.__ast_node__) + return self.__ast_node__ + def __make_new(self, node): return _Stringifier( node, self.__globals__, self.__owner__, self.__forward_is_class__ @@ -292,38 +313,37 @@ def __hash__(self): def __getitem__(self, other): # Special case, to avoid stringifying references to class-scoped variables # as '__classdict__["x"]'. - if ( - isinstance(self.__ast_node__, ast.Name) - and self.__ast_node__.id == "__classdict__" - ): + if self.__ast_node__ == "__classdict__": raise KeyError if isinstance(other, tuple): - elts = [self.__convert(elt) for elt in other] + elts = [self.__convert_to_ast(elt) for elt in other] other = ast.Tuple(elts) else: - other = self.__convert(other) + other = self.__convert_to_ast(other) assert isinstance(other, ast.AST), repr(other) - return self.__make_new(ast.Subscript(self.__ast_node__, other)) + return self.__make_new(ast.Subscript(self.__get_ast(), other)) def __getattr__(self, attr): - return self.__make_new(ast.Attribute(self.__ast_node__, attr)) + return self.__make_new(ast.Attribute(self.__get_ast(), attr)) def __call__(self, *args, **kwargs): return self.__make_new( ast.Call( - self.__ast_node__, - [self.__convert(arg) for arg in args], + self.__get_ast(), + [self.__convert_to_ast(arg) for arg in args], [ - ast.keyword(key, self.__convert(value)) + ast.keyword(key, self.__convert_to_ast(value)) for key, value in kwargs.items() ], ) ) def __iter__(self): - yield self.__make_new(ast.Starred(self.__ast_node__)) + yield self.__make_new(ast.Starred(self.__get_ast())) def __repr__(self): + if isinstance(self.__ast_node__, str): + return self.__ast_node__ return ast.unparse(self.__ast_node__) def __format__(self, format_spec): @@ -332,7 +352,7 @@ def __format__(self, format_spec): def _make_binop(op: ast.AST): def binop(self, other): return self.__make_new( - ast.BinOp(self.__ast_node__, op, self.__convert(other)) + ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other)) ) return binop @@ -356,7 +376,7 @@ def binop(self, other): def _make_rbinop(op: ast.AST): def rbinop(self, other): return self.__make_new( - ast.BinOp(self.__convert(other), op, self.__ast_node__) + ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast()) ) return rbinop @@ -381,9 +401,9 @@ def _make_compare(op): def compare(self, other): return self.__make_new( ast.Compare( - left=self.__ast_node__, + left=self.__get_ast(), ops=[op], - comparators=[self.__convert(other)], + comparators=[self.__convert_to_ast(other)], ) ) @@ -400,7 +420,7 @@ def compare(self, other): def _make_unary_op(op): def unary_op(self): - return self.__make_new(ast.UnaryOp(op, self.__ast_node__)) + return self.__make_new(ast.UnaryOp(op, self.__get_ast())) return unary_op @@ -422,7 +442,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False): def __missing__(self, key): fwdref = _Stringifier( - ast.Name(id=key), + key, globals=self.globals, owner=self.owner, is_class=self.is_class, @@ -480,7 +500,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): name = freevars[i] else: name = "__cell__" - fwdref = _Stringifier(ast.Name(id=name)) + fwdref = _Stringifier(name) new_closure.append(types.CellType(fwdref)) closure = tuple(new_closure) else: @@ -532,7 +552,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): else: name = "__cell__" fwdref = _Stringifier( - ast.Name(id=name), + name, cell=cell, owner=owner, globals=annotate.__globals__, @@ -555,6 +575,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): result = func(Format.VALUE) for obj in globals.stringifiers: obj.__class__ = ForwardRef + if isinstance(obj.__ast_node__, str): + obj.__arg__ = obj.__ast_node__ + obj.__ast_node__ = None return result elif format == Format.VALUE: # Should be impossible because __annotate__ functions must not raise From eaea393ec63c19a9a4a13dff86dbdbe0c3b2b17a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 23 Sep 2024 07:19:53 -0700 Subject: [PATCH 3/8] add a test case --- Lib/test/test_annotationlib.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 321cd38a9a7fff..1425fcc05180a0 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -280,7 +280,13 @@ class Gen[T]: def test_fwdref_with_module(self): self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format) - self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter) + self.assertIs( + ForwardRef("Counter", module="collections").evaluate(), collections.Counter + ) + self.assertEqual( + ForwardRef("Counter[int]", module="collections").evaluate(), + collections.Counter[int], + ) with self.assertRaises(NameError): # If globals are passed explicitly, we don't look at the module dict From e274f998644153e817d5f7a267934e9004039c1a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 23 Sep 2024 07:24:25 -0700 Subject: [PATCH 4/8] more tests --- Lib/test/test_annotationlib.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 1425fcc05180a0..44fd70a24a3fa4 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -299,6 +299,24 @@ def test_fwdref_value_is_cached(self): self.assertIs(fr.evaluate(globals={"hello": str}), str) self.assertIs(fr.evaluate(), str) + def test_fwdref_with_owner(self): + self.assertEqual( + ForwardRef("Counter[int]", owner=collections).evaluate(), + collections.Counter[int], + ) + + def test_name_lookup_without_eval(self): + # test the codepath where we look up simple names directly in the + # namespaces without going through eval() + self.assertIs(ForwardRef("int").evaluate(), int) + self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str) + self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float) + self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str) + + with self.assertRaises(NameError): + ForwardRef("doesntexist").evaluate() + + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): From 0dc0a68239fd45621cc665c75e1dbd4f5c166e03 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 23 Sep 2024 07:34:28 -0700 Subject: [PATCH 5/8] fmt --- Lib/test/test_annotationlib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 44fd70a24a3fa4..36918410507627 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -317,7 +317,6 @@ def test_name_lookup_without_eval(self): ForwardRef("doesntexist").evaluate() - class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): self.assertEqual(annotationlib.get_annotations(int), {}) From fefa1263cb4c943bc432d75a98ad2313bfa5c39a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 24 Sep 2024 10:22:15 -0700 Subject: [PATCH 6/8] one more test --- Lib/test/test_annotationlib.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 343d57a279eda3..073df4ec50d095 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -1,6 +1,7 @@ """Tests for the annotations module.""" import annotationlib +import builtins import collections import functools import itertools @@ -324,6 +325,8 @@ def test_name_lookup_without_eval(self): self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str) self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float) self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str) + with support.swap_attr(builtins, "int", dict): + self.assertIs(ForwardRef("int").evaluate(), dict) with self.assertRaises(NameError): ForwardRef("doesntexist").evaluate() From ab30124a3d1f233b5c4bcc5d6ad393f27b592c56 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 24 Sep 2024 14:14:54 -0700 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Victor Stinner --- Lib/annotationlib.py | 7 ++++--- Lib/test/test_annotationlib.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index fcc386786e144a..77aac8bf2f0eb3 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -296,9 +296,10 @@ def __convert_to_ast(self, other): return ast.Constant(value=other) def __get_ast(self): - if isinstance(self.__ast_node__, str): - return ast.Name(id=self.__ast_node__) - return self.__ast_node__ + node = self.__ast_node__ + if isinstance(node, str): + return ast.Name(id=node) + return node def __make_new(self, node): return _Stringifier( diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 073df4ec50d095..c0820556cc952b 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -282,7 +282,8 @@ class Gen[T]: def test_fwdref_with_module(self): self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format) self.assertIs( - ForwardRef("Counter", module="collections").evaluate(), collections.Counter + ForwardRef("Counter", module="collections").evaluate(), + collections.Counter ) self.assertEqual( ForwardRef("Counter[int]", module="collections").evaluate(), From cebb78e4ecb5ce7b703ed8995e319549ba5ba2ab Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 25 Sep 2024 13:45:33 -0700 Subject: [PATCH 8/8] PR feedback --- Lib/annotationlib.py | 5 +++-- Lib/test/test_annotationlib.py | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 77aac8bf2f0eb3..be3bc275817f50 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -4,6 +4,7 @@ import builtins import enum import functools +import keyword import sys import types @@ -156,10 +157,10 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): locals.pop(param_name, None) arg = self.__forward_arg__ - if arg.isidentifier(): + if arg.isidentifier() and not keyword.iskeyword(arg): if arg in locals: value = locals[arg] - elif globals is not None and arg in globals: + elif arg in globals: value = globals[arg] elif hasattr(builtins, arg): return getattr(builtins, arg) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index c0820556cc952b..cc051ef3b93658 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -332,6 +332,14 @@ def test_name_lookup_without_eval(self): with self.assertRaises(NameError): ForwardRef("doesntexist").evaluate() + def test_fwdref_invalid_syntax(self): + fr = ForwardRef("if") + with self.assertRaises(SyntaxError): + fr.evaluate() + fr = ForwardRef("1+") + with self.assertRaises(SyntaxError): + fr.evaluate() + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self):