diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py index f067a26d1fb3ae..a67cd301f6eaa8 100644 --- a/Lib/test/test_clinic.py +++ b/Lib/test/test_clinic.py @@ -609,6 +609,35 @@ def test_directive_output_invalid_command(self): """ self.expect_failure(block, err, lineno=2) + def test_validate_cloned_init(self): + block = """ + /*[clinic input] + class C "void *" "" + C.meth + a: int + [clinic start generated code]*/ + /*[clinic input] + @classmethod + C.__init__ = C.meth + [clinic start generated code]*/ + """ + err = "'__init__' must be a normal method, not a class or static method" + self.expect_failure(block, err, lineno=8) + + def test_validate_cloned_new(self): + block = """ + /*[clinic input] + class C "void *" "" + C.meth + a: int + [clinic start generated code]*/ + /*[clinic input] + C.__new__ = C.meth + [clinic start generated code]*/ + """ + err = "'__new__' must be a class method" + self.expect_failure(block, err, lineno=7) + class ParseFileUnitTest(TestCase): def expect_parsing_failure( @@ -1918,7 +1947,7 @@ class Foo "" "" self.parse_function(block) def test_new_must_be_a_class_method(self): - err = "__new__ must be a class method!" + err = "'__new__' must be a class method!" block = """ module foo class Foo "" "" @@ -1927,7 +1956,7 @@ class Foo "" "" self.expect_failure(block, err, lineno=2) def test_init_must_be_a_normal_method(self): - err = "__init__ must be a normal method, not a class or static method!" + err = "'__init__' must be a normal method, not a class or static method!" block = """ module foo class Foo "" "" @@ -2030,7 +2059,7 @@ def test_illegal_c_identifier(self): self.expect_failure(block, err, lineno=2) def test_cannot_convert_special_method(self): - err = "__len__ is a special method and cannot be converted" + err = "'__len__' is a special method and cannot be converted" block = """ class T "" "" T.__len__ diff --git a/Tools/clinic/clinic.py b/Tools/clinic/clinic.py index 1e0303c77087eb..11dbfb3fbe858e 100755 --- a/Tools/clinic/clinic.py +++ b/Tools/clinic/clinic.py @@ -4840,6 +4840,21 @@ def state_dsl_start(self, line: str) -> None: self.next(self.state_modulename_name, line) + def update_function_kind(self, fullname: str) -> None: + fields = fullname.split('.') + name = fields.pop() + _, cls = self.clinic._module_and_class(fields) + if name in unsupported_special_methods: + fail(f"{name!r} is a special method and cannot be converted to Argument Clinic!") + if name == '__new__': + if (self.kind is not CLASS_METHOD) or (not cls): + fail("'__new__' must be a class method!") + self.kind = METHOD_NEW + elif name == '__init__': + if (self.kind is not CALLABLE) or (not cls): + fail("'__init__' must be a normal method, not a class or static method!") + self.kind = METHOD_INIT + def state_modulename_name(self, line: str) -> None: # looking for declaration, which establishes the leftmost column # line should be @@ -4888,6 +4903,7 @@ def state_modulename_name(self, line: str) -> None: function_name = fields.pop() module, cls = self.clinic._module_and_class(fields) + self.update_function_kind(full_name) overrides: dict[str, Any] = { "name": function_name, "full_name": full_name, @@ -4948,20 +4964,9 @@ def state_modulename_name(self, line: str) -> None: function_name = fields.pop() module, cls = self.clinic._module_and_class(fields) - fields = full_name.split('.') - if fields[-1] in unsupported_special_methods: - fail(f"{fields[-1]} is a special method and cannot be converted to Argument Clinic! (Yet.)") - - if fields[-1] == '__new__': - if (self.kind is not CLASS_METHOD) or (not cls): - fail("__new__ must be a class method!") - self.kind = METHOD_NEW - elif fields[-1] == '__init__': - if (self.kind is not CALLABLE) or (not cls): - fail("__init__ must be a normal method, not a class or static method!") - self.kind = METHOD_INIT - if not return_converter: - return_converter = init_return_converter() + self.update_function_kind(full_name) + if self.kind is METHOD_INIT and not return_converter: + return_converter = init_return_converter() if not return_converter: return_converter = CReturnConverter()