Skip to content

Commit b0cecb6

Browse files
committed
New required field handling
1 parent 3c246b3 commit b0cecb6

6 files changed

Lines changed: 287 additions & 28 deletions

File tree

pydantic_forms/core/shared.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import structlog
1717
from pydantic import BaseModel, ConfigDict, PydanticUndefinedAnnotation, version
1818

19+
from pydantic_forms.utils.required import determine_required_form_fields
20+
1921
logger = structlog.get_logger(__name__)
2022

2123

@@ -45,6 +47,16 @@ def get_value(k: str, v: Any) -> Any:
4547
mutable_data = {k: get_value(k, v) for k, v in data.items()}
4648
super().__init__(**mutable_data)
4749

50+
@classmethod
51+
def model_json_schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]:
52+
schema = super().model_json_schema(*args, **kwargs)
53+
required_fields = determine_required_form_fields(cls)
54+
55+
# TODO add toggle
56+
if new_required := [k for k, v in required_fields.items() if v]:
57+
schema["required"] = new_required
58+
return schema
59+
4860
if PYDANTIC_VERSION in ("2.9", "2.10", "2.11"):
4961

5062
@classmethod

pydantic_forms/utils/required.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2019-2026 SURF.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
import types
14+
from collections.abc import Iterable
15+
from typing import (
16+
Annotated,
17+
Any,
18+
TypeVar,
19+
Union,
20+
get_args,
21+
get_origin,
22+
)
23+
24+
from more_itertools import first
25+
from pydantic import BaseModel
26+
from pydantic.fields import FieldInfo
27+
28+
29+
def is_union(tp: type[Any] | None) -> bool:
30+
return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap]
31+
32+
33+
def get_origin_and_args(t: Any) -> tuple[Any, tuple[Any, ...]]:
34+
"""Return the origin and args of the given type.
35+
36+
When wrapped in Annotated[] this is removed.
37+
"""
38+
origin, args = get_origin(t), get_args(t)
39+
if origin is not Annotated:
40+
return origin, args
41+
42+
t_unwrapped = first(args)
43+
return get_origin(t_unwrapped), get_args(t_unwrapped)
44+
45+
46+
def is_union_type(t: Any, test_type: type | None = None) -> bool:
47+
"""Check if `t` is union type (Union[Type, AnotherType]).
48+
49+
Optionally check if T is of `test_type` We cannot check for literal Nones.
50+
51+
>>> is_union_type(Union[int, str])
52+
True
53+
>>> is_union_type(Annotated[Union[int, str], "foo"])
54+
True
55+
>>> is_union_type(Union[int, str], str)
56+
True
57+
>>> is_union_type(Union[int, str], bool)
58+
False
59+
>>> is_union_type(Union[int, str], Union[int, str])
60+
True
61+
>>> is_union_type(Union[int, None])
62+
True
63+
>>> is_union_type(Annotated[Union[int, None], "foo"])
64+
True
65+
>>> is_union_type(int)
66+
False
67+
"""
68+
origin, args = get_origin_and_args(t)
69+
if not is_union(origin):
70+
return False
71+
if not test_type:
72+
return True
73+
74+
if is_of_type(t, test_type):
75+
return True
76+
77+
for arg in args:
78+
result = is_of_type(arg, test_type)
79+
if result:
80+
return result
81+
return False
82+
83+
84+
def is_of_type(t: Any, test_type: Any) -> bool:
85+
"""Check if annotation type is valid for type.
86+
87+
>>> is_of_type(list, list)
88+
True
89+
>>> is_of_type(list[int], list[int])
90+
True
91+
>>> is_of_type(strEnum, str)
92+
True
93+
>>> is_of_type(strEnum, Enum)
94+
True
95+
>>> is_of_type(int, str)
96+
False
97+
>>> is_of_type(Any, Any)
98+
True
99+
>>> is_of_type(Any, int)
100+
True
101+
"""
102+
if t is Any:
103+
return True
104+
105+
if is_union_type(test_type):
106+
return any(get_origin(t) is get_origin(arg) for arg in get_args(test_type))
107+
108+
if (
109+
get_origin(t)
110+
and get_origin(test_type)
111+
and get_origin(t) is get_origin(test_type)
112+
and get_args(t) == get_args(test_type)
113+
):
114+
return True
115+
116+
if test_type is t:
117+
# Test type is a typing type instance and matches
118+
return True
119+
120+
# Workaround for the fact that you can't call issubclass on typing types
121+
try:
122+
return issubclass(t, test_type)
123+
except TypeError:
124+
return False
125+
126+
127+
def filter_nonetype(types_: Iterable[Any]) -> Iterable[Any]:
128+
def not_nonetype(type_: Any) -> bool:
129+
return type_ is not None.__class__
130+
131+
return filter(not_nonetype, types_)
132+
133+
134+
def is_optional_type(t: Any, test_type: type | None = None) -> bool:
135+
"""Check if `t` is optional type (Union[None, ...]).
136+
137+
And optionally check if T is of `test_type`
138+
139+
>>> is_optional_type(Optional[int])
140+
True
141+
>>> is_optional_type(Annotated[Optional[int], "foo"])
142+
True
143+
>>> is_optional_type(Annotated[int, "foo"])
144+
False
145+
>>> is_optional_type(Union[None, int])
146+
True
147+
>>> is_optional_type(Union[int, str, None])
148+
True
149+
>>> is_optional_type(Union[int, str])
150+
False
151+
>>> is_optional_type(Optional[int], int)
152+
True
153+
>>> is_optional_type(Optional[int], str)
154+
False
155+
>>> is_optional_type(Annotated[Optional[int], "foo"], int)
156+
True
157+
>>> is_optional_type(Annotated[Optional[int], "foo"], str)
158+
False
159+
>>> is_optional_type(Optional[State], int)
160+
False
161+
>>> is_optional_type(Optional[State], State)
162+
True
163+
"""
164+
origin, args = get_origin_and_args(t)
165+
166+
if is_union(origin) and None.__class__ in args:
167+
field_type = first(filter_nonetype(args))
168+
return test_type is None or is_of_type(field_type, test_type)
169+
return False
170+
171+
172+
# TODO The above code is copy-pasted from orchestrator-core/types.py.
173+
# Maybe something to move to a shared lib some day.
174+
175+
176+
def _is_required(field: FieldInfo) -> bool:
177+
"""Determine whether a FormPage field is required.
178+
179+
Our logic extends that of Pydantic because of our common practice to use FormPage to transmit data.
180+
TODO explain better
181+
"""
182+
match field.annotation, field.is_required(), field.default, field.json_schema_extra:
183+
case _, True, _, _:
184+
# Pydantic considers the field as required
185+
return True
186+
case _, False, None, _:
187+
# Pydantic considers the field as optional, and the default is none
188+
return False
189+
case _, _, _, {"format": "read_only_field"}:
190+
# pydantic-forms fields which we never want to mark as required
191+
# TODO: is this complete?
192+
return False
193+
case t, _, _, _:
194+
# A field is required if it's not optional (makes sense, doesn't it?)
195+
return not is_optional_type(t)
196+
case _:
197+
# For any combination we've missed, the safest assumption is that it's not required
198+
return False
199+
200+
201+
BaseModelDerivative = TypeVar("BaseModelDerivative", bound=BaseModel)
202+
203+
204+
def determine_required_form_fields(form: type[BaseModelDerivative]) -> dict[str, bool]:
205+
return {name: _is_required(field) for name, field in form.model_fields.items()}

