Skip to content

Commit 6aafc5d

Browse files
committed
[ty] support kw_only=True for dataclasses
astral-sh/ty#111
1 parent 18ad284 commit 6aafc5d

File tree

6 files changed

+137
-10
lines changed

6 files changed

+137
-10
lines changed

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,40 @@ OrderTrueOverwritten(1) < OrderTrueOverwritten(2)
195195

196196
### `kw_only_default`
197197

198-
To do
198+
When provided, sets the default value for the `kw_only` parameter of `field()`.
199+
200+
```py
201+
from typing import dataclass_transform
202+
from dataclasses import field
203+
204+
@dataclass_transform(kw_only_default=True)
205+
def create_model(*, init=True): ...
206+
@create_model()
207+
class A:
208+
name: str = field(default="Voldemort")
209+
210+
a = A()
211+
a = A(name="Harry")
212+
a = A("Harry") # error: [too-many-positional-arguments]
213+
```
214+
215+
TODO: This can be overridden by the call to the decorator function.
216+
217+
```py
218+
from typing import dataclass_transform
219+
220+
@dataclass_transform(kw_only_default=True)
221+
def create_model(*, kw_only: bool = True): ...
222+
@create_model(kw_only=False)
223+
class CustomerModel:
224+
id: int
225+
name: str
226+
227+
# TODO: Should not emit errors
228+
# error: [missing-argument]
229+
# error: [too-many-positional-arguments]
230+
c = CustomerModel(1, "Harry")
231+
```
199232

200233
### `field_specifiers`
201234

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,70 @@ To do
465465

466466
### `kw_only`
467467

468-
To do
468+
An error is emitted if a dataclass is defined with `kw_only=True` and positional arguments are
469+
passed to the constructor.
470+
471+
```toml
472+
[environment]
473+
python-version = "3.10"
474+
```
475+
476+
```py
477+
from dataclasses import dataclass
478+
479+
@dataclass(kw_only=True)
480+
class A:
481+
x: int
482+
y: int
483+
484+
# error: [missing-argument] "No arguments provided for required parameters `x`, `y`"
485+
# error: [too-many-positional-arguments] "Too many positional arguments: expected 0, got 2"
486+
a = A(1, 2)
487+
a = A(x=1, y=2)
488+
```
489+
490+
The class-level parameter can be overridden per-field.
491+
492+
```py
493+
from dataclasses import dataclass, field
494+
495+
@dataclass(kw_only=True)
496+
class A:
497+
a: str = field(kw_only=False)
498+
b: int = 0
499+
500+
A("hi")
501+
```
502+
503+
If some fields are `kw_only`, they should appear after all positional fields in the `__init__`
504+
signature.
505+
506+
```py
507+
@dataclass
508+
class A:
509+
b: int = field(kw_only=True, default=3)
510+
a: str
511+
512+
A("hi")
513+
```
514+
515+
### `kw_only` - Python < 3.10
516+
517+
For Python < 3.10, `kw_only` is not supported.
518+
519+
```toml
520+
[environment]
521+
python-version = "3.9"
522+
```
523+
524+
```py
525+
from dataclasses import dataclass
526+
527+
@dataclass(kw_only=True) # TODO: Emit a diagnostic here
528+
class A:
529+
x: int
530+
y: int
531+
```
469532

470533
### `slots`
471534

crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ class Person:
6363
age: int | None = field(default=None, kw_only=True)
6464
role: str = field(default="user", kw_only=True)
6565

66-
# TODO: the `age` and `role` fields should be keyword-only
67-
# revealed: (self: Person, name: str, age: int | None = None, role: str = Literal["user"]) -> None
66+
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
6867
reveal_type(Person.__init__)
6968

7069
alice = Person(role="admin", name="Alice")
7170

