Skip to content

Commit 004b893

Browse files
authored
Backport recent improvements to the implementation of Protocol (#324)
1 parent f84880d commit 004b893

File tree

3 files changed

+146
-27
lines changed

3 files changed

+146
-27
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Unreleased
2+
3+
- Speedup `issubclass()` checks against simple runtime-checkable protocols by
4+
around 6% (backporting https://github.com/python/cpython/pull/112717, by Alex
5+
Waygood).
6+
- Fix a regression in the implementation of protocols where `typing.Protocol`
7+
classes that were not marked as `@runtime_checkable` would be unnecessarily
8+
introspected, potentially causing exceptions to be raised if the protocol had
9+
problematic members. Patch by Alex Waygood, backporting
10+
https://github.com/python/cpython/pull/113401.
11+
112
# Release 4.9.0 (December 9, 2023)
213

314
This feature release adds `typing_extensions.ReadOnly`, as specified

src/test_typing_extensions.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,8 +2817,8 @@ def meth(self): pass # noqa: B027
28172817

28182818
self.assertNotIn("__protocol_attrs__", vars(NonP))
28192819
self.assertNotIn("__protocol_attrs__", vars(NonPR))
2820-
self.assertNotIn("__callable_proto_members_only__", vars(NonP))
2821-
self.assertNotIn("__callable_proto_members_only__", vars(NonPR))
2820+
self.assertNotIn("__non_callable_proto_members__", vars(NonP))
2821+
self.assertNotIn("__non_callable_proto_members__", vars(NonPR))
28222822

28232823
acceptable_extra_attrs = {
28242824
'_is_protocol', '_is_runtime_protocol', '__parameters__',
@@ -2891,11 +2891,26 @@ def __subclasshook__(cls, other):
28912891
@skip_if_py312b1
28922892
def test_issubclass_fails_correctly(self):
28932893
@runtime_checkable
2894-
class P(Protocol):
2894+
class NonCallableMembers(Protocol):
28952895
x = 1
2896+
2897+
class NotRuntimeCheckable(Protocol):
2898+
def callable_member(self) -> int: ...
2899+
2900+
@runtime_checkable
2901+
class RuntimeCheckable(Protocol):
2902+
def callable_member(self) -> int: ...
2903+
28962904
class C: pass
2897-
with self.assertRaisesRegex(TypeError, r"issubclass\(\) arg 1 must be a class"):
2898-
issubclass(C(), P)
2905+
2906+
# These three all exercise different code paths,
2907+
# but should result in the same error message:
2908+
for protocol in NonCallableMembers, NotRuntimeCheckable, RuntimeCheckable:
2909+
with self.subTest(proto_name=protocol.__name__):
2910+
with self.assertRaisesRegex(
2911+
TypeError, r"issubclass\(\) arg 1 must be a class"
2912+
):
2913+
issubclass(C(), protocol)
28992914

29002915
def test_defining_generic_protocols(self):
29012916
T = TypeVar('T')
@@ -3456,6 +3471,7 @@ def method(self) -> None: ...
34563471

34573472
@skip_if_early_py313_alpha
34583473
def test_protocol_issubclass_error_message(self):
3474+
@runtime_checkable
34593475
class Vec2D(Protocol):
34603476
x: float
34613477
y: float
@@ -3471,6 +3487,39 @@ def square_norm(self) -> float:
34713487
with self.assertRaisesRegex(TypeError, re.escape(expected_error_message)):
34723488
issubclass(int, Vec2D)
34733489

3490+
def test_nonruntime_protocol_interaction_with_evil_classproperty(self):
3491+
class classproperty:
3492+
def __get__(self, instance, type):
3493+
raise RuntimeError("NO")
3494+
3495+
class Commentable(Protocol):
3496+
evil = classproperty()
3497+
3498+
# recognised as a protocol attr,
3499+
# but not actually accessed by the protocol metaclass
3500+
# (which would raise RuntimeError) for non-runtime protocols.
3501+
# See gh-113320
3502+
self.assertEqual(get_protocol_members(Commentable), {"evil"})
3503+
3504+
def test_runtime_protocol_interaction_with_evil_classproperty(self):
3505+
class CustomError(Exception): pass
3506+
3507+
class classproperty:
3508+
def __get__(self, instance, type):
3509+
raise CustomError
3510+
3511+
with self.assertRaises(TypeError) as cm:
3512+
@runtime_checkable
3513+
class Commentable(Protocol):
3514+
evil = classproperty()
3515+
3516+
exc = cm.exception
3517+
self.assertEqual(
3518+
exc.args[0],
3519+
"Failed to determine whether protocol member 'evil' is a method member"
3520+
)
3521+
self.assertIs(type(exc.__cause__), CustomError)
3522+
34743523

34753524
class Point2DGeneric(Generic[T], TypedDict):
34763525
a: T
@@ -5263,7 +5312,7 @@ def test_typing_extensions_defers_when_possible(self):
52635312
'SupportsRound', 'Unpack',
52645313
}
52655314
if sys.version_info < (3, 13):
5266-
exclude |= {'NamedTuple', 'Protocol'}
5315+
exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'}
52675316
if not hasattr(typing, 'ReadOnly'):
52685317
exclude |= {'TypedDict', 'is_typeddict'}
52695318
for item in typing_extensions.__all__:

src/typing_extensions.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def clear_overloads():
473473
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
474474
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
475475
"__subclasshook__", "__orig_class__", "__init__", "__new__",
476-
"__protocol_attrs__", "__callable_proto_members_only__",
476+
"__protocol_attrs__", "__non_callable_proto_members__",
477477
"__match_args__",
478478
}
479479

@@ -521,6 +521,22 @@ def _no_init(self, *args, **kwargs):
521521
if type(self)._is_protocol:
522522
raise TypeError('Protocols cannot be instantiated')
523523

