Skip to content

Commit a714309

Browse files
committed
fix
1 parent 4fbc405 commit a714309

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

mypy/stubgen.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,18 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
500500
name = f"**{name}"
501501

502502
args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
503-
if o.name == "__init__" and is_dataclass_generated and "**" in args:
503+
504+
is_dataclass_generated = (
505+
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
506+
)
507+
if o.name == "__init__" and is_dataclass_generated and "**" in [a.name for a in args]:
504508
# The dataclass plugin generates invalid nameless "*" and "**" arguments
505-
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
506-
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
507-
args[args.index("**")] = f"**{new_name}__" # same here
509+
new_name = "".join(a.name.strip("*") for a in args)
510+
for arg in args:
511+
if arg.name == "*":
512+
arg.name = f"*{new_name}_" # this name is guaranteed to be unique
513+
elif arg.name == "**":
514+
arg.name = f"**{new_name}__" # same here
508515
return args
509516

510517
def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None:

mypy/stubutil.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,10 @@ def import_lines(self) -> list[str]:
453453
# be imported from it. the names can also be alias in the form 'original as alias'
454454
module_map: Mapping[str, list[str]] = defaultdict(list)
455455

456-
for name in sorted(self.required_names):
456+
for name in sorted(
457+
self.required_names,
458+
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
459+
):
457460
# If we haven't seen this name in an import statement, ignore it
458461
if name not in self.module_for:
459462
continue
@@ -477,7 +480,7 @@ def import_lines(self) -> list[str]:
477480
assert "." not in name # Because reexports only has nonqualified names
478481
result.append(f"import {name} as {name}\n")
479482
else:
480-
result.append(f"import {self.direct_imports[name]}\n")
483+
result.append(f"import {name}\n")
481484

482485
# Now generate all the from ... import ... lines collected in module_map
483486
for module, names in sorted(module_map.items()):

0 commit comments

Comments
 (0)