diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3b3a07a14..ed7f10e52 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,14 @@ Changelog ========= +Version 0.5.0 +============= + +* Feature: Migrated from Pydantic v1 to Pydantic v2 +* Removed dependency on dataclasses-json in favor of Pydantic models +* Updated the JSON serialization to work with Pydantic v2 +* Updated schema validation to use the new Pydantic v2 API + Version 0.4.7 ============= diff --git a/deployment/scripts/sync_initial_messages.py b/deployment/scripts/sync_initial_messages.py index b1b76fc51..437fbf321 100644 --- a/deployment/scripts/sync_initial_messages.py +++ b/deployment/scripts/sync_initial_messages.py @@ -8,11 +8,14 @@ initial_messages_list = [ # Diagnostic VMs "cad11970efe9b7478300fd04d7cc91c646ca0a792b9cc718650f86e1ccfac73e", # Initial program - "3fc0aa9569da840c43e7bd2033c3c580abb46b007527d6d20f2d4e98e867f7af", # DiagVM + "3fc0aa9569da840c43e7bd2033c3c580abb46b007527d6d20f2d4e98e867f7af", # Old DiagVM Debian 12 + "63faf8b5db1cf8d965e6a464a0cb8062af8e7df131729e48738342d956f29ace", # Current Debian 12 DiagVM "67705389842a0a1b95eaa408b009741027964edc805997475e95c505d642edd8", # Legacy Diag VM # Volumes like runtimes, data, code, etc - "6b8618f5b8913c0f582f1a771a154a556ee3fa3437ef3cf91097819910cf383b", # Diag VM code volume - "f873715dc2feec3833074bd4b8745363a0e0093746b987b4c8191268883b2463", # Diag VM runtime volume + "6b8618f5b8913c0f582f1a771a154a556ee3fa3437ef3cf91097819910cf383b", # Old Diag VM code volume + "f873715dc2feec3833074bd4b8745363a0e0093746b987b4c8191268883b2463", # Old Diag VM runtime volume + "79f19811f8e843f37ff7535f634b89504da3d8f03e1f0af109d1791cf6add7af", # Diag VM code volume + "63f07193e6ee9d207b7d1fcf8286f9aee34e6f12f101d2ec77c1229f92964696", # Diag VM runtime volume "a92c81992e885d7a554fa78e255a5802404b7fdde5fbff20a443ccd13020d139", # Legacy Diag VM code volume "bd79839bf96e595a06da5ac0b6ba51dea6f7e2591bb913deccded04d831d29f4", # Legacy Diag VM runtime volume ] diff --git a/pyproject.toml b/pyproject.toml index ca5130b5e..7779d122b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "aiohttp-jinja2==1.6", "aioipfs~=0.7.1", "alembic==1.15.1", - "aleph-message==0.6.1", + "aleph-message==1.0.0", "aleph-nuls2==0.1", "aleph-p2p-client @ git+https://github.com/aleph-im/p2p-service-client-python@cbfebb871db94b2ca580e66104a67cd730c5020c", "asyncpg==0.30", @@ -46,7 +46,8 @@ dependencies = [ "multiaddr==0.0.9", # for libp2p-stubs "orjson>=3.7.7", # Minimum version for Python 3.11 "psycopg2-binary==2.9.10", # Note: psycopg3 is not yet supported by SQLAlchemy - "pycryptodome==3.22.0", # for libp2p-stubs + "pycryptodome==3.22.0", + "pydantic>=2.0.0,<3.0.0", "pymultihash==0.8.2", # for libp2p-stubs "pynacl==1.5", "pytezos-crypto==3.13.4.1", @@ -62,6 +63,7 @@ dependencies = [ "sqlalchemy-utils==0.41.2", "substrate-interface==1.7.11", "types-aiofiles==24.1.0.20241221", + "typing-extensions>=4.6.1", "ujson==5.10.0", # required by aiocache "urllib3==2.3", "uvloop==0.21", @@ -158,9 +160,7 @@ dependencies = [ "isort==5.13.2", "check-sdist==0.1.3", "sqlalchemy[mypy]==1.4.41", - "yamlfix==1.16.1", - # because of aleph messages otherwise yamlfix install a too new version - "pydantic>=1.10.5,<2.0.0", + "yamlfix==1.17.0", "pyproject-fmt==2.2.1", "types-aiofiles", "types-protobuf", diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 7eadf1afe..0ce3b3e33 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -1,4 +1,5 @@ import asyncio +import json from typing import Any, Dict, List, Mapping, Optional, Self, Set, Type, Union, cast import aio_pika.abc @@ -66,10 +67,12 @@ async def prepare_sync_event_payload( protocol=ChainSyncProtocol.ON_CHAIN_SYNC, version=1, content=OnChainContent( - messages=[OnChainMessage.from_orm(message) for message in messages] + messages=[ + OnChainMessage.model_validate(message) for message in messages + ] ), ) - archive_content: bytes = archive.json().encode("utf-8") + archive_content: bytes = archive.model_dump_json().encode("utf-8") ipfs_cid = await self.storage_service.add_file( session=session, file_content=archive_content, engine=ItemType.ipfs @@ -166,7 +169,9 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An ) try: - payload = cast(GenericMessageEvent, payload_model.parse_obj(tx.content)) + payload = cast( + GenericMessageEvent, payload_model.model_validate(tx.content) + ) except ValidationError: raise InvalidContent(f"Incompatible tx content for {tx.chain}/{tx.hash}") @@ -189,7 +194,7 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An item_hash=ItemHash(payload.content), metadata=None, ) - item_content = content.json(exclude_none=True) + item_content = json.dumps(content.model_dump(exclude_none=True)) else: item_content = payload.content diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index db6e8d542..0e9f81f72 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -84,16 +84,16 @@ def make_events_query( if not datetime_range and not block_range: raise ValueError("A range of datetimes or blocks must be specified.") - model: Union[Type[MessageEvent], Type[SyncEvent]] + model_fields: List[str] if event_type == ChainEventType.MESSAGE: - model = MessageEvent event_type_str = "messageEvents" + model_fields = list(MessageEvent.model_fields.keys()) else: - model = SyncEvent event_type_str = "syncEvents" + model_fields = list(SyncEvent.model_fields.keys()) - fields = "\n".join(model.__fields__.keys()) + fields = "\n".join(model_fields) params: Dict[str, Any] = { "blockchain": f'"{blockchain.value}"', "limit": limit, @@ -147,7 +147,7 @@ async def _query(self, query: str, model: Type[T]) -> T: response = await self.http_session.post("/", json={"query": query}) response.raise_for_status() response_json = await response.json() - return model.parse_obj(response_json) + return model.model_validate(response_json) async def fetch_account_state( self, @@ -196,7 +196,7 @@ def indexer_event_to_chain_tx( if isinstance(indexer_event, MessageEvent): protocol = ChainSyncProtocol.SMART_CONTRACT protocol_version = 1 - content = indexer_event.dict() + content = indexer_event.model_dump() else: sync_message = aleph_json.loads(indexer_event.message) diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index b8a951d8b..6cb34e6aa 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -162,7 +162,7 @@ async def fetch_messages( response.raise_for_status() response_json = await response.json() - return IndexerResponse[IndexerMessageEvent].parse_obj(response_json) + return IndexerResponse[IndexerMessageEvent].model_validate(response_json) def indexer_event_to_chain_tx( @@ -176,7 +176,7 @@ def indexer_event_to_chain_tx( publisher=indexer_event.source, protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=indexer_event.payload.dict(), + content=indexer_event.payload.model_dump(), ) return chain_tx diff --git a/src/aleph/db/models/messages.py b/src/aleph/db/models/messages.py index 0d2f8e894..eb06ae12f 100644 --- a/src/aleph/db/models/messages.py +++ b/src/aleph/db/models/messages.py @@ -14,7 +14,6 @@ StoreContent, ) from pydantic import ValidationError -from pydantic.error_wrappers import ErrorWrapper from sqlalchemy import ( ARRAY, TIMESTAMP, @@ -62,14 +61,14 @@ def validate_message_content( content_dict: Dict[str, Any], ) -> BaseContent: content_type = CONTENT_TYPE_MAP[message_type] - content = content_type.parse_obj(content_dict) + content = content_type.model_validate(content_dict) # Validate that the content time can be converted to datetime. This will # raise a ValueError and be caught # TODO: move this validation in aleph-message try: _ = dt.datetime.fromtimestamp(content_dict["time"]) except ValueError as e: - raise ValidationError([ErrorWrapper(e, loc="time")], model=content_type) from e + raise ValidationError(str(e)) from e return content diff --git a/src/aleph/handlers/content/vm.py b/src/aleph/handlers/content/vm.py index 35aa352e8..3f1896553 100644 --- a/src/aleph/handlers/content/vm.py +++ b/src/aleph/handlers/content/vm.py @@ -195,7 +195,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb: if content.on.message: vm.message_triggers = [ - subscription.dict() for subscription in content.on.message + subscription.model_dump() for subscription in content.on.message ] vm.code_volume = CodeVolumeDb( diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index e0e9347a5..c9283fa6b 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -1,9 +1,9 @@ import datetime as dt from decimal import Decimal -from typing import Dict, List, Optional +from typing import Annotated, Dict, List, Optional from aleph_message.models import Chain -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, field_validator from aleph.types.files import FileType from aleph.types.sort_order import SortOrder @@ -16,11 +16,16 @@ class GetAccountQueryParams(BaseModel): ) +FloatDecimal = Annotated[ + Decimal, PlainSerializer(lambda x: float(x), return_type=float, when_used="json") +] + + class GetAccountBalanceResponse(BaseModel): address: str - balance: Decimal - details: Optional[Dict[str, Decimal]] - locked_amount: Decimal + balance: FloatDecimal + details: Optional[Dict[str, FloatDecimal]] = None + locked_amount: FloatDecimal class GetAccountFilesQueryParams(BaseModel): @@ -53,7 +58,7 @@ class GetBalancesChainsQueryParams(BaseModel): ) min_balance: int = Field(default=0, ge=1, description="Minimum Balance needed") - @validator("chains", pre=True) + @field_validator("chains", mode="before") def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) @@ -61,8 +66,7 @@ def split_str(cls, v): class AddressBalanceResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) address: str balance: str @@ -78,8 +82,7 @@ class GetAccountFilesResponseItem(BaseModel): class GetAccountFilesResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) address: str total_size: int diff --git a/src/aleph/schemas/api/costs.py b/src/aleph/schemas/api/costs.py index ef66ae81e..03610bbf6 100644 --- a/src/aleph/schemas/api/costs.py +++ b/src/aleph/schemas/api/costs.py @@ -1,27 +1,25 @@ from typing import List -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, field_validator from aleph.toolkit.costs import format_cost_str class EstimatedCostDetailResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) type: str name: str cost_hold: str cost_stream: str - @validator("cost_hold", "cost_stream") + @field_validator("cost_hold", "cost_stream") def check_format_price(cls, v): return format_cost_str(v) class EstimatedCostsResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) required_tokens: float payment_type: str diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index dea39f1aa..c15d76ea7 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -25,10 +25,8 @@ ProgramContent, StoreContent, ) -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field, field_serializer -import aleph.toolkit.json as aleph_json from aleph.db.models import MessageDb from aleph.types.message_status import ErrorCode, MessageStatus @@ -39,26 +37,26 @@ class MessageConfirmation(BaseModel): """Format of the result when a message has been confirmed on a blockchain""" - class Config: - orm_mode = True - json_encoders = {dt.datetime: lambda d: d.timestamp()} + model_config = ConfigDict(from_attributes=True) chain: Chain height: int hash: str + datetime: dt.datetime + @field_serializer("datetime") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() -class BaseMessage(GenericModel, Generic[MType, ContentType]): - class Config: - orm_mode = True - json_loads = aleph_json.loads - json_encoders = {dt.datetime: lambda d: d.timestamp()} + +class BaseMessage(BaseModel, Generic[MType, ContentType]): + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime @@ -67,6 +65,10 @@ class Config: confirmed: bool confirmations: List[MessageConfirmation] + @field_serializer("time") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() + class AggregateMessage( BaseMessage[Literal[MessageType.aggregate], AggregateContent] # type: ignore @@ -127,13 +129,14 @@ def format_message(message: MessageDb) -> AlephMessage: message_type = message.type message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.from_orm(message) + + return message_cls.model_validate(message) def format_message_dict(message: Dict[str, Any]) -> AlephMessage: message_type = message.get("type") message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.parse_obj(message) + return message_cls.model_validate(message) class BaseMessageStatus(BaseModel): @@ -145,45 +148,41 @@ class BaseMessageStatus(BaseModel): # We already have a model for the validation of pending messages, but this one # is only used for formatting and does not try to be smart. class PendingMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime channel: Optional[str] = None - content: Optional[Dict[str, Any]] + content: Optional[Dict[str, Any]] = None reception_time: dt.datetime class PendingMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PENDING messages: List[PendingMessage] class ProcessedMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PROCESSED message: AlephMessage class ForgottenMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType item_type: ItemType item_hash: str @@ -201,18 +200,15 @@ class RejectedMessageStatus(BaseMessageStatus): status: MessageStatus = MessageStatus.REJECTED message: Mapping[str, Any] error_code: ErrorCode - details: Any + details: Any = None class MessageStatusInfo(BaseMessageStatus): - class Config: - orm_mode = True - fields = {"item_hash": {"exclude": True}} + model_config = ConfigDict(from_attributes=True) class MessageHashes(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) MessageWithStatus = Union[ @@ -224,12 +220,13 @@ class Config: class MessageListResponse(BaseModel): - class Config: - json_encoders = {dt.datetime: lambda d: d.timestamp()} - json_loads = aleph_json.loads - messages: List[AlephMessage] pagination_page: int pagination_total: int pagination_per_page: int pagination_item: Literal["messages"] = "messages" + time: dt.datetime + + @field_serializer("time") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() diff --git a/src/aleph/schemas/base_messages.py b/src/aleph/schemas/base_messages.py index c1837e9dd..8b42bc817 100644 --- a/src/aleph/schemas/base_messages.py +++ b/src/aleph/schemas/base_messages.py @@ -7,8 +7,8 @@ from typing import Any, Generic, Mapping, Optional, TypeVar, cast from aleph_message.models import BaseContent, Chain, ItemType, MessageType -from pydantic import root_validator, validator -from pydantic.generics import GenericModel +from pydantic import BaseModel, field_validator, model_validator +from pydantic_core.core_schema import ValidationInfo from aleph.utils import item_type_from_hash @@ -71,7 +71,7 @@ def base_message_validator_check_item_hash(v: Any, values: Mapping[str, Any]): return v -class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): +class AlephBaseMessage(BaseModel, Generic[MType, ContentType]): """ The base structure of an Aleph message. All the fields of this class appear in all the representations @@ -89,10 +89,12 @@ class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): channel: Optional[str] = None content: Optional[ContentType] = None - @root_validator() - def check_item_type(cls, values): - return base_message_validator_check_item_type(values) + @model_validator(mode="after") + def check_item_type(self) -> "AlephBaseMessage": + values = self.model_dump() + base_message_validator_check_item_type(values) + return self - @validator("item_hash") - def check_item_hash(cls, v: Any, values: Mapping[str, Any]): - return base_message_validator_check_item_hash(v, values) + @field_validator("item_hash") + def check_item_hash(cls, v: Any, info: ValidationInfo): + return base_message_validator_check_item_hash(v, info.data) diff --git a/src/aleph/schemas/chains/indexer_response.py b/src/aleph/schemas/chains/indexer_response.py index 8a7f9fae5..4bb344505 100644 --- a/src/aleph/schemas/chains/indexer_response.py +++ b/src/aleph/schemas/chains/indexer_response.py @@ -4,9 +4,9 @@ import datetime as dt from enum import Enum -from typing import List, Protocol, Tuple +from typing import Annotated, List, Protocol, Tuple -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, BeforeValidator, Field class GenericMessageEvent(Protocol): @@ -33,6 +33,17 @@ class EntityType(str, Enum): STATE = "state" +def split_datetime_ranges(v): + if isinstance(v, str): + return v.split("/") + return v + + +DateTimeRange = Annotated[ + Tuple[dt.datetime, dt.datetime], BeforeValidator(split_datetime_ranges) +] + + class AccountEntityState(BaseModel): blockchain: IndexerBlockchain type: EntityType @@ -40,14 +51,8 @@ class AccountEntityState(BaseModel): account: str completeHistory: bool progress: float - pending: List[Tuple[dt.datetime, dt.datetime]] - processed: List[Tuple[dt.datetime, dt.datetime]] - - @validator("pending", "processed", pre=True, each_item=True) - def split_datetime_ranges(cls, v): - if isinstance(v, str): - return v.split("/") - return v + pending: List[DateTimeRange] + processed: List[DateTimeRange] class IndexerAccountStateResponseData(BaseModel): diff --git a/src/aleph/schemas/chains/sync_events.py b/src/aleph/schemas/chains/sync_events.py index ab4304847..8cbbae22e 100644 --- a/src/aleph/schemas/chains/sync_events.py +++ b/src/aleph/schemas/chains/sync_events.py @@ -2,28 +2,27 @@ from typing import Annotated, List, Literal, Optional, Union from aleph_message.models import Chain, ItemHash, ItemType, MessageType -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from aleph.types.chain_sync import ChainSyncProtocol from aleph.types.channel import Channel class OnChainMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: ItemHash time: float channel: Optional[Channel] = None - @validator("time", pre=True) - def check_time(cls, v, values): + @field_validator("time", mode="before") + def check_time(cls, v, info): if isinstance(v, dt.datetime): return v.timestamp() diff --git a/src/aleph/schemas/chains/tezos_indexer_response.py b/src/aleph/schemas/chains/tezos_indexer_response.py index d9204d427..1b15f5ce7 100644 --- a/src/aleph/schemas/chains/tezos_indexer_response.py +++ b/src/aleph/schemas/chains/tezos_indexer_response.py @@ -2,8 +2,7 @@ from enum import Enum from typing import Generic, List, TypeVar -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field PayloadType = TypeVar("PayloadType") @@ -24,7 +23,7 @@ class IndexerStats(BaseModel): total_events: int = Field(alias="totalEvents") -class IndexerEvent(GenericModel, Generic[PayloadType]): +class IndexerEvent(BaseModel, Generic[PayloadType]): source: str timestamp: dt.datetime block_level: int = Field(alias="blockLevel") @@ -34,8 +33,7 @@ class IndexerEvent(GenericModel, Generic[PayloadType]): class MessageEventPayload(BaseModel): - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(from_attributes=True, populate_by_name=True) timestamp: float addr: str @@ -67,11 +65,11 @@ def timestamp_seconds(self) -> float: IndexerEventType = TypeVar("IndexerEventType", bound=IndexerEvent) -class IndexerResponseData(GenericModel, Generic[IndexerEventType]): +class IndexerResponseData(BaseModel, Generic[IndexerEventType]): index_status: IndexerStatus = Field(alias="indexStatus") stats: IndexerStats events: List[IndexerEventType] -class IndexerResponse(GenericModel, Generic[IndexerEventType]): +class IndexerResponse(BaseModel, Generic[IndexerEventType]): data: IndexerResponseData[IndexerEventType] diff --git a/src/aleph/schemas/cost_estimation_messages.py b/src/aleph/schemas/cost_estimation_messages.py index fdcb57baf..3ff99cc82 100644 --- a/src/aleph/schemas/cost_estimation_messages.py +++ b/src/aleph/schemas/cost_estimation_messages.py @@ -17,7 +17,7 @@ ImmutableVolume, PersistentVolume, ) -from pydantic import Field, ValidationError, root_validator +from pydantic import Field, ValidationError, model_validator from aleph.schemas.base_messages import AlephBaseMessage, ContentType, MType from aleph.schemas.pending_messages import base_pending_message_load_content @@ -79,7 +79,7 @@ class BaseCostEstimationMessage(AlephBaseMessage, Generic[MType, ContentType]): type: MType item_hash: str - @root_validator(pre=True) + @model_validator(mode="before") def load_content(cls, values): return base_pending_message_load_content(values) @@ -169,4 +169,4 @@ async def validate_cost_estimation_message_content( ) -> CostEstimationContent: content = await storage_service.get_message_content(message) content_type = COST_MESSAGE_TYPE_TO_CONTENT[message.type] - return content_type.parse_obj(content.value) + return content_type.model_validate(content.value) diff --git a/src/aleph/schemas/pending_messages.py b/src/aleph/schemas/pending_messages.py index 4437df81b..7df5bf454 100644 --- a/src/aleph/schemas/pending_messages.py +++ b/src/aleph/schemas/pending_messages.py @@ -31,7 +31,7 @@ ProgramContent, StoreContent, ) -from pydantic import ValidationError, root_validator, validator +from pydantic import ValidationError, field_validator, model_validator import aleph.toolkit.json as aleph_json from aleph.exceptions import UnknownHashError @@ -110,13 +110,13 @@ class BasePendingMessage(AlephBaseMessage, Generic[MType, ContentType]): type: MType time: dt.datetime - @root_validator(pre=True) + @model_validator(mode="before") def load_content(cls, values): return base_pending_message_load_content(values) - @validator("time", pre=True) - def check_time(cls, v, values): - return base_pending_message_validator_check_time(v, values) + @field_validator("time", mode="before") + def check_time(cls, v, info): + return base_pending_message_validator_check_time(v, info.data) class PendingAggregateMessage( diff --git a/src/aleph/toolkit/json.py b/src/aleph/toolkit/json.py index 95d7a5054..eeb89bf15 100644 --- a/src/aleph/toolkit/json.py +++ b/src/aleph/toolkit/json.py @@ -8,7 +8,7 @@ from typing import IO, Any, Union import orjson -from pydantic.json import pydantic_encoder +import pydantic # The actual type of serialized JSON as returned by the JSON serializer. SerializedJson = bytes @@ -49,8 +49,10 @@ def extended_json_encoder(obj: Any) -> Any: return obj.toordinal() elif isinstance(obj, time): return obj.hour * 3600 + obj.minute * 60 + obj.second + obj.microsecond / 1e6 + elif isinstance(obj, pydantic.BaseModel): + return obj.model_dump() else: - return pydantic_encoder(obj) + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def dumps(obj: Any) -> bytes: diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 5151cf0da..21e4a34d6 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -39,7 +39,7 @@ def item_hash(self) -> str: def to_dict(self) -> Dict[str, Any]: return { "status": self.status.value, - "message": format_message(self.message).dict(), + "message": format_message(self.message).model_dump(), } diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 9225df29c..487f6d0cf 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -4,7 +4,7 @@ from aiohttp import web from aleph_message.models import MessageType -from pydantic import ValidationError, parse_obj_as +from pydantic import ValidationError import aleph.toolkit.json as aleph_json from aleph.db.accessors.balances import ( @@ -78,9 +78,9 @@ async def get_account_balance(request: web.Request): address = _get_address_from_request(request) try: - query_params = GetAccountQueryParams.parse_obj(request.query) + query_params = GetAccountQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) session_factory: DbSessionFactory = get_session_factory_from_request(request) with session_factory() as session: @@ -91,23 +91,25 @@ async def get_account_balance(request: web.Request): return web.json_response( text=GetAccountBalanceResponse( address=address, balance=balance, locked_amount=total_cost, details=details - ).json() + ).model_dump_json() ) async def get_chain_balances(request: web.Request) -> web.Response: try: - query_params = GetBalancesChainsQueryParams.parse_obj(request.query) + query_params = GetBalancesChainsQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) session_factory: DbSessionFactory = get_session_factory_from_request(request) with session_factory() as session: balances = get_balances_by_chain(session, **find_filters) - formatted_balances = [AddressBalanceResponse.from_orm(b) for b in balances] + formatted_balances = [ + AddressBalanceResponse.model_validate(b) for b in balances + ] total_balances = count_balances_by_chain(session, **find_filters) @@ -128,9 +130,9 @@ async def get_account_files(request: web.Request) -> web.Response: address = _get_address_from_request(request) try: - query_params = GetAccountFilesQueryParams.parse_obj(request.query) + query_params = GetAccountFilesQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) session_factory: DbSessionFactory = get_session_factory_from_request(request) @@ -152,9 +154,12 @@ async def get_account_files(request: web.Request) -> web.Response: response = GetAccountFilesResponse( address=address, total_size=total_size, - files=parse_obj_as(List[GetAccountFilesResponseItem], file_pins), + files=[ + GetAccountFilesResponseItem.model_validate(file_pin) + for file_pin in file_pins + ], pagination_page=query_params.page, pagination_total=nb_files, pagination_per_page=query_params.pagination, ) - return web.json_response(text=response.json()) + return web.json_response(text=response.model_dump_json()) diff --git a/src/aleph/web/controllers/aggregates.py b/src/aleph/web/controllers/aggregates.py index 09b25b0a6..c9f61c3be 100644 --- a/src/aleph/web/controllers/aggregates.py +++ b/src/aleph/web/controllers/aggregates.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from aiohttp import web -from pydantic import BaseModel, ValidationError, validator +from pydantic import BaseModel, ValidationError, field_validator from sqlalchemy import select from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate @@ -22,10 +22,7 @@ class AggregatesQueryParams(BaseModel): with_info: bool = False value_only: bool = False - @validator( - "keys", - pre=True, - ) + @field_validator("keys", mode="before") def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) @@ -40,7 +37,7 @@ async def address_aggregate(request: web.Request) -> web.Response: address: str = request.match_info["address"] try: - query_params = AggregatesQueryParams.parse_obj(request.query) + query_params = AggregatesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity( text=e.json(), content_type="application/json" diff --git a/src/aleph/web/controllers/main.py b/src/aleph/web/controllers/main.py index 7fcb7aa8e..c254b9dae 100644 --- a/src/aleph/web/controllers/main.py +++ b/src/aleph/web/controllers/main.py @@ -76,7 +76,9 @@ async def metrics_json(request: web.Request) -> web.Response: with session_factory() as session: return web.Response( - text=(await get_metrics(session=session, node_cache=node_cache)).to_json(), + text=( + await get_metrics(session=session, node_cache=node_cache) + ).model_dump_json(), content_type="application/json", ) @@ -98,7 +100,7 @@ async def ccn_metric(request: web.Request) -> web.Response: """Fetch metrics for CCN node id""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) @@ -124,7 +126,7 @@ async def crn_metric(request: web.Request) -> web.Response: """Fetch Metric for crn.""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index ed9fb3b61..bb132907f 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -6,7 +6,14 @@ import aiohttp.web_ws from aiohttp import WSMsgType, web from aleph_message.models import Chain, ItemHash, MessageType -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) import aleph.toolkit.json as aleph_json from aleph.db.accessors.messages import ( @@ -114,19 +121,50 @@ class BaseMessageQueryParams(BaseModel): default=None, description="Accepted values for the 'item_hash' field." ) - @root_validator - def validate_field_dependencies(cls, values): - start_date = values.get("start_date") - end_date = values.get("end_date") + start_date: float = Field( + default=0, + ge=0, + alias="startDate", + description="Start date timestamp. If specified, only messages with " + "a time field greater or equal to this value will be returned.", + ) + end_date: float = Field( + default=0, + ge=0, + alias="endDate", + description="End date timestamp. If specified, only messages with " + "a time field lower than this value will be returned.", + ) + + start_block: int = Field( + default=0, + ge=0, + alias="startBlock", + description="Start block number. If specified, only messages with " + "a block number greater or equal to this value will be returned.", + ) + end_block: int = Field( + default=0, + ge=0, + alias="endBlock", + description="End block number. If specified, only messages with " + "a block number lower than this value will be returned.", + ) + + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date if start_date and end_date and (end_date < start_date): raise ValueError("end date cannot be lower than start date.") - start_block = values.get("start_block") - end_block = values.get("end_block") + start_block = self.start_block + end_block = self.end_block if start_block and end_block and (end_block < start_block): raise ValueError("end block cannot be lower than start block.") - return values - @validator( + return self + + @field_validator( "hashes", "addresses", "refs", @@ -137,15 +175,14 @@ def validate_field_dependencies(cls, values): "channels", "message_types", "tags", - pre=True, + mode="before", ) def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) class MessageQueryParams(BaseMessageQueryParams): @@ -158,36 +195,6 @@ class MessageQueryParams(BaseMessageQueryParams): default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." ) - start_date: float = Field( - default=0, - ge=0, - alias="startDate", - description="Start date timestamp. If specified, only messages with " - "a time field greater or equal to this value will be returned.", - ) - end_date: float = Field( - default=0, - ge=0, - alias="endDate", - description="End date timestamp. If specified, only messages with " - "a time field lower than this value will be returned.", - ) - - start_block: int = Field( - default=0, - ge=0, - alias="startBlock", - description="Start block number. If specified, only messages with " - "a block number greater or equal to this value will be returned.", - ) - end_block: int = Field( - default=0, - ge=0, - alias="endBlock", - description="End block number. If specified, only messages with " - "a block number lower than this value will be returned.", - ) - class WsMessageQueryParams(BaseMessageQueryParams): history: Optional[int] = Field( @@ -237,7 +244,7 @@ class MessageHashesQueryParams(BaseModel): "Set this to false to include metadata alongside the hashes in the response.", ) - @root_validator + @model_validator(mode="after") def validate_field_dependencies(cls, values): start_date = values.get("start_date") end_date = values.get("end_date") @@ -245,8 +252,7 @@ def validate_field_dependencies(cls, values): raise ValueError("end date cannot be lower than start date.") return values - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) def message_to_dict(message: MessageDb) -> Dict[str, Any]: @@ -292,16 +298,16 @@ async def view_messages_list(request: web.Request) -> web.Response: """Messages list view with filters""" try: - query_params = MessageQueryParams.parse_obj(request.query) + query_params = MessageQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) # If called from the messages/page/{page}.json endpoint, override the page # parameters with the URL one if url_page_param := request.match_info.get("page"): query_params.page = int(url_page_param) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -332,10 +338,10 @@ async def _send_history_to_ws( session=session, pagination=history, include_confirmations=True, - **query_params.dict(exclude_none=True), + **query_params.model_dump(exclude_none=True), ) for message in messages: - await ws.send_str(format_message(message).json()) + await ws.send_str(format_message(message).model_dump_json()) def message_matches_filters( @@ -412,7 +418,7 @@ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): if message_matches_filters(message=message, query_params=query_params): try: - await ws.send_str(message.json()) + await ws.send_str(message.model_dump_json()) except ConnectionResetError: # We can detect the WS closing in this task in addition to the main one. # The main task will also detect the close event. @@ -437,9 +443,9 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: mq_channel = await get_mq_ws_channel_from_request(request=request, logger=LOGGER) try: - query_params = WsMessageQueryParams.parse_obj(request.query) + query_params = WsMessageQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) history = query_params.history @@ -508,7 +514,9 @@ def _get_message_with_status( if status == MessageStatus.PENDING: # There may be several instances of the same pending message, return the first. pending_messages_db = get_pending_messages(session=session, item_hash=item_hash) - pending_messages = [PendingMessage.from_orm(m) for m in pending_messages_db] + pending_messages = [ + PendingMessage.model_validate(m) for m in pending_messages_db + ] return PendingMessageStatus( status=MessageStatus.PENDING, item_hash=item_hash, @@ -540,7 +548,7 @@ def _get_message_with_status( return ForgottenMessageStatus( item_hash=item_hash, reception_time=reception_time, - message=ForgottenMessage.from_orm(forgotten_message_db), + message=ForgottenMessage.model_validate(forgotten_message_db), forgotten_by=forgotten_message_db.forgotten_by, ) @@ -579,7 +587,7 @@ async def view_message(request: web.Request): session=session, status_db=message_status_db ) - return web.json_response(text=message_with_status.json()) + return web.json_response(text=message_with_status.model_dump_json()) async def view_message_content(request: web.Request): @@ -637,17 +645,17 @@ async def view_message_status(request: web.Request): if message_status is None: raise web.HTTPNotFound() - status_info = MessageStatusInfo.from_orm(message_status) - return web.json_response(text=status_info.json()) + status_info = MessageStatusInfo.model_validate(message_status) + return web.json_response(text=status_info.model_dump_json()) async def view_message_hashes(request: web.Request): try: - query_params = MessageHashesQueryParams.parse_obj(request.query) + query_params = MessageHashesQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -659,7 +667,7 @@ async def view_message_hashes(request: web.Request): if find_filters["hash_only"]: formatted_hashes = [h for h in hashes] else: - formatted_hashes = [MessageHashes.from_orm(h) for h in hashes] + formatted_hashes = [MessageHashes.model_validate(h) for h in hashes] total_hashes = count_matching_hashes(session, **find_filters) response = { diff --git a/src/aleph/web/controllers/metrics.py b/src/aleph/web/controllers/metrics.py index f429f8708..c9ed4ec19 100644 --- a/src/aleph/web/controllers/metrics.py +++ b/src/aleph/web/controllers/metrics.py @@ -79,7 +79,7 @@ class Metrics(DataClassJsonMixin): pyaleph_status_sync_pending_messages_total: int pyaleph_status_sync_pending_txs_total: int - pyaleph_status_chain_eth_last_committed_height: Optional[int] + pyaleph_status_chain_eth_last_committed_height: Optional[int] = None pyaleph_processing_pending_messages_seen_ids_total: Optional[int] = None pyaleph_processing_pending_messages_tasks_total: Optional[int] = None diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index 88939d724..ce1d3815d 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -112,7 +112,7 @@ async def pub_json(request: web.Request): pub_status = PublicationStatus.from_failures(failed_publications) return web.json_response( - text=pub_status.json(), + text=pub_status.model_dump_json(), status=500 if pub_status == "error" else 200, ) @@ -125,9 +125,9 @@ class PubMessageRequest(BaseModel): @shielded async def pub_message(request: web.Request): try: - request_data = PubMessageRequest.parse_obj(await request.json()) + request_data = PubMessageRequest.model_validate(await request.json()) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) except ValueError: # Body must be valid JSON raise web.HTTPUnprocessableEntity() @@ -142,4 +142,6 @@ async def pub_message(request: web.Request): ) status_code = broadcast_status_to_http_status(broadcast_status) - return web.json_response(text=broadcast_status.json(), status=status_code) + return web.json_response( + text=broadcast_status.model_dump_json(), status=status_code + ) diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index 773d7f6df..b2e4ab7c8 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -2,7 +2,14 @@ from aiohttp import web from aleph_message.models import ItemHash -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) from sqlalchemy import select from aleph.db.accessors.posts import ( @@ -82,7 +89,7 @@ class PostQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) - @root_validator + @model_validator(mode="before") def validate_field_dependencies(cls, values): start_date = values.get("start_date") end_date = values.get("end_date") @@ -90,22 +97,15 @@ def validate_field_dependencies(cls, values): raise ValueError("end date cannot be lower than start date.") return values - @validator( - "addresses", - "hashes", - "refs", - "post_types", - "channels", - "tags", - pre=True, + @field_validator( + "addresses", "hashes", "refs", "post_types", "channels", "tags", mode="before" ) def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) def merged_post_to_dict(merged_post: MergedPost) -> Dict[str, Any]: @@ -173,9 +173,9 @@ def merged_post_v0_to_dict( def get_query_params(request: web.Request) -> PostQueryParams: try: - query_params = PostQueryParams.parse_obj(request.query) + query_params = PostQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) path_page = get_path_page(request) if path_page: @@ -188,7 +188,7 @@ async def view_posts_list_v0(request: web.Request) -> web.Response: query_string = request.query_string query_params = get_query_params(request) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -231,15 +231,15 @@ async def view_posts_list_v1(request) -> web.Response: query_string = request.query_string try: - query_params = PostQueryParams.parse_obj(request.query) + query_params = PostQueryParams.model_validate(request.query) except ValidationError as e: - raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + raise web.HTTPUnprocessableEntity(text=e.json()) path_page = get_path_page(request) if path_page: query_params.page = path_page - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 452b230bc..5d22ed3ce 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -118,7 +118,7 @@ async def message_price(request: web.Request): "detail": costs, } - response = EstimatedCostsResponse.parse_obj(model) + response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) @@ -134,7 +134,7 @@ async def message_price_estimate(request: web.Request): storage_service = get_storage_service_from_request(request) with session_factory() as session: - parsed_body = PubMessageRequest.parse_obj(await request.json()) + parsed_body = PubMessageRequest.model_validate(await request.json()) message = validate_cost_estimation_message_dict(parsed_body.message_dict) content = await validate_cost_estimation_message_content( message, storage_service @@ -157,6 +157,6 @@ async def message_price_estimate(request: web.Request): "detail": costs, } - response = EstimatedCostsResponse.parse_obj(model) + response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) diff --git a/src/aleph/web/controllers/programs.py b/src/aleph/web/controllers/programs.py index c5427ccf2..408178fc6 100644 --- a/src/aleph/web/controllers/programs.py +++ b/src/aleph/web/controllers/programs.py @@ -1,5 +1,5 @@ from aiohttp import web -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from aleph.db.accessors.messages import get_programs_triggered_by_messages from aleph.types.db_session import DbSessionFactory @@ -9,8 +9,7 @@ class GetProgramQueryFields(BaseModel): sort_order: SortOrder = SortOrder.DESCENDING - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") async def get_programs_on_message(request: web.Request) -> web.Response: diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 6cbcc0ebb..7e0a4ac7b 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -296,11 +296,11 @@ async def pub_on_p2p_topics( class BroadcastStatus(BaseModel): publication_status: PublicationStatus - message_status: Optional[MessageStatus] + message_status: Optional[MessageStatus] = None def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: - if broadcast_status.publication_status == "error": + if broadcast_status.publication_status.status == "error": return 500 message_status = broadcast_status.message_status @@ -311,7 +311,7 @@ def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: def format_pending_message_dict(pending_message: BasePendingMessage) -> Dict[str, Any]: - pending_message_dict = pending_message.dict(exclude_none=True) + pending_message_dict = pending_message.model_dump(exclude_none=True) pending_message_dict["time"] = pending_message_dict["time"].timestamp() return pending_message_dict diff --git a/tests/api/test_get_message.py b/tests/api/test_get_message.py index c6333da97..095ffc5f8 100644 --- a/tests/api/test_get_message.py +++ b/tests/api/test_get_message.py @@ -197,7 +197,7 @@ async def test_get_processed_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = ProcessedMessageStatus.parse_obj(response_json) + parsed_response = ProcessedMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.PROCESSED assert parsed_response.item_hash == processed_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -236,7 +236,7 @@ async def test_get_rejected_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = RejectedMessageStatus.parse_obj(response_json) + parsed_response = RejectedMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.REJECTED assert parsed_response.item_hash == rejected_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -260,7 +260,7 @@ async def test_get_forgotten_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = ForgottenMessageStatus.parse_obj(response_json) + parsed_response = ForgottenMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.FORGOTTEN assert parsed_response.item_hash == forgotten_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -289,7 +289,7 @@ async def test_get_pending_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = PendingMessageStatus.parse_obj(response_json) + parsed_response = PendingMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.PENDING assert parsed_response.item_hash == processed_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 6e9b2582a..73bb6b282 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -527,7 +527,7 @@ def instance_message_fixture() -> MessageDb: use_latest=True, ) ], - ).dict(), + ).model_dump(), size=3000, time=timestamp_to_datetime(1686572207.89381), channel=Channel("TEST"), diff --git a/tests/chains/test_chain_data_service.py b/tests/chains/test_chain_data_service.py index b34a5522c..e88c9df2a 100644 --- a/tests/chains/test_chain_data_service.py +++ b/tests/chains/test_chain_data_service.py @@ -1,4 +1,5 @@ import datetime as dt +import json import pytest from aleph_message.models import ( @@ -86,7 +87,7 @@ async def test_smart_contract_protocol_ipfs_store( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(mode="json"), ) chain_data_service = ChainDataService( @@ -135,7 +136,7 @@ async def test_smart_contract_protocol_regular_message( timestamp=1668611900, addr="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", msgtype="POST", - msgcontent=content.json(), + msgcontent=json.dumps(content.model_dump()), ) tx = ChainTxDb( @@ -146,7 +147,7 @@ async def test_smart_contract_protocol_regular_message( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(mode="json"), ) chain_data_service = ChainDataService( diff --git a/tests/chains/test_cosmos.py b/tests/chains/test_cosmos.py index 6cc586d17..29ce207e7 100644 --- a/tests/chains/test_cosmos.py +++ b/tests/chains/test_cosmos.py @@ -3,7 +3,6 @@ from urllib.parse import unquote import pytest -from pydantic.tools import parse_obj_as from aleph.chains.cosmos import CosmosConnector from aleph.schemas.pending_messages import PendingPostMessage @@ -16,7 +15,7 @@ def cosmos_message() -> PendingPostMessage: message = json.loads(unquote(TEST_MESSAGE)) message["signature"] = json.dumps(message["signature"]) message["item_content"] = json.dumps(message["item_content"], separators=(",", ":")) - return parse_obj_as(PendingPostMessage, message) + return PendingPostMessage.model_validate(message) @pytest.mark.asyncio @@ -28,8 +27,7 @@ async def test_verify_signature_real(cosmos_message: PendingPostMessage): @pytest.mark.asyncio async def test_verify_signature_bad_json(): connector = CosmosConnector() - message = parse_obj_as( - PendingPostMessage, + message = PendingPostMessage.model_validate( { "chain": "CSDK", "time": 1737558660.737648, @@ -37,7 +35,7 @@ async def test_verify_signature_bad_json(): "type": "POST", "item_hash": sha256("ITEM_HASH".encode()).hexdigest(), "signature": "baba", - }, + } ) result = await connector.verify_signature(message) assert result is False diff --git a/tests/chains/test_tezos.py b/tests/chains/test_tezos.py index aa0d93e4c..ab8cea92b 100644 --- a/tests/chains/test_tezos.py +++ b/tests/chains/test_tezos.py @@ -101,7 +101,7 @@ def test_datetime_to_iso_8601(): type="my-type", address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ), ( MessageType.aggregate.value, @@ -110,7 +110,7 @@ def test_datetime_to_iso_8601(): content={"body": "My first post on Tezos"}, address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ), ], ) @@ -139,4 +139,4 @@ def test_indexer_event_to_aleph_message(message_type: str, message_content: str) assert tx.protocol == ChainSyncProtocol.SMART_CONTRACT assert tx.protocol_version == 1 - assert tx.content == indexer_event.payload.dict() + assert tx.content == indexer_event.payload.model_dump() diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index d9f3b6086..1f1f98980 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -136,7 +136,7 @@ async def _process_smart_contract_tx( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(), ) pending_tx = PendingTxDb(tx=tx) diff --git a/tests/services/test_cost_service.py b/tests/services/test_cost_service.py index d92548b68..8d8441841 100644 --- a/tests/services/test_cost_service.py +++ b/tests/services/test_cost_service.py @@ -49,7 +49,7 @@ def fixture_hold_instance_message() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -109,7 +109,7 @@ def fixture_hold_instance_message_complete() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -145,7 +145,7 @@ def fixture_flow_instance_message() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -210,7 +210,7 @@ def fixture_flow_instance_message_complete() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -267,7 +267,7 @@ def fixture_hold_program_message_complete() -> ExecutableContent: }, } - return CostEstimationProgramContent.parse_obj(content) + return CostEstimationProgramContent.model_validate(content) def test_compute_cost( @@ -293,7 +293,7 @@ def test_compute_cost_conf( fixture_settings_aggregate_in_db, fixture_hold_instance_message, ): - message_dict = fixture_hold_instance_message.dict() + message_dict = fixture_hold_instance_message.model_dump() # Convert the message to conf message_dict["environment"].update( @@ -306,7 +306,7 @@ def test_compute_cost_conf( } ) - rebuilt_message = InstanceContent.parse_obj(message_dict) + rebuilt_message = InstanceContent.model_validate(message_dict) file_db = StoredFileDb() mock = Mock() @@ -410,7 +410,7 @@ def test_compute_flow_cost_conf( fixture_settings_aggregate_in_db, fixture_flow_instance_message, ): - message_dict = fixture_flow_instance_message.dict() + message_dict = fixture_flow_instance_message.model_dump() # Convert the message to conf message_dict["environment"].update( @@ -423,7 +423,7 @@ def test_compute_flow_cost_conf( } ) - rebuilt_message = InstanceContent.parse_obj(message_dict) + rebuilt_message = InstanceContent.model_validate(message_dict) # Proceed with the test file_db = StoredFileDb()