Skip to content

Add support for inline from_queryset in model classes #1045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def is_annotated_model_fullname(model_cls_fullname: str) -> bool:
return model_cls_fullname.startswith(WITH_ANNOTATIONS_FULLNAME + "[")


def create_type_info(name: str, module: str, bases: List[Instance]) -> TypeInfo:

# make new class expression
classdef = ClassDef(name, Block([]))
classdef.fullname = module + "." + name

# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, module)
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()

classdef.info = new_typeinfo

return new_typeinfo


def add_new_class_for_module(
module: MypyFile,
name: str,
Expand All @@ -217,15 +234,7 @@ def add_new_class_for_module(
) -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names)

# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = module.fullname + "." + new_class_unique_name

# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname)
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
new_typeinfo = create_type_info(new_class_unique_name, module.fullname, bases)

# add fields
if fields:
Expand All @@ -237,7 +246,6 @@ def add_new_class_for_module(
MDEF, var, plugin_generated=True, no_serialize=no_serialize
)

classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(
GDEF, new_typeinfo, plugin_generated=True, no_serialize=no_serialize
)
Expand Down Expand Up @@ -382,29 +390,25 @@ def copy_method_to_another_class(
method_node: FuncDef,
return_type: Optional[MypyType] = None,
original_module_name: Optional[str] = None,
) -> None:
) -> bool:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
if not semanal_api.final_iteration:
semanal_api.defer()
return

arguments, return_type = build_unannotated_method_args(method_node)
add_method_to_class(
semanal_api, ctx.cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type
)
return
return True

method_type = method_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return
return False

if return_type is None:
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
if return_type is None:
return
return False

# We build the arguments from the method signature (`CallableType`), because if we were to
# use the arguments from the method node (`FuncDef.arguments`) we're not compatible with
Expand All @@ -417,7 +421,7 @@ def copy_method_to_another_class(
):
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
if bound_arg_type is None:
return
return False
if arg_name is None and hasattr(method_node, "arguments"):
arg_name = method_node.arguments[pos].variable.name
arguments.append(
Expand All @@ -435,6 +439,8 @@ def copy_method_to_another_class(
semanal_api, ctx.cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type
)

return True


def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
Expand Down
6 changes: 0 additions & 6 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
from mypy_django_plugin.transformers.managers import (
create_new_manager_class_from_from_queryset_method,
fail_if_manager_type_created_in_model_body,
resolve_manager_method,
)
from mypy_django_plugin.transformers.models import (
Expand Down Expand Up @@ -237,11 +236,6 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
django_context=self.django_context,
)

elif method_name == "from_queryset":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return fail_if_manager_type_created_in_model_body

return None

def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
Expand Down
181 changes: 94 additions & 87 deletions mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
FuncBase,
FuncDef,
MemberExpr,
NameExpr,
OverloadedFuncDef,
RefExpr,
StrExpr,
SymbolTableNode,
TypeInfo,
Var,
)
from mypy.plugin import AttributeContext, ClassDefContext, DynamicClassDefContext, MethodContext
from mypy.plugin import AttributeContext, DynamicClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import AnyType, CallableType, Instance, ProperType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from typing_extensions import Final

from mypy_django_plugin import errorcodes
from mypy_django_plugin.lib import fullnames, helpers