72-
# TODO: this should be an error
71+
# error: [too-many-positional-arguments] "Too many positional arguments: expected 1, got 2"
7372
bob = Person("Bob", 30)
7473
```
7574

crates/ty_python_semantic/src/types.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6675,6 +6675,9 @@ pub struct FieldInstance<'db> {
66756675

66766676
/// Whether this field is part of the `__init__` signature, or not.
66776677
pub init: bool,
6678+
6679+
/// Whether or not this field can only be passed as a keyword argument to `__init__`.
6680+
pub kw_only: Option<bool>,
66786681
}
66796682

66806683
// The Salsa heap is tracked separately.
@@ -6690,6 +6693,7 @@ impl<'db> FieldInstance<'db> {
66906693
db,
66916694
self.default_type(db).normalized_impl(db, visitor),
66926695
self.init(db),
6696+
self.kw_only(db),
66936697
)
66946698
}
66956699
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use ruff_db::parsed::parsed_module;
1212
use smallvec::{SmallVec, smallvec, smallvec_inline};
1313

1414
use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type};
15+
use crate::Program;
1516
use crate::db::Db;
1617
use crate::dunder_all::dunder_all_names;
1718
use crate::place::{Boundness, Place};
@@ -33,7 +34,7 @@ use crate::types::{
3334
WrapperDescriptorKind, enums, ide_support, todo_type,
3435
};
3536
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
36-
use ruff_python_ast as ast;
37+
use ruff_python_ast::{self as ast, PythonVersion};
3738

3839
/// Binding information for a possible union of callables. At a call site, the arguments must be
3940
/// compatible with _all_ of the types in the union for the call to be valid.
@@ -860,7 +861,11 @@ impl<'db> Bindings<'db> {
860861
params |= DataclassParams::MATCH_ARGS;
861862
}
862863
if to_bool(kw_only, false) {
863-
params |= DataclassParams::KW_ONLY;
864+
if Program::get(db).python_version(db) >= PythonVersion::PY310 {
865+
params |= DataclassParams::KW_ONLY;
866+
} else {
867+
// TODO: emit diagnostic
868+
}
864869
}
865870
if to_bool(slots, false) {
866871
params |= DataclassParams::SLOTS;
@@ -919,7 +924,8 @@ impl<'db> Bindings<'db> {
919924
}
920925

921926
Some(KnownFunction::Field) => {
922-
if let [default, default_factory, init, ..] = overload.parameter_types()
927+
if let [default, default_factory, init, .., kw_only] =
928+
overload.parameter_types()
923929
{
924930
let default_ty = match (default, default_factory) {
925931
(Some(default_ty), _) => *default_ty,
@@ -933,6 +939,14 @@ impl<'db> Bindings<'db> {
933939
.map(|init| !init.bool(db).is_always_false())
934940
.unwrap_or(true);
935941

942+
let kw_only = if Program::get(db).python_version(db)
943+
>= PythonVersion::PY310
944+
{
945+
kw_only.map(|kw_only| !kw_only.bool(db).is_always_false())
946+
} else {
947+
None
948+
};
949+
936950
// `typeshed` pretends that `dataclasses.field()` returns the type of the
937951
// default value directly. At runtime, however, this function returns an
938952
// instance of `dataclasses.Field`. We also model it this way and return
@@ -942,7 +956,7 @@ impl<'db> Bindings<'db> {
942956
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
943957
overload.set_return_type(Type::KnownInstance(
944958
KnownInstanceType::Field(FieldInstance::new(
945-
db, default_ty, init,
959+
db, default_ty, init, kw_only,
946960
)),
947961
));
948962
}

crates/ty_python_semantic/src/types/class.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,9 @@ pub(crate) struct DataclassField<'db> {
11121112

11131113
/// Whether or not this field should appear in the signature of `__init__`.
11141114
pub(crate) init: bool,
1115+
1116+
/// Whether or not this field can only be passed as a keyword argument to `__init__`.
1117+
pub(crate) kw_only: Option<bool>,
11151118
}
11161119

11171120
/// Representation of a class definition statement in the AST: either a non-generic class, or a
@@ -1863,6 +1866,7 @@ impl<'db> ClassLiteral<'db> {
18631866
mut default_ty,
18641867
init_only: _,
18651868
init,
1869+
kw_only,
18661870
},
18671871
) in self.fields(db, specialization, field_policy)
18681872
{
@@ -1927,7 +1931,10 @@ impl<'db> ClassLiteral<'db> {
19271931
}
19281932
}
19291933

1930-
let mut parameter = if kw_only_field_seen || name == "__replace__" {
1934+
let mut parameter = if kw_only_field_seen
1935+
|| name == "__replace__"
1936+
|| kw_only.unwrap_or(has_dataclass_param(DataclassParams::KW_ONLY))
1937+
{
19311938
Parameter::keyword_only(field_name)
19321939
} else {
19331940
Parameter::positional_or_keyword(field_name)
@@ -1946,6 +1953,10 @@ impl<'db> ClassLiteral<'db> {
19461953
parameters.push(parameter);
19471954
}
19481955

1956+
// In the event that we have a mix of keyword-only and positional parameters, we need to sort them
1957+
// so that the keyword-only parameters appear after positional parameters.
1958+
parameters.sort_by_key(Parameter::is_keyword_only);
1959+
19491960
let mut signature = Signature::new(Parameters::new(parameters), return_ty);
19501961
signature.inherited_generic_context = self.generic_context(db);
19511962
Some(CallableType::function_like(db, signature))
@@ -2235,9 +2246,11 @@ impl<'db> ClassLiteral<'db> {
22352246
default_ty.map(|ty| ty.apply_optional_specialization(db, specialization));
22362247

22372248
let mut init = true;
2249+
let mut kw_only = None;
22382250
if let Some(Type::KnownInstance(KnownInstanceType::Field(field))) = default_ty {
22392251
default_ty = Some(field.default_type(db));
22402252
init = field.init(db);
2253+
kw_only = field.kw_only(db);
22412254
}
22422255

22432256
attributes.insert(
@@ -2247,6 +2260,7 @@ impl<'db> ClassLiteral<'db> {
22472260
default_ty,
22482261
init_only: attr.is_init_var(),
22492262
init,
2263+
kw_only,
22502264
},
22512265
);
22522266
}

0 commit comments

Comments
 (0)