pydantic_forms/validators/components/read_only.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _get_read_only_schema(default: Any) -> dict:
6262
"uniforms": forms_schema, # Deprecated
6363
constants.EXTRA_PROPERTIES: forms_schema,
6464
"type": _get_json_type(default),
65+
"format": "read_only_field",
6566
}
6667

6768

tests/unit_tests/test_core.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from inline_snapshot import snapshot
23
from pydantic import ConfigDict, model_validator
34

45
from pydantic_forms.core import FormPage, generate_form, post_form
@@ -7,6 +8,7 @@
78

89
# TODO: Remove when generic forms of pydantic_forms are ready
910
from pydantic_forms.utils.json import json_dumps, json_loads
11+
from pydantic_forms.utils.required import determine_required_form_fields
1012

1113

1214
class TestChoices(strEnum):
@@ -223,3 +225,34 @@ def validator(cls, values: dict) -> dict:
223225
assert len(e.value.errors) == 1
224226
assert e.value.errors[0]["loc"] == ("__root__",)
225227
assert e.value.errors[0]["msg"] == "too high"
228+
229+
230+
class FormWithAllDefaultScenarios(FormPage):
231+
field1: int
232+
field2: int = 1
233+
field3: int | None # Probably not used
234+
field4: int | None = None # Dito
235+
field5: int | None = 1
236+
237+
238+
def test_defaults():
239+
assert FormWithAllDefaultScenarios.model_json_schema() == snapshot(
240+
{
241+
"additionalProperties": False,
242+
"properties": {
243+
"field1": {"title": "Field1", "type": "integer"},
244+
"field2": {"default": 1, "title": "Field2", "type": "integer"},
245+
"field3": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Field3"},
246+
"field4": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "Field4"},
247+
"field5": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": 1, "title": "Field5"},
248+
},
249+
"required": ["field1", "field2", "field3"],
250+
"title": "unknown",
251+
"type": "object",
252+
}
253+
)
254+
255+
256+
def test_defaults2():
257+
requireds = determine_required_form_fields(FormWithAllDefaultScenarios)
258+
assert requireds == snapshot({"field1": True, "field2": True, "field3": True, "field4": False, "field5": False})
Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from uuid import uuid4
22