MANAGER_METHODS_RETURNING_QUERYSET: Final = frozenset(
Expand Down Expand Up @@ -182,81 +180,110 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
"""
semanal_api = helpers.get_semanal_api(ctx)

# TODO: Emit an error when called in a class scope
if semanal_api.is_class_scope():
return

# Don't redeclare the manager class if we've already defined it.
manager_node = semanal_api.lookup_current_scope(ctx.name)
if manager_node and isinstance(manager_node.node, TypeInfo):
# This is just a deferral run where our work is already finished
return

callee = ctx.call.callee
assert isinstance(callee, MemberExpr)
assert isinstance(callee.expr, RefExpr)

base_manager_info = callee.expr.node
if base_manager_info is None:
if not semanal_api.final_iteration:
semanal_api.defer()
new_manager_info = create_manager_info_from_from_queryset_call(ctx.api, ctx.call, ctx.name)
if new_manager_info is None:
if not ctx.api.final_iteration:
ctx.api.defer()
return

assert isinstance(base_manager_info, TypeInfo)
# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)

passed_queryset = ctx.call.args[0]
assert isinstance(passed_queryset, NameExpr)

derived_queryset_fullname = passed_queryset.fullname
if derived_queryset_fullname is None:
# In some cases, due to the way the semantic analyzer works, only passed_queryset.name is available.
# But it should be analyzed again, so this isn't a problem.
return
def create_manager_info_from_from_queryset_call(
api: SemanticAnalyzerPluginInterface, call_expr: CallExpr, name: Optional[str] = None
) -> Optional[TypeInfo]:
"""
Extract manager and queryset TypeInfo from a from_queryset call.
"""

base_manager_instance = fill_typevars(base_manager_info)
assert isinstance(base_manager_instance, Instance)
new_manager_info = semanal_api.basic_new_typeinfo(
ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
)
if (
# Check that this is a from_queryset call on a manager subclass
not isinstance(call_expr.callee, MemberExpr)
or not isinstance(call_expr.callee.expr, RefExpr)
or not isinstance(call_expr.callee.expr.node, TypeInfo)
or not call_expr.callee.expr.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)
or not call_expr.callee.name == "from_queryset"
# Check that the call has one or two arguments and that the first is a
# QuerySet subclass
or not 1 <= len(call_expr.args) <= 2
or not isinstance(call_expr.args[0], RefExpr)
or not isinstance(call_expr.args[0].node, TypeInfo)
or not call_expr.args[0].node.has_base(fullnames.QUERYSET_CLASS_FULLNAME)
):
return None

sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None
if sym.node is None:
if not semanal_api.final_iteration:
semanal_api.defer()
else:
# inherit from Any to prevent false-positives, if queryset class cannot be resolved
new_manager_info.fallback_to_any = True
return
base_manager_info, queryset_info = call_expr.callee.expr.node, call_expr.args[0].node
if queryset_info.fullname is None:
# In some cases, due to the way the semantic analyzer works, only
# passed_queryset.name is available. But it should be analyzed again,
# so this isn't a problem.
return None

derived_queryset_info = sym.node
assert isinstance(derived_queryset_info, TypeInfo)

new_manager_info.line = ctx.call.line
new_manager_info.type_vars = base_manager_info.type_vars
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
# Stash the queryset fullname which was passed to .from_queryset
# So that our 'resolve_manager_method' attribute hook can fetch the method from that QuerySet class
new_manager_info.metadata["django"] = {"from_queryset_manager": derived_queryset_fullname}

if len(ctx.call.args) > 1:
expr = ctx.call.args[1]
assert isinstance(expr, StrExpr)
custom_manager_generated_name = expr.value
if len(call_expr.args) == 2 and isinstance(call_expr.args[1], StrExpr):
manager_name = call_expr.args[1].value
else:
custom_manager_generated_name = base_manager_info.name + "From" + derived_queryset_info.name
manager_name = f"{base_manager_info.name}From{queryset_info.name}"

custom_manager_generated_fullname = ".".join(["django.db.models.manager", custom_manager_generated_name])
new_manager_info = create_manager_class(api, base_manager_info, name or manager_name, call_expr.line)

popuplate_manager_from_queryset(new_manager_info, queryset_info)

manager_fullname = ".".join(["django.db.models.manager", manager_name])

base_manager_info = new_manager_info.mro[1]
base_manager_info.metadata.setdefault("from_queryset_managers", {})
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
base_manager_info.metadata["from_queryset_managers"][manager_fullname] = new_manager_info.fullname

# Add the new manager to the current module
module = api.modules[api.cur_mod_id]
module.names[name or manager_name] = SymbolTableNode(
GDEF, new_manager_info, plugin_generated=True, no_serialize=False
)

return new_manager_info

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)

class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
self_type = fill_typevars(new_manager_info)
assert isinstance(self_type, Instance)
def create_manager_class(
api: SemanticAnalyzerPluginInterface, base_manager_info: TypeInfo, name: str, line: int
) -> TypeInfo:

base_manager_instance = fill_typevars(base_manager_info)
assert isinstance(base_manager_instance, Instance)

manager_info = helpers.create_type_info(name, api.cur_mod_id, bases=[base_manager_instance])
manager_info.line = line
manager_info.type_vars = base_manager_info.type_vars
manager_info.defn.type_vars = base_manager_info.defn.type_vars
manager_info.defn.line = line
manager_info.metaclass_type = manager_info.calculate_metaclass_type()

return manager_info


def popuplate_manager_from_queryset(manager_info: TypeInfo, queryset_info: TypeInfo) -> None:
"""
Add methods from the QuerySet class to the manager.
"""

# Stash the queryset fullname which was passed to .from_queryset So that
# our 'resolve_manager_method' attribute hook can fetch the method from
# that QuerySet class
django_metadata = helpers.get_django_metadata(manager_info)
django_metadata["from_queryset_manager"] = queryset_info.fullname

# We collect and mark up all methods before django.db.models.query.QuerySet as class members
for class_mro_info in derived_queryset_info.mro:
for class_mro_info in queryset_info.mro:
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
break
for name, sym in class_mro_info.names.items():
Expand All @@ -270,39 +297,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
# queryset_method: Any = ...
#
helpers.add_new_sym_for_info(
new_manager_info,
manager_info,
name=name,
sym_type=AnyType(TypeOfAny.special_form),
)

# For methods on BaseManager that return a queryset we need to update the
# return type to be the actual queryset subclass used. This is done by
# adding the methods as attributes with type Any to the manager class,
# similar to how custom queryset methods are handled above. The actual type
# of these methods are resolved in resolve_manager_method.
for name in MANAGER_METHODS_RETURNING_QUERYSET:
# For methods on BaseManager that return a queryset we need to update
# the return type to be the actual queryset subclass used. This is done
# by adding the methods as attributes with type Any to the manager
# class. The actual type of these methods are resolved in
# resolve_manager_method.
for method_name in MANAGER_METHODS_RETURNING_QUERYSET:
helpers.add_new_sym_for_info(
new_manager_info,
name=name,
manager_info,
name=method_name,
sym_type=AnyType(TypeOfAny.special_form),
)

# Insert the new manager (dynamic) class
assert semanal_api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, new_manager_info, plugin_generated=True))


def fail_if_manager_type_created_in_model_body(ctx: MethodContext) -> MypyType:
"""
Method hook that checks if method `<Manager>.from_queryset` is called inside a model class body.

Doing so won't, for instance, trigger the dynamic class hook(`create_new_manager_class_from_from_queryset_method`)
for managers.
"""
api = helpers.get_typechecker_api(ctx)
outer_model_info = api.scope.active_class()
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# Not inside a model class definition
return ctx.default_return_type

api.fail("`.from_queryset` called from inside model class body", ctx.context, code=errorcodes.MANAGER_UNTYPED)
return ctx.default_return_type
Loading