Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def get_field_metadata(field: Any) -> Any:
return FakeMetadata()

def post_init_field_info(field_info: FieldInfo) -> None:
return None
if IS_PYDANTIC_V2:
if field_info.alias and not field_info.validation_alias:
field_info.validation_alias = field_info.alias
if field_info.alias and not field_info.serialization_alias:
field_info.serialization_alias = field_info.alias
else:
field_info._validate() # type: ignore[attr-defined]

# Dummy to make it importable
def _calculate_keys(
Expand Down
102 changes: 66 additions & 36 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic's validation_alias type annotation is wider (validation_alias: str | AliasPath | AliasChoices | None), but I think it's fine if we only support str for now and extend it later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference on this.

serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -260,6 +262,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -314,6 +318,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -349,6 +355,8 @@ def Field(
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
validation_alias: Optional[str] = None,
serialization_alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
Expand Down Expand Up @@ -387,43 +395,65 @@ def Field(
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
primary_key=primary_key,
foreign_key=foreign_key,
ondelete=ondelete,
unique=unique,
nullable=nullable,
index=index,
sa_type=sa_type,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
field_info_kwargs = {
"alias": alias,
"title": title,
"description": description,
"exclude": exclude,
"include": include,
"const": const,
"gt": gt,
"ge": ge,
"lt": lt,
"le": le,
"multiple_of": multiple_of,
"max_digits": max_digits,
"decimal_places": decimal_places,
"min_items": min_items,
"max_items": max_items,
"unique_items": unique_items,
"min_length": min_length,
"max_length": max_length,
"allow_mutation": allow_mutation,
"regex": regex,
"discriminator": discriminator,
"repr": repr,
"primary_key": primary_key,
"foreign_key": foreign_key,
"ondelete": ondelete,
"unique": unique,
"nullable": nullable,
"index": index,
"sa_type": sa_type,
"sa_column": sa_column,
"sa_column_args": sa_column_args,
"sa_column_kwargs": sa_column_kwargs,
**current_schema_extra,
)
}
if IS_PYDANTIC_V2:
# Add Pydantic v2 specific parameters
field_info_kwargs.update(
{
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
}
)
field_info = FieldInfo(
default,
default_factory=default_factory,
**field_info_kwargs,
)
else:
if validation_alias:
raise RuntimeError("validation_alias is not supported in Pydantic v1")
if serialization_alias:
raise RuntimeError("serialization_alias is not supported in Pydantic v1")
field_info = FieldInfo(
default,
default_factory=default_factory,
**field_info_kwargs,
)

post_init_field_info(field_info)
return field_info

Expand Down
176 changes: 176 additions & 0 deletions tests/test_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Type, Union

import pytest
from pydantic import VERSION, BaseModel, ValidationError
from pydantic import Field as PField
from sqlmodel import Field, SQLModel

# -----------------------------------------------------------------------------------
# Models


class PydanticUser(BaseModel):
full_name: str = PField(alias="fullName")


class SQLModelUser(SQLModel):
full_name: str = Field(alias="fullName")


# Models with config (validate_by_name=True)


if VERSION.startswith("2."):

class PydanticUserWithConfig(PydanticUser):
model_config = {"validate_by_name": True}

class SQLModelUserWithConfig(SQLModelUser):
model_config = {"validate_by_name": True}

else:

class PydanticUserWithConfig(PydanticUser):
class Config:
allow_population_by_field_name = True

class SQLModelUserWithConfig(SQLModelUser):
class Config:
allow_population_by_field_name = True


# -----------------------------------------------------------------------------------
# Tests

# Test validate by name


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_create_with_field_name(model: Union[Type[PydanticUser], Type[SQLModelUser]]):
with pytest.raises(ValidationError):
model(full_name="Alice")


@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig])
def test_create_with_field_name_with_config(
model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]],
):
user = model(full_name="Alice")
assert user.full_name == "Alice"


# Test validate by alias


@pytest.mark.parametrize(
"model",
[PydanticUser, SQLModelUser, PydanticUserWithConfig, SQLModelUserWithConfig],
)
def test_create_with_alias(
model: Union[
Type[PydanticUser],
Type[SQLModelUser],
Type[PydanticUserWithConfig],
Type[SQLModelUserWithConfig],
],
):
user = model(fullName="Bob") # using alias
assert user.full_name == "Bob"


# Test validate by name and alias


@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig])
def test_create_with_both_prefers_alias(
model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]],
):
user = model(full_name="IGNORED", fullName="Charlie")
assert user.full_name == "Charlie" # alias should take precedence


# Test serialize


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_dict_default_uses_field_names(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Dana")
data = user.dict()
assert "full_name" in data
assert "fullName" not in data
assert data["full_name"] == "Dana"


# Test serialize by alias


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_dict_default_uses_aliases(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Dana")
data = user.dict(by_alias=True)
assert "fullName" in data
assert "full_name" not in data
assert data["fullName"] == "Dana"


# Test json by alias


@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser])
def test_json_by_alias(
model: Union[Type[PydanticUser], Type[SQLModelUser]],
):
user = model(fullName="Frank")
json_data = user.json(by_alias=True)
assert ('"fullName":"Frank"' in json_data) or ('"fullName": "Frank"' in json_data)
assert "full_name" not in json_data


# Pydantic v2 specific models - only define if we're running Pydantic v2
if VERSION.startswith("2."):

class PydanticUserV2(BaseModel):
first_name: str = PField(
validation_alias="firstName", serialization_alias="f_name"
)

class SQLModelUserV2(SQLModel):
first_name: str = Field(
validation_alias="firstName", serialization_alias="f_name"
)
else:
# Dummy classes for Pydantic v1 to prevent import errors
PydanticUserV2 = None
SQLModelUserV2 = None


@pytest.mark.skipif(
not VERSION.startswith("2."),
reason="validation_alias and serialization_alias are not supported in Pydantic v1",
)
@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2])
def test_create_with_validation_alias(
model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]],
):
user = model(firstName="John")
assert user.first_name == "John"


@pytest.mark.skipif(
not VERSION.startswith("2."),
reason="validation_alias and serialization_alias are not supported in Pydantic v1",
)
@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2])
def test_serialize_with_serialization_alias(
model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]],
):
user = model(firstName="Jane")
data = user.dict(by_alias=True)
assert "f_name" in data
assert "firstName" not in data
assert "first_name" not in data
assert data["f_name"] == "Jane"
Loading