3+
from inline_snapshot import Is, snapshot
4+
35
from pydantic_forms.core import FormPage
46
from pydantic_forms.validators import DisplaySubscription, Label, migration_summary
57

@@ -35,34 +37,37 @@ class Form(FormPage):
3537
label: Label
3638
summary: Summary
3739

38-
expected = {
39-
"$defs": {"MigrationSummaryValue": {"properties": {}, "title": "MigrationSummaryValue", "type": "object"}},
40-
"additionalProperties": False,
41-
"properties": {
42-
"display_sub": {
43-
"default": str(some_sub_id),
44-
"format": "subscription",
45-
"title": "Display Sub",
46-
"type": "string",
47-
},
48-
"label": {
49-
"anyOf": [{"type": "string"}, {"type": "null"}],
50-
"format": "label",
51-
"default": None,
52-
"title": "Label",
53-
"type": "string",
54-
},
55-
"summary": {
56-
"$ref": "#/$defs/MigrationSummaryValue",
57-
"default": None,
58-
"format": "summary",
59-
"type": "string",
60-
"uniforms": {"data": {"headers": ["one"]}},
61-
"extraProperties": {"data": {"headers": ["one"]}},
40+
expected = snapshot(
41+
{
42+
"$defs": {"MigrationSummaryValue": {"properties": {}, "title": "MigrationSummaryValue", "type": "object"}},
43+
"additionalProperties": False,
44+
"properties": {
45+
"display_sub": {
46+
"default": Is(str(some_sub_id)),
47+
"format": "subscription",
48+
"title": "Display Sub",
49+
"type": "string",
50+
},
51+
"label": {
52+
"anyOf": [{"type": "string"}, {"type": "null"}],
53+
"default": None,
54+
"format": "label",
55+
"title": "Label",
56+
"type": "string",
57+
},
58+
"summary": {
59+
"$ref": "#/$defs/MigrationSummaryValue",
60+
"default": None,
61+
"extraProperties": {"data": {"headers": ["one"]}},
62+
"format": "summary",
63+
"type": "string",
64+
"uniforms": {"data": {"headers": ["one"]}},
65+
},
6266
},
63-
},
64-
"title": "unknown",
65-
"type": "object",
66-
}
67+
"title": "unknown",
68+
"type": "object",
69+
"required": ["display_sub"],
70+
}
71+
)
6772

6873
assert Form.model_json_schema() == expected

tests/unit_tests/test_read_only_field.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Form(FormPage):
4949
**enum,
5050
"const": schema_value,
5151
"default": schema_value,
52+
"format": "read_only_field",
5253
"title": "Read Only",
5354
"uniforms": {"disabled": True, "value": schema_value},
5455
"extraProperties": {"disabled": True, "value": schema_value},
@@ -88,6 +89,7 @@ class Form(FormPage):
8889
"read_only_list": {
8990
"default": schema_value,
9091
"items": expected_item_type,
92+
"format": "read_only_field",
9193
"title": "Read Only List",
9294
"uniforms": {"disabled": True, "value": schema_value},
9395
"extraProperties": {"disabled": True, "value": schema_value},
@@ -175,6 +177,7 @@ class Form(FormPage):
175177
"read_only_list": {
176178
"default": schema_value,
177179
"items": expected_item_type,
180+
"format": "read_only_field",
178181
"title": "Read Only List",
179182
"uniforms": {"disabled": True, "value": schema_value},
180183
"extraProperties": {"disabled": True, "value": schema_value},

0 commit comments

Comments
 (0)