524+
def _type_check_issubclass_arg_1(arg):
525+
"""Raise TypeError if `arg` is not an instance of `type`
526+
in `issubclass(arg, <protocol>)`.
527+
528+
In most cases, this is verified by type.__subclasscheck__.
529+
Checking it again unnecessarily would slow down issubclass() checks,
530+
so, we don't perform this check unless we absolutely have to.
531+
532+
For various error paths, however,
533+
we want to ensure that *this* error message is shown to the user
534+
where relevant, rather than a typing.py-specific error message.
535+
"""
536+
if not isinstance(arg, type):
537+
# Same error message as for issubclass(1, int).
538+
raise TypeError('issubclass() arg 1 must be a class')
539+
524540
# Inheriting from typing._ProtocolMeta isn't actually desirable,
525541
# but is necessary to allow typing.Protocol and typing_extensions.Protocol
526542
# to mix without getting TypeErrors about "metaclass conflict"
@@ -551,11 +567,6 @@ def __init__(cls, *args, **kwargs):
551567
abc.ABCMeta.__init__(cls, *args, **kwargs)
552568
if getattr(cls, "_is_protocol", False):
553569
cls.__protocol_attrs__ = _get_protocol_attrs(cls)
554-
# PEP 544 prohibits using issubclass()
555-
# with protocols that have non-method members.
556-
cls.__callable_proto_members_only__ = all(
557-
callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__
558-
)
559570

560571
def __subclasscheck__(cls, other):
561572
if cls is Protocol:
@@ -564,26 +575,23 @@ def __subclasscheck__(cls, other):
564575
getattr(cls, '_is_protocol', False)
565576
and not _allow_reckless_class_checks()
566577
):
567-
if not isinstance(other, type):
568-
# Same error message as for issubclass(1, int).
569-
raise TypeError('issubclass() arg 1 must be a class')
578+
if not getattr(cls, '_is_runtime_protocol', False):
579+
_type_check_issubclass_arg_1(other)
580+
raise TypeError(
581+
"Instance and class checks can only be used with "
582+
"@runtime_checkable protocols"
583+
)
570584
if (
571-
not cls.__callable_proto_members_only__
585+
# this attribute is set by @runtime_checkable:
586+
cls.__non_callable_proto_members__
572587
and cls.__dict__.get("__subclasshook__") is _proto_hook
573588
):
574-
non_method_attrs = sorted(
575-
attr for attr in cls.__protocol_attrs__
576-
if not callable(getattr(cls, attr, None))
577-
)
589+
_type_check_issubclass_arg_1(other)
590+
non_method_attrs = sorted(cls.__non_callable_proto_members__)
578591
raise TypeError(
579592
"Protocols with non-method members don't support issubclass()."
580593
f" Non-method members: {str(non_method_attrs)[1:-1]}."
581594
)
582-
if not getattr(cls, '_is_runtime_protocol', False):
583-
raise TypeError(
584-
"Instance and class checks can only be used with "
585-
"@runtime_checkable protocols"
586-
)
587595
return abc.ABCMeta.__subclasscheck__(cls, other)
588596

589597
def __instancecheck__(cls, instance):
@@ -610,7 +618,8 @@ def __instancecheck__(cls, instance):
610618
val = inspect.getattr_static(instance, attr)
611619
except AttributeError:
612620
break
613-
if val is None and callable(getattr(cls, attr, None)):
621+
# this attribute is set by @runtime_checkable:
622+
if val is None and attr not in cls.__non_callable_proto_members__:
614623
break
615624
else:
616625
return True
@@ -678,8 +687,58 @@ def __init_subclass__(cls, *args, **kwargs):
678687
cls.__init__ = _no_init
679688

680689

690+
if sys.version_info >= (3, 13):
691+
runtime_checkable = typing.runtime_checkable
692+
else:
693+
def runtime_checkable(cls):
694+
"""Mark a protocol class as a runtime protocol.
695+
696+
Such protocol can be used with isinstance() and issubclass().
697+
Raise TypeError if applied to a non-protocol class.
698+
This allows a simple-minded structural check very similar to
699+
one trick ponies in collections.abc such as Iterable.
700+
701+
For example::
702+
703+
@runtime_checkable
704+
class Closable(Protocol):
705+
def close(self): ...
706+
707+
assert isinstance(open('/some/file'), Closable)
708+
709+
Warning: this will check only the presence of the required methods,
710+
not their type signatures!
711+
"""
712+
if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False):
713+
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
714+
' got %r' % cls)
715+
cls._is_runtime_protocol = True
716+
717+
# Only execute the following block if it's a typing_extensions.Protocol class.
718+
# typing.Protocol classes don't need it.
719+
if isinstance(cls, _ProtocolMeta):
720+
# PEP 544 prohibits using issubclass()
721+
# with protocols that have non-method members.
722+
# See gh-113320 for why we compute this attribute here,
723+
# rather than in `_ProtocolMeta.__init__`
724+
cls.__non_callable_proto_members__ = set()
725+
for attr in cls.__protocol_attrs__:
726+
try:
727+
is_callable = callable(getattr(cls, attr, None))
728+
except Exception as e:
729+
raise TypeError(
730+
f"Failed to determine whether protocol member {attr!r} "
731+
"is a method member"
732+
) from e
733+
else:
734+
if not is_callable:
735+
cls.__non_callable_proto_members__.add(attr)
736+
737+
return cls
738+
739+
681740
# The "runtime" alias exists for backwards compatibility.
682-
runtime = runtime_checkable = typing.runtime_checkable
741+
runtime = runtime_checkable
683742

684743

685744
# Our version of runtime-checkable protocols is faster on Python 3.8-3.11

0 commit comments

Comments
 (0)