diff --git a/pyproject.toml b/pyproject.toml index ca5130b5e..f16522167 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "aiohttp-jinja2==1.6", "aioipfs~=0.7.1", "alembic==1.15.1", - "aleph-message==0.6.1", + "aleph-message @ git+https://github.com/aleph-im/aleph-message@108-upgrade-pydantic-version#egg=aleph-message", "aleph-nuls2==0.1", "aleph-p2p-client @ git+https://github.com/aleph-im/p2p-service-client-python@cbfebb871db94b2ca580e66104a67cd730c5020c", "asyncpg==0.30", @@ -158,9 +158,8 @@ dependencies = [ "isort==5.13.2", "check-sdist==0.1.3", "sqlalchemy[mypy]==1.4.41", - "yamlfix==1.16.1", - # because of aleph messages otherwise yamlfix install a too new version - "pydantic>=1.10.5,<2.0.0", + "yamlfix>=1.17", + "pydantic>=2,<3.0.0", "pyproject-fmt==2.2.1", "types-aiofiles", "types-protobuf", diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 7eadf1afe..0ce3b3e33 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -1,4 +1,5 @@ import asyncio +import json from typing import Any, Dict, List, Mapping, Optional, Self, Set, Type, Union, cast import aio_pika.abc @@ -66,10 +67,12 @@ async def prepare_sync_event_payload( protocol=ChainSyncProtocol.ON_CHAIN_SYNC, version=1, content=OnChainContent( - messages=[OnChainMessage.from_orm(message) for message in messages] + messages=[ + OnChainMessage.model_validate(message) for message in messages + ] ), ) - archive_content: bytes = archive.json().encode("utf-8") + archive_content: bytes = archive.model_dump_json().encode("utf-8") ipfs_cid = await self.storage_service.add_file( session=session, file_content=archive_content, engine=ItemType.ipfs @@ -166,7 +169,9 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An ) try: - payload = cast(GenericMessageEvent, payload_model.parse_obj(tx.content)) + payload = cast( + GenericMessageEvent, payload_model.model_validate(tx.content) + ) except ValidationError: raise InvalidContent(f"Incompatible tx content for {tx.chain}/{tx.hash}") @@ -189,7 +194,7 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An item_hash=ItemHash(payload.content), metadata=None, ) - item_content = content.json(exclude_none=True) + item_content = json.dumps(content.model_dump(exclude_none=True)) else: item_content = payload.content diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index fe382595e..e3e683bba 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -342,7 +342,7 @@ async def packer(self, config: Config): account, int(gas_price * 1.1), nonce, - sync_event_payload.json(), + sync_event_payload.model_dump_json(), ) LOGGER.info("Broadcast %r on %s" % (response, CHAIN_NAME)) diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index db6e8d542..2930e5c74 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -93,7 +93,7 @@ def make_events_query( model = SyncEvent event_type_str = "syncEvents" - fields = "\n".join(model.__fields__.keys()) + fields = "\n".join(model.__annotations__.keys()) params: Dict[str, Any] = { "blockchain": f'"{blockchain.value}"', "limit": limit, @@ -147,7 +147,7 @@ async def _query(self, query: str, model: Type[T]) -> T: response = await self.http_session.post("/", json={"query": query}) response.raise_for_status() response_json = await response.json() - return model.parse_obj(response_json) + return model.model_validate(response_json) async def fetch_account_state( self, @@ -196,7 +196,7 @@ def indexer_event_to_chain_tx( if isinstance(indexer_event, MessageEvent): protocol = ChainSyncProtocol.SMART_CONTRACT protocol_version = 1 - content = indexer_event.dict() + content = indexer_event.model_dump() else: sync_message = aleph_json.loads(indexer_event.message) diff --git a/src/aleph/chains/nuls2.py b/src/aleph/chains/nuls2.py index 6062b5b0a..b1a2d3f65 100644 --- a/src/aleph/chains/nuls2.py +++ b/src/aleph/chains/nuls2.py @@ -210,7 +210,7 @@ async def packer(self, config: Config): # Required to apply update to the files table in get_chaindata session.commit() - content = sync_event_payload.json() + content = sync_event_payload.model_dump_json() tx = await prepare_transfer_tx( address, [(target_addr, CHEAP_UNIT_FEE)], diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index b8a951d8b..6cb34e6aa 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -162,7 +162,7 @@ async def fetch_messages( response.raise_for_status() response_json = await response.json() - return IndexerResponse[IndexerMessageEvent].parse_obj(response_json) + return IndexerResponse[IndexerMessageEvent].model_validate(response_json) def indexer_event_to_chain_tx( @@ -176,7 +176,7 @@ def indexer_event_to_chain_tx( publisher=indexer_event.source, protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=indexer_event.payload.dict(), + content=indexer_event.payload.model_dump(), ) return chain_tx diff --git a/src/aleph/db/models/messages.py b/src/aleph/db/models/messages.py index 0d2f8e894..eb06ae12f 100644 --- a/src/aleph/db/models/messages.py +++ b/src/aleph/db/models/messages.py @@ -14,7 +14,6 @@ StoreContent, ) from pydantic import ValidationError -from pydantic.error_wrappers import ErrorWrapper from sqlalchemy import ( ARRAY, TIMESTAMP, @@ -62,14 +61,14 @@ def validate_message_content( content_dict: Dict[str, Any], ) -> BaseContent: content_type = CONTENT_TYPE_MAP[message_type] - content = content_type.parse_obj(content_dict) + content = content_type.model_validate(content_dict) # Validate that the content time can be converted to datetime. This will # raise a ValueError and be caught # TODO: move this validation in aleph-message try: _ = dt.datetime.fromtimestamp(content_dict["time"]) except ValueError as e: - raise ValidationError([ErrorWrapper(e, loc="time")], model=content_type) from e + raise ValidationError(str(e)) from e return content diff --git a/src/aleph/handlers/content/vm.py b/src/aleph/handlers/content/vm.py index 35aa352e8..3f1896553 100644 --- a/src/aleph/handlers/content/vm.py +++ b/src/aleph/handlers/content/vm.py @@ -195,7 +195,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb: if content.on.message: vm.message_triggers = [ - subscription.dict() for subscription in content.on.message + subscription.model_dump() for subscription in content.on.message ] vm.code_volume = CodeVolumeDb( diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index e0e9347a5..50d3d883d 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from aleph_message.models import Chain -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from aleph.types.files import FileType from aleph.types.sort_order import SortOrder @@ -53,7 +53,7 @@ class GetBalancesChainsQueryParams(BaseModel): ) min_balance: int = Field(default=0, ge=1, description="Minimum Balance needed") - @validator("chains", pre=True) + @field_validator("chains", mode="before") def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) @@ -61,8 +61,7 @@ def split_str(cls, v): class AddressBalanceResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) address: str balance: str @@ -78,8 +77,7 @@ class GetAccountFilesResponseItem(BaseModel): class GetAccountFilesResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) address: str total_size: int diff --git a/src/aleph/schemas/api/costs.py b/src/aleph/schemas/api/costs.py index ef66ae81e..4e4878e8f 100644 --- a/src/aleph/schemas/api/costs.py +++ b/src/aleph/schemas/api/costs.py @@ -1,27 +1,25 @@ from typing import List -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, field_validator from aleph.toolkit.costs import format_cost_str class EstimatedCostDetailResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) type: str name: str cost_hold: str cost_stream: str - @validator("cost_hold", "cost_stream") + @field_validator("cost_hold", "cost_stream", mode="after") def check_format_price(cls, v): return format_cost_str(v) class EstimatedCostsResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) required_tokens: float payment_type: str diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index dea39f1aa..6857cdd93 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -25,10 +25,8 @@ ProgramContent, StoreContent, ) -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field, field_serializer -import aleph.toolkit.json as aleph_json from aleph.db.models import MessageDb from aleph.types.message_status import ErrorCode, MessageStatus @@ -39,26 +37,26 @@ class MessageConfirmation(BaseModel): """Format of the result when a message has been confirmed on a blockchain""" - class Config: - orm_mode = True - json_encoders = {dt.datetime: lambda d: d.timestamp()} + model_config = ConfigDict(from_attributes=True) chain: Chain height: int hash: str + time: dt.datetime + + @field_serializer("time") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() -class BaseMessage(GenericModel, Generic[MType, ContentType]): - class Config: - orm_mode = True - json_loads = aleph_json.loads - json_encoders = {dt.datetime: lambda d: d.timestamp()} +class BaseMessage(BaseModel, Generic[MType, ContentType]): + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime @@ -67,6 +65,10 @@ class Config: confirmed: bool confirmations: List[MessageConfirmation] + @field_serializer("time") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() + class AggregateMessage( BaseMessage[Literal[MessageType.aggregate], AggregateContent] # type: ignore @@ -127,13 +129,13 @@ def format_message(message: MessageDb) -> AlephMessage: message_type = message.type message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.from_orm(message) + return message_cls.model_validate(message) def format_message_dict(message: Dict[str, Any]) -> AlephMessage: message_type = message.get("type") message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.parse_obj(message) + return message_cls.model_validate(message) class BaseMessageStatus(BaseModel): @@ -145,45 +147,41 @@ class BaseMessageStatus(BaseModel): # We already have a model for the validation of pending messages, but this one # is only used for formatting and does not try to be smart. class PendingMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime channel: Optional[str] = None - content: Optional[Dict[str, Any]] + content: Optional[Dict[str, Any]] = None reception_time: dt.datetime class PendingMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PENDING messages: List[PendingMessage] class ProcessedMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PROCESSED message: AlephMessage class ForgottenMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType item_type: ItemType item_hash: str @@ -201,18 +199,15 @@ class RejectedMessageStatus(BaseMessageStatus): status: MessageStatus = MessageStatus.REJECTED message: Mapping[str, Any] error_code: ErrorCode - details: Any + details: Any = None class MessageStatusInfo(BaseMessageStatus): - class Config: - orm_mode = True - fields = {"item_hash": {"exclude": True}} + model_config = ConfigDict(from_attributes=True) class MessageHashes(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) MessageWithStatus = Union[ @@ -224,12 +219,13 @@ class Config: class MessageListResponse(BaseModel): - class Config: - json_encoders = {dt.datetime: lambda d: d.timestamp()} - json_loads = aleph_json.loads - messages: List[AlephMessage] pagination_page: int pagination_total: int pagination_per_page: int pagination_item: Literal["messages"] = "messages" + time: dt.datetime + + @field_serializer("time") + def serialize_time(self, dt: dt.datetime, _info) -> float: + return dt.timestamp() diff --git a/src/aleph/schemas/base_messages.py b/src/aleph/schemas/base_messages.py index c1837e9dd..99518e83b 100644 --- a/src/aleph/schemas/base_messages.py +++ b/src/aleph/schemas/base_messages.py @@ -7,9 +7,9 @@ from typing import Any, Generic, Mapping, Optional, TypeVar, cast from aleph_message.models import BaseContent, Chain, ItemType, MessageType -from pydantic import root_validator, validator -from pydantic.generics import GenericModel +from pydantic import BaseModel, ValidationInfo, field_validator +from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.utils import item_type_from_hash MType = TypeVar("MType", bound=MessageType) @@ -71,7 +71,7 @@ def base_message_validator_check_item_hash(v: Any, values: Mapping[str, Any]): return v -class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): +class AlephBaseMessage(BaseModel, Generic[MType, ContentType]): """ The base structure of an Aleph message. All the fields of this class appear in all the representations @@ -89,10 +89,25 @@ class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): channel: Optional[str] = None content: Optional[ContentType] = None - @root_validator() - def check_item_type(cls, values): - return base_message_validator_check_item_type(values) - - @validator("item_hash") - def check_item_hash(cls, v: Any, values: Mapping[str, Any]): - return base_message_validator_check_item_hash(v, values) + @field_validator("item_hash", mode="after") + @classmethod + def check_item_type(cls, values: Any, info: ValidationInfo) -> Any: + return base_message_validator_check_item_type(info.data) + + @field_validator("item_hash", mode="before") + @classmethod + def check_item_hash(cls, v: Any, info: ValidationInfo) -> Any: + return base_message_validator_check_item_hash(v, info.data) + + @field_validator("time") + @classmethod + def check_time(cls, v: Any) -> Any: + """ + Parses the time field as a UTC datetime. Contrary to the default datetime + validator, this implementation raises an exception if the time field is + too far in the future. + """ + if isinstance(v, dt.datetime): + return v + + return timestamp_to_datetime(v) diff --git a/src/aleph/schemas/chains/indexer_response.py b/src/aleph/schemas/chains/indexer_response.py index 8a7f9fae5..00ea74453 100644 --- a/src/aleph/schemas/chains/indexer_response.py +++ b/src/aleph/schemas/chains/indexer_response.py @@ -6,7 +6,7 @@ from enum import Enum from typing import List, Protocol, Tuple -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class GenericMessageEvent(Protocol): @@ -43,11 +43,11 @@ class AccountEntityState(BaseModel): pending: List[Tuple[dt.datetime, dt.datetime]] processed: List[Tuple[dt.datetime, dt.datetime]] - @validator("pending", "processed", pre=True, each_item=True) - def split_datetime_ranges(cls, v): - if isinstance(v, str): - return v.split("/") - return v + @field_validator("pending", "processed", mode="before") + def split_datetime_ranges(cls, values): + return map( + lambda value: value.split("/") if isinstance(value, str) else value, values + ) class IndexerAccountStateResponseData(BaseModel): diff --git a/src/aleph/schemas/chains/sync_events.py b/src/aleph/schemas/chains/sync_events.py index ab4304847..3c30bbe17 100644 --- a/src/aleph/schemas/chains/sync_events.py +++ b/src/aleph/schemas/chains/sync_events.py @@ -2,27 +2,26 @@ from typing import Annotated, List, Literal, Optional, Union from aleph_message.models import Chain, ItemHash, ItemType, MessageType -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from aleph.types.chain_sync import ChainSyncProtocol from aleph.types.channel import Channel class OnChainMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: ItemHash time: float channel: Optional[Channel] = None - @validator("time", pre=True) + @field_validator("time", mode="before") def check_time(cls, v, values): if isinstance(v, dt.datetime): return v.timestamp() diff --git a/src/aleph/schemas/chains/tezos_indexer_response.py b/src/aleph/schemas/chains/tezos_indexer_response.py index d9204d427..4283dfa4b 100644 --- a/src/aleph/schemas/chains/tezos_indexer_response.py +++ b/src/aleph/schemas/chains/tezos_indexer_response.py @@ -2,8 +2,7 @@ from enum import Enum from typing import Generic, List, TypeVar -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field PayloadType = TypeVar("PayloadType") @@ -24,7 +23,7 @@ class IndexerStats(BaseModel): total_events: int = Field(alias="totalEvents") -class IndexerEvent(GenericModel, Generic[PayloadType]): +class IndexerEvent(BaseModel, Generic[PayloadType]): source: str timestamp: dt.datetime block_level: int = Field(alias="blockLevel") @@ -34,8 +33,7 @@ class IndexerEvent(GenericModel, Generic[PayloadType]): class MessageEventPayload(BaseModel): - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) timestamp: float addr: str @@ -67,11 +65,11 @@ def timestamp_seconds(self) -> float: IndexerEventType = TypeVar("IndexerEventType", bound=IndexerEvent) -class IndexerResponseData(GenericModel, Generic[IndexerEventType]): +class IndexerResponseData(BaseModel, Generic[IndexerEventType]): index_status: IndexerStatus = Field(alias="indexStatus") stats: IndexerStats events: List[IndexerEventType] -class IndexerResponse(GenericModel, Generic[IndexerEventType]): +class IndexerResponse(BaseModel, Generic[IndexerEventType]): data: IndexerResponseData[IndexerEventType] diff --git a/src/aleph/schemas/cost_estimation_messages.py b/src/aleph/schemas/cost_estimation_messages.py index fdcb57baf..1679358cf 100644 --- a/src/aleph/schemas/cost_estimation_messages.py +++ b/src/aleph/schemas/cost_estimation_messages.py @@ -17,7 +17,7 @@ ImmutableVolume, PersistentVolume, ) -from pydantic import Field, ValidationError, root_validator +from pydantic import Field, ValidationError, model_validator from aleph.schemas.base_messages import AlephBaseMessage, ContentType, MType from aleph.schemas.pending_messages import base_pending_message_load_content @@ -79,7 +79,7 @@ class BaseCostEstimationMessage(AlephBaseMessage, Generic[MType, ContentType]): type: MType item_hash: str - @root_validator(pre=True) + @model_validator(mode="after") def load_content(cls, values): return base_pending_message_load_content(values) @@ -169,4 +169,4 @@ async def validate_cost_estimation_message_content( ) -> CostEstimationContent: content = await storage_service.get_message_content(message) content_type = COST_MESSAGE_TYPE_TO_CONTENT[message.type] - return content_type.parse_obj(content.value) + return content_type.model_validate(content.value) diff --git a/src/aleph/schemas/pending_messages.py b/src/aleph/schemas/pending_messages.py index 4437df81b..4b616dc42 100644 --- a/src/aleph/schemas/pending_messages.py +++ b/src/aleph/schemas/pending_messages.py @@ -31,7 +31,7 @@ ProgramContent, StoreContent, ) -from pydantic import ValidationError, root_validator, validator +from pydantic import ValidationError, field_validator, model_validator import aleph.toolkit.json as aleph_json from aleph.exceptions import UnknownHashError @@ -110,11 +110,12 @@ class BasePendingMessage(AlephBaseMessage, Generic[MType, ContentType]): type: MType time: dt.datetime - @root_validator(pre=True) - def load_content(cls, values): + @model_validator(mode="before") + @classmethod + def load_content(cls, values: Any): return base_pending_message_load_content(values) - @validator("time", pre=True) + @field_validator("time", mode="before") def check_time(cls, v, values): return base_pending_message_validator_check_time(v, values) diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 5151cf0da..21e4a34d6 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -39,7 +39,7 @@ def item_hash(self) -> str: def to_dict(self) -> Dict[str, Any]: return { "status": self.status.value, - "message": format_message(self.message).dict(), + "message": format_message(self.message).model_dump(), } diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 9225df29c..014e2e57e 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -4,7 +4,7 @@ from aiohttp import web from aleph_message.models import MessageType -from pydantic import ValidationError, parse_obj_as +from pydantic import ValidationError import aleph.toolkit.json as aleph_json from aleph.db.accessors.balances import ( @@ -78,7 +78,7 @@ async def get_account_balance(request: web.Request): address = _get_address_from_request(request) try: - query_params = GetAccountQueryParams.parse_obj(request.query) + query_params = GetAccountQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -91,23 +91,25 @@ async def get_account_balance(request: web.Request): return web.json_response( text=GetAccountBalanceResponse( address=address, balance=balance, locked_amount=total_cost, details=details - ).json() + ).model_dump_json() ) async def get_chain_balances(request: web.Request) -> web.Response: try: - query_params = GetBalancesChainsQueryParams.parse_obj(request.query) + query_params = GetBalancesChainsQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) session_factory: DbSessionFactory = get_session_factory_from_request(request) with session_factory() as session: balances = get_balances_by_chain(session, **find_filters) - formatted_balances = [AddressBalanceResponse.from_orm(b) for b in balances] + formatted_balances = [ + AddressBalanceResponse.model_validate(b) for b in balances + ] total_balances = count_balances_by_chain(session, **find_filters) @@ -128,7 +130,7 @@ async def get_account_files(request: web.Request) -> web.Response: address = _get_address_from_request(request) try: - query_params = GetAccountFilesQueryParams.parse_obj(request.query) + query_params = GetAccountFilesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -149,12 +151,16 @@ async def get_account_files(request: web.Request) -> web.Response: if not file_pins: raise web.HTTPNotFound() + validated_files = [ + GetAccountFilesResponseItem.model_validate(file) for file in file_pins + ] + response = GetAccountFilesResponse( address=address, total_size=total_size, - files=parse_obj_as(List[GetAccountFilesResponseItem], file_pins), + files=validated_files, pagination_page=query_params.page, pagination_total=nb_files, pagination_per_page=query_params.pagination, ) - return web.json_response(text=response.json()) + return web.json_response(text=response.model_dump_json()) diff --git a/src/aleph/web/controllers/aggregates.py b/src/aleph/web/controllers/aggregates.py index 09b25b0a6..66b910850 100644 --- a/src/aleph/web/controllers/aggregates.py +++ b/src/aleph/web/controllers/aggregates.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from aiohttp import web -from pydantic import BaseModel, ValidationError, validator +from pydantic import BaseModel, ValidationError, field_validator from sqlalchemy import select from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate @@ -22,10 +22,8 @@ class AggregatesQueryParams(BaseModel): with_info: bool = False value_only: bool = False - @validator( - "keys", - pre=True, - ) + @field_validator("keys", mode="before") + @classmethod def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) @@ -40,7 +38,7 @@ async def address_aggregate(request: web.Request) -> web.Response: address: str = request.match_info["address"] try: - query_params = AggregatesQueryParams.parse_obj(request.query) + query_params = AggregatesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity( text=e.json(), content_type="application/json" diff --git a/src/aleph/web/controllers/main.py b/src/aleph/web/controllers/main.py index 7fcb7aa8e..5b5c5b79e 100644 --- a/src/aleph/web/controllers/main.py +++ b/src/aleph/web/controllers/main.py @@ -98,7 +98,7 @@ async def ccn_metric(request: web.Request) -> web.Response: """Fetch metrics for CCN node id""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) @@ -124,7 +124,7 @@ async def crn_metric(request: web.Request) -> web.Response: """Fetch Metric for crn.""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index ed9fb3b61..38741fac1 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -6,7 +6,14 @@ import aiohttp.web_ws from aiohttp import WSMsgType, web from aleph_message.models import Chain, ItemHash, MessageType -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) import aleph.toolkit.json as aleph_json from aleph.db.accessors.messages import ( @@ -114,19 +121,20 @@ class BaseMessageQueryParams(BaseModel): default=None, description="Accepted values for the 'item_hash' field." ) - @root_validator - def validate_field_dependencies(cls, values): - start_date = values.get("start_date") - end_date = values.get("end_date") + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date if start_date and end_date and (end_date < start_date): raise ValueError("end date cannot be lower than start date.") - start_block = values.get("start_block") - end_block = values.get("end_block") + start_block = self.start_block + end_block = self.end_block if start_block and end_block and (end_block < start_block): raise ValueError("end block cannot be lower than start block.") - return values - @validator( + return self + + @field_validator( "hashes", "addresses", "refs", @@ -137,15 +145,15 @@ def validate_field_dependencies(cls, values): "channels", "message_types", "tags", - pre=True, + mode="before", ) + @classmethod def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) class MessageQueryParams(BaseMessageQueryParams): @@ -237,7 +245,7 @@ class MessageHashesQueryParams(BaseModel): "Set this to false to include metadata alongside the hashes in the response.", ) - @root_validator + @model_validator(mode="after") def validate_field_dependencies(cls, values): start_date = values.get("start_date") end_date = values.get("end_date") @@ -245,8 +253,7 @@ def validate_field_dependencies(cls, values): raise ValueError("end date cannot be lower than start date.") return values - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) def message_to_dict(message: MessageDb) -> Dict[str, Any]: @@ -292,7 +299,7 @@ async def view_messages_list(request: web.Request) -> web.Response: """Messages list view with filters""" try: - query_params = MessageQueryParams.parse_obj(request.query) + query_params = MessageQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -301,7 +308,7 @@ async def view_messages_list(request: web.Request) -> web.Response: if url_page_param := request.match_info.get("page"): query_params.page = int(url_page_param) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -332,10 +339,10 @@ async def _send_history_to_ws( session=session, pagination=history, include_confirmations=True, - **query_params.dict(exclude_none=True), + **query_params.model_dump(exclude_none=True), ) for message in messages: - await ws.send_str(format_message(message).json()) + await ws.send_str(format_message(message).model_dump_json()) def message_matches_filters( @@ -412,7 +419,7 @@ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): if message_matches_filters(message=message, query_params=query_params): try: - await ws.send_str(message.json()) + await ws.send_str(message.model_dump_json()) except ConnectionResetError: # We can detect the WS closing in this task in addition to the main one. # The main task will also detect the close event. @@ -437,7 +444,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: mq_channel = await get_mq_ws_channel_from_request(request=request, logger=LOGGER) try: - query_params = WsMessageQueryParams.parse_obj(request.query) + query_params = WsMessageQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -508,7 +515,9 @@ def _get_message_with_status( if status == MessageStatus.PENDING: # There may be several instances of the same pending message, return the first. pending_messages_db = get_pending_messages(session=session, item_hash=item_hash) - pending_messages = [PendingMessage.from_orm(m) for m in pending_messages_db] + pending_messages = [ + PendingMessage.model_validate(m) for m in pending_messages_db + ] return PendingMessageStatus( status=MessageStatus.PENDING, item_hash=item_hash, @@ -540,7 +549,7 @@ def _get_message_with_status( return ForgottenMessageStatus( item_hash=item_hash, reception_time=reception_time, - message=ForgottenMessage.from_orm(forgotten_message_db), + message=ForgottenMessage.model_validate(forgotten_message_db), forgotten_by=forgotten_message_db.forgotten_by, ) @@ -579,7 +588,7 @@ async def view_message(request: web.Request): session=session, status_db=message_status_db ) - return web.json_response(text=message_with_status.json()) + return web.json_response(text=message_with_status.model_dump_json()) async def view_message_content(request: web.Request): @@ -637,17 +646,17 @@ async def view_message_status(request: web.Request): if message_status is None: raise web.HTTPNotFound() - status_info = MessageStatusInfo.from_orm(message_status) - return web.json_response(text=status_info.json()) + status_info = MessageStatusInfo.model_validate(message_status) + return web.json_response(text=status_info.model_dump_json(exclude={"item_hash"})) async def view_message_hashes(request: web.Request): try: - query_params = MessageHashesQueryParams.parse_obj(request.query) + query_params = MessageHashesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -659,7 +668,7 @@ async def view_message_hashes(request: web.Request): if find_filters["hash_only"]: formatted_hashes = [h for h in hashes] else: - formatted_hashes = [MessageHashes.from_orm(h) for h in hashes] + formatted_hashes = [MessageHashes.model_validate(h) for h in hashes] total_hashes = count_matching_hashes(session, **find_filters) response = { diff --git a/src/aleph/web/controllers/metrics.py b/src/aleph/web/controllers/metrics.py index f429f8708..3973589fe 100644 --- a/src/aleph/web/controllers/metrics.py +++ b/src/aleph/web/controllers/metrics.py @@ -119,7 +119,7 @@ async def fetch_reference_total_messages() -> Optional[int]: async with session.get( urljoin(url, "metrics.json"), raise_for_status=True ) as resp: - data = await resp.json() + data = await resp.model_dump_json() return int(data["pyaleph_status_sync_messages_total"]) except aiohttp.ClientResponseError: LOGGER.warning("ETH height could not be obtained") diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index 88939d724..7c1310429 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -112,7 +112,7 @@ async def pub_json(request: web.Request): pub_status = PublicationStatus.from_failures(failed_publications) return web.json_response( - text=pub_status.json(), + text=pub_status.model_dump_json(), status=500 if pub_status == "error" else 200, ) @@ -125,7 +125,7 @@ class PubMessageRequest(BaseModel): @shielded async def pub_message(request: web.Request): try: - request_data = PubMessageRequest.parse_obj(await request.json()) + request_data = PubMessageRequest.model_validate(await request.json()) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) except ValueError: @@ -142,4 +142,6 @@ async def pub_message(request: web.Request): ) status_code = broadcast_status_to_http_status(broadcast_status) - return web.json_response(text=broadcast_status.json(), status=status_code) + return web.json_response( + text=broadcast_status.model_dump_json(), status=status_code + ) diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index 773d7f6df..82240ecb5 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -2,7 +2,14 @@ from aiohttp import web from aleph_message.models import ItemHash -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) from sqlalchemy import select from aleph.db.accessors.posts import ( @@ -82,30 +89,25 @@ class PostQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) - @root_validator - def validate_field_dependencies(cls, values): - start_date = values.get("start_date") - end_date = values.get("end_date") + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date if start_date and end_date and (end_date < start_date): raise ValueError("end date cannot be lower than start date.") - return values - - @validator( - "addresses", - "hashes", - "refs", - "post_types", - "channels", - "tags", - pre=True, + + return self + + @field_validator( + "addresses", "hashes", "refs", "post_types", "channels", "tags", mode="before" ) - def split_str(cls, v): + @classmethod + def split_str(cls, v) -> List[str]: if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) def merged_post_to_dict(merged_post: MergedPost) -> Dict[str, Any]: @@ -188,7 +190,7 @@ async def view_posts_list_v0(request: web.Request) -> web.Response: query_string = request.query_string query_params = get_query_params(request) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -239,7 +241,7 @@ async def view_posts_list_v1(request) -> web.Response: if path_page: query_params.page = path_page - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 452b230bc..5d22ed3ce 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -118,7 +118,7 @@ async def message_price(request: web.Request): "detail": costs, } - response = EstimatedCostsResponse.parse_obj(model) + response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) @@ -134,7 +134,7 @@ async def message_price_estimate(request: web.Request): storage_service = get_storage_service_from_request(request) with session_factory() as session: - parsed_body = PubMessageRequest.parse_obj(await request.json()) + parsed_body = PubMessageRequest.model_validate(await request.json()) message = validate_cost_estimation_message_dict(parsed_body.message_dict) content = await validate_cost_estimation_message_content( message, storage_service @@ -157,6 +157,6 @@ async def message_price_estimate(request: web.Request): "detail": costs, } - response = EstimatedCostsResponse.parse_obj(model) + response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) diff --git a/src/aleph/web/controllers/programs.py b/src/aleph/web/controllers/programs.py index c5427ccf2..e898e7aca 100644 --- a/src/aleph/web/controllers/programs.py +++ b/src/aleph/web/controllers/programs.py @@ -1,5 +1,5 @@ from aiohttp import web -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from aleph.db.accessors.messages import get_programs_triggered_by_messages from aleph.types.db_session import DbSessionFactory @@ -8,9 +8,7 @@ class GetProgramQueryFields(BaseModel): sort_order: SortOrder = SortOrder.DESCENDING - - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") async def get_programs_on_message(request: web.Request) -> web.Response: diff --git a/src/aleph/web/controllers/storage.py b/src/aleph/web/controllers/storage.py index 6bcfee950..0ffc3a57d 100644 --- a/src/aleph/web/controllers/storage.py +++ b/src/aleph/web/controllers/storage.py @@ -232,8 +232,9 @@ async def _check_and_add_file( raise web.HTTPUnprocessableEntity(reason="Store message content needed") try: - message_content = CostEstimationStoreContent.parse_raw(message.item_content) - message_content.estimated_size_mib = uploaded_file.size + message_content = CostEstimationStoreContent.model_validate_json( + message.item_content + ) if message_content.item_hash != file_hash: raise web.HTTPUnprocessableEntity( @@ -334,7 +335,7 @@ async def storage_add_file(request: web.Request): metadata.file.read() if isinstance(metadata, FileField) else metadata ) try: - storage_metadata = StorageMetadata.parse_raw(metadata_bytes) + storage_metadata = StorageMetadata.model_validate_json(metadata_bytes) except ValidationError as e: raise web.HTTPUnprocessableEntity( reason=f"Could not decode metadata: {e.json()}" diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 6cbcc0ebb..7e0a4ac7b 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -296,11 +296,11 @@ async def pub_on_p2p_topics( class BroadcastStatus(BaseModel): publication_status: PublicationStatus - message_status: Optional[MessageStatus] + message_status: Optional[MessageStatus] = None def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: - if broadcast_status.publication_status == "error": + if broadcast_status.publication_status.status == "error": return 500 message_status = broadcast_status.message_status @@ -311,7 +311,7 @@ def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: def format_pending_message_dict(pending_message: BasePendingMessage) -> Dict[str, Any]: - pending_message_dict = pending_message.dict(exclude_none=True) + pending_message_dict = pending_message.model_dump(exclude_none=True) pending_message_dict["time"] = pending_message_dict["time"].timestamp() return pending_message_dict diff --git a/tests/api/test_balance.py b/tests/api/test_balance.py index f1d927399..585eac6d5 100644 --- a/tests/api/test_balance.py +++ b/tests/api/test_balance.py @@ -1,3 +1,5 @@ +from decimal import Decimal + import pytest from aleph_message.models import Chain @@ -26,12 +28,12 @@ async def test_get_balance( response = await ccn_api_client.get(MESSAGES_URI) assert response.status == 200, await response.text() data = await response.json() - assert data["balance"] == user_balance.balance - assert data["locked_amount"] == 1001.8 + assert data["balance"] == str(user_balance.balance) + assert data["locked_amount"] == "1001.800000000000000000" details = data["details"] - assert details["ETH"] == user_balance.balance + assert details["ETH"] == str(user_balance.balance) @pytest.mark.asyncio @@ -48,7 +50,7 @@ async def test_get_balance_with_chain( _ = [message async for message in pipeline] assert fixture_instance_message.item_content - expected_locked_amount = 1001.8 + expected_locked_amount = "1001.800000000000000000" chain = Chain.AVAX.value # Test Avax avax_response = await ccn_api_client.get(f"{MESSAGES_URI}?chain={chain}") @@ -56,7 +58,7 @@ async def test_get_balance_with_chain( assert avax_response.status == 200, await avax_response.text() avax_data = await avax_response.json() avax_expected_balance = user_balance_eth_avax.balance - assert avax_data["balance"] == avax_expected_balance + assert avax_data["balance"] == str(avax_expected_balance) assert avax_data["locked_amount"] == expected_locked_amount # Verify ETH Value @@ -65,7 +67,7 @@ async def test_get_balance_with_chain( assert eth_response.status == 200, await eth_response.text() eth_data = await eth_response.json() eth_expected_balance = user_balance_eth_avax.balance - assert eth_data["balance"] == eth_expected_balance + assert eth_data["balance"] == str(eth_expected_balance) assert eth_data["locked_amount"] == expected_locked_amount # Verify All Chain @@ -73,13 +75,13 @@ async def test_get_balance_with_chain( assert total_response.status == 200, await total_response.text() total_data = await total_response.json() total_expected_balance = user_balance_eth_avax.balance * 2 - assert total_data["balance"] == total_expected_balance + assert total_data["balance"] == str(total_expected_balance) assert total_data["locked_amount"] == expected_locked_amount details = total_data["details"] assert details is not None - assert details["ETH"] == user_balance_eth_avax.balance - assert details["AVAX"] == user_balance_eth_avax.balance + assert details["ETH"] == str(user_balance_eth_avax.balance) + assert details["AVAX"] == str(user_balance_eth_avax.balance) @pytest.mark.asyncio @@ -90,15 +92,15 @@ async def test_get_balance_with_no_balance( assert response.status == 200, await response.text() data = await response.json() - assert data["balance"] == 0 - assert data["locked_amount"] == 0 + assert data["balance"] == "0" + assert str(Decimal(data["locked_amount"]).quantize(Decimal("0.01"))) == "0.00" # Test Eth Case response = await ccn_api_client.get(f"{MESSAGES_URI}?chain{Chain.ETH.value}") assert response.status == 200, await response.text() data = await response.json() - assert data["balance"] == 0 - assert data["locked_amount"] == 0 + assert data["balance"] == "0" + assert str(Decimal(data["locked_amount"]).quantize(Decimal("0.01"))) == "0.00" details = data["details"] assert not details diff --git a/tests/api/test_get_message.py b/tests/api/test_get_message.py index c6333da97..095ffc5f8 100644 --- a/tests/api/test_get_message.py +++ b/tests/api/test_get_message.py @@ -197,7 +197,7 @@ async def test_get_processed_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = ProcessedMessageStatus.parse_obj(response_json) + parsed_response = ProcessedMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.PROCESSED assert parsed_response.item_hash == processed_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -236,7 +236,7 @@ async def test_get_rejected_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = RejectedMessageStatus.parse_obj(response_json) + parsed_response = RejectedMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.REJECTED assert parsed_response.item_hash == rejected_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -260,7 +260,7 @@ async def test_get_forgotten_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = ForgottenMessageStatus.parse_obj(response_json) + parsed_response = ForgottenMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.FORGOTTEN assert parsed_response.item_hash == forgotten_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME @@ -289,7 +289,7 @@ async def test_get_pending_message_status( ) assert response.status == 200, await response.text() response_json = await response.json() - parsed_response = PendingMessageStatus.parse_obj(response_json) + parsed_response = PendingMessageStatus.model_validate(response_json) assert parsed_response.status == MessageStatus.PENDING assert parsed_response.item_hash == processed_message.item_hash assert parsed_response.reception_time == RECEPTION_DATETIME diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 6e9b2582a..b191e3a7b 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -14,7 +14,6 @@ from aleph_message.models.execution.volume import ( ImmutableVolume, ParentVolume, - PersistentVolumeSizeMib, VolumePersistence, ) @@ -228,16 +227,32 @@ async def test_get_messages_filter_by_tags( assert messages[0]["item_hash"] == amend_message_db.item_hash -@pytest.mark.parametrize("type_field", ("msgType", "msgTypes")) @pytest.mark.asyncio -async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str): +async def test_get_by_deprecated_message_type(fixture_messages, ccn_api_client): messages_by_type = defaultdict(list) for message in fixture_messages: messages_by_type[message["type"]].append(message) for message_type, expected_messages in messages_by_type.items(): response = await ccn_api_client.get( - MESSAGES_URI, params={type_field: message_type} + MESSAGES_URI, params={"msgType": message_type} + ) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert set(msg["item_hash"] for msg in messages) == set( + msg["item_hash"] for msg in expected_messages + ) + + +@pytest.mark.asyncio +async def test_get_by_message_type(fixture_messages, ccn_api_client): + messages_by_type = defaultdict(list) + for message in fixture_messages: + messages_by_type[message["type"]].append(message) + + for message_type, expected_messages in messages_by_type.items(): + response = await ccn_api_client.get( + MESSAGES_URI, params={"msgTypes": [message_type]} ) assert response.status == 200, await response.text() messages = (await response.json())["messages"] @@ -517,7 +532,7 @@ def instance_message_fixture() -> MessageDb: ) ), persistence=VolumePersistence("host"), - size_mib=PersistentVolumeSizeMib(1024), + size_mib=1024, ), volumes=[ ImmutableVolume( @@ -527,7 +542,7 @@ def instance_message_fixture() -> MessageDb: use_latest=True, ) ], - ).dict(), + ).model_dump(), size=3000, time=timestamp_to_datetime(1686572207.89381), channel=Channel("TEST"), diff --git a/tests/chains/test_chain_data_service.py b/tests/chains/test_chain_data_service.py index b34a5522c..722d8573f 100644 --- a/tests/chains/test_chain_data_service.py +++ b/tests/chains/test_chain_data_service.py @@ -1,4 +1,5 @@ import datetime as dt +import json import pytest from aleph_message.models import ( @@ -44,7 +45,7 @@ async def mock_add_file( session: DbSession, file_content: bytes, engine: ItemType = ItemType.ipfs ) -> str: content = file_content - archive = OnChainSyncEventPayload.parse_raw(content) + archive = OnChainSyncEventPayload.model_validate_json(content) assert archive.version == 1 assert len(archive.content.messages) == len(messages) @@ -86,7 +87,7 @@ async def test_smart_contract_protocol_ipfs_store( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(), ) chain_data_service = ChainDataService( @@ -112,7 +113,7 @@ async def test_smart_contract_protocol_ipfs_store( assert pending_message.channel is None assert pending_message.item_content - message_content = StoreContent.parse_raw(pending_message.item_content) + message_content = StoreContent.model_validate_json(pending_message.item_content) assert message_content.item_hash == payload.message_content assert message_content.item_type == ItemType.ipfs assert message_content.address == payload.addr @@ -135,7 +136,7 @@ async def test_smart_contract_protocol_regular_message( timestamp=1668611900, addr="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", msgtype="POST", - msgcontent=content.json(), + msgcontent=json.dumps(content.model_dump()), ) tx = ChainTxDb( @@ -146,7 +147,7 @@ async def test_smart_contract_protocol_regular_message( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(mode="json"), ) chain_data_service = ChainDataService( @@ -172,7 +173,7 @@ async def test_smart_contract_protocol_regular_message( assert pending_message.channel is None assert pending_message.item_content - message_content = PostContent.parse_raw(pending_message.item_content) + message_content = PostContent.model_validate_json(pending_message.item_content) assert message_content.address == content.address assert message_content.time == content.time assert message_content.ref == content.ref diff --git a/tests/chains/test_cosmos.py b/tests/chains/test_cosmos.py index 6cc586d17..29ce207e7 100644 --- a/tests/chains/test_cosmos.py +++ b/tests/chains/test_cosmos.py @@ -3,7 +3,6 @@ from urllib.parse import unquote import pytest -from pydantic.tools import parse_obj_as from aleph.chains.cosmos import CosmosConnector from aleph.schemas.pending_messages import PendingPostMessage @@ -16,7 +15,7 @@ def cosmos_message() -> PendingPostMessage: message = json.loads(unquote(TEST_MESSAGE)) message["signature"] = json.dumps(message["signature"]) message["item_content"] = json.dumps(message["item_content"], separators=(",", ":")) - return parse_obj_as(PendingPostMessage, message) + return PendingPostMessage.model_validate(message) @pytest.mark.asyncio @@ -28,8 +27,7 @@ async def test_verify_signature_real(cosmos_message: PendingPostMessage): @pytest.mark.asyncio async def test_verify_signature_bad_json(): connector = CosmosConnector() - message = parse_obj_as( - PendingPostMessage, + message = PendingPostMessage.model_validate( { "chain": "CSDK", "time": 1737558660.737648, @@ -37,7 +35,7 @@ async def test_verify_signature_bad_json(): "type": "POST", "item_hash": sha256("ITEM_HASH".encode()).hexdigest(), "signature": "baba", - }, + } ) result = await connector.verify_signature(message) assert result is False diff --git a/tests/chains/test_tezos.py b/tests/chains/test_tezos.py index aa0d93e4c..ab8cea92b 100644 --- a/tests/chains/test_tezos.py +++ b/tests/chains/test_tezos.py @@ -101,7 +101,7 @@ def test_datetime_to_iso_8601(): type="my-type", address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ), ( MessageType.aggregate.value, @@ -110,7 +110,7 @@ def test_datetime_to_iso_8601(): content={"body": "My first post on Tezos"}, address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ), ], ) @@ -139,4 +139,4 @@ def test_indexer_event_to_aleph_message(message_type: str, message_content: str) assert tx.protocol == ChainSyncProtocol.SMART_CONTRACT assert tx.protocol_version == 1 - assert tx.content == indexer_event.payload.dict() + assert tx.content == indexer_event.payload.model_dump() diff --git a/tests/conftest.py b/tests/conftest.py index 62c876d9a..b50351ac0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -319,7 +319,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): Insert volume references in the DB to make the program processable. """ - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) diff --git a/tests/db/test_accounts.py b/tests/db/test_accounts.py index 2d71d72cf..a02de7cf5 100644 --- a/tests/db/test_accounts.py +++ b/tests/db/test_accounts.py @@ -2,7 +2,6 @@ from typing import List import pytest -import pytz from aleph_message.models import Chain, ItemType, MessageType from aleph.db.accessors.messages import ( @@ -52,7 +51,7 @@ def fixture_messages(): }, item_type=ItemType.inline, size=2000, - time=pytz.utc.localize(dt.datetime.utcfromtimestamp(1664999872)), + time=dt.datetime.fromtimestamp(1664999872, tz=dt.timezone.utc), channel=Channel("CHANEL-N5"), ), ] diff --git a/tests/db/test_cost.py b/tests/db/test_cost.py index a1133097b..d7cd634ac 100644 --- a/tests/db/test_cost.py +++ b/tests/db/test_cost.py @@ -61,7 +61,7 @@ def insert_volume_refs(session: DbSession, message: MessageDb): """ if message.item_content: - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) @@ -98,7 +98,7 @@ async def insert_costs(session: DbSession, message: MessageDb): """ if message.item_content: - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) _, costs = get_total_and_detailed_costs(session, content, message.item_hash) diff --git a/tests/db/test_messages.py b/tests/db/test_messages.py index d38ca7ec0..c16ba6b24 100644 --- a/tests/db/test_messages.py +++ b/tests/db/test_messages.py @@ -46,7 +46,7 @@ def fixture_message() -> MessageDb: }, item_type=ItemType.inline, size=2000, - time=pytz.utc.localize(dt.datetime.utcfromtimestamp(1664999872)), + time=dt.datetime.fromtimestamp(1664999872, tz=dt.timezone.utc), channel=Channel("CHANEL-N5"), ) diff --git a/tests/message_processing/test_process_confidential.py b/tests/message_processing/test_process_confidential.py index 2a2981797..625a771a3 100644 --- a/tests/message_processing/test_process_confidential.py +++ b/tests/message_processing/test_process_confidential.py @@ -194,7 +194,7 @@ def get_volume_refs(content: ExecutableContent) -> List[ImmutableVolume]: def insert_volume_refs(session: DbSession, message: PendingMessageDb): item_content = message.item_content if message.item_content is not None else "" - content = InstanceContent.parse_raw(item_content) + content = InstanceContent.model_validate_json(item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) diff --git a/tests/message_processing/test_process_instances.py b/tests/message_processing/test_process_instances.py index bae788e5e..dcf851215 100644 --- a/tests/message_processing/test_process_instances.py +++ b/tests/message_processing/test_process_instances.py @@ -294,7 +294,7 @@ def fixture_forget_instance_message( sender=fixture_instance_message.sender, signature=None, item_type=ItemType.inline, - item_content=content.json(), + item_content=content.model_dump_json(), time=fixture_instance_message.time + dt.timedelta(seconds=1), channel=None, reception_time=fixture_instance_message.reception_time @@ -338,7 +338,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): """ assert message.item_content - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) @@ -488,7 +488,9 @@ async def test_process_instance_missing_volumes( assert rejected_message.error_code == ErrorCode.VM_VOLUME_NOT_FOUND if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) volume_refs = set(volume.ref for volume in get_volume_refs(content)) assert isinstance(rejected_message.details, dict) assert set(rejected_message.details["errors"]) == volume_refs @@ -570,7 +572,9 @@ async def test_get_additional_storage_price( session.commit() if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) with session_factory() as session: settings = _get_settings(session) pricing = _get_product_price(session, content, settings) @@ -605,7 +609,9 @@ async def test_get_total_and_detailed_costs_from_db( _ = [message async for message in pipeline] if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) with session_factory() as session: cost, _ = get_total_and_detailed_costs( session=session, @@ -634,7 +640,7 @@ async def test_compare_account_cost_with_cost_function_hold( _ = [message async for message in pipeline] assert fixture_instance_message.item_content - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json(fixture_instance_message.item_content) with session_factory() as session: db_cost, _ = get_total_and_detailed_costs_from_db( session=session, @@ -669,11 +675,9 @@ async def test_compare_account_cost_with_cost_payg_funct( _ = [message async for message in pipeline] assert fixture_instance_message_payg.item_content - - content = InstanceContent.parse_raw( + content = InstanceContent.model_validate_json( fixture_instance_message_payg.item_content - ) # Parse again - + ) with session_factory() as session: assert content.payment.type == PaymentType.superfluid cost, details = get_total_and_detailed_costs( @@ -777,7 +781,7 @@ async def test_compare_account_cost_with_cost_function_without_volume( _ = [message async for message in pipeline] assert fixture_instance_message_only_rootfs.item_content - content = InstanceContent.parse_raw( + content = InstanceContent.model_validate_json( fixture_instance_message_only_rootfs.item_content ) with session_factory() as session: diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index d9f3b6086..69f6ada4e 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -2,7 +2,6 @@ from typing import Dict, List, Set import pytest -import pytz from aleph_message.models import Chain, MessageType, PostContent from configmanager import Config from sqlalchemy import select @@ -56,7 +55,7 @@ async def test_process_pending_tx_on_chain_protocol( chain_tx = ChainTxDb( hash="0xf49cb176c1ce4f6eb7b9721303994b05074f8fadc37b5f41ac6f78bdf4b14b6c", chain=Chain.ETH, - datetime=pytz.utc.localize(dt.datetime.utcfromtimestamp(1632835747)), + datetime=dt.datetime.fromtimestamp(1632835747, tz=dt.timezone.utc), height=13314512, publisher="0x23eC28598DCeB2f7082Cc3a9D670592DfEd6e0dC", protocol=ChainSyncProtocol.ON_CHAIN_SYNC, @@ -136,7 +135,7 @@ async def _process_smart_contract_tx( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(), ) pending_tx = PendingTxDb(tx=tx) @@ -214,7 +213,7 @@ async def test_process_pending_smart_contract_tx_post( type="my-type", address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ) await _process_smart_contract_tx( diff --git a/tests/message_processing/test_process_programs.py b/tests/message_processing/test_process_programs.py index 94ce2a89e..99fbce657 100644 --- a/tests/message_processing/test_process_programs.py +++ b/tests/message_processing/test_process_programs.py @@ -117,7 +117,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): """ assert message.item_content - content = ProgramContent.parse_raw(message.item_content) + content = ProgramContent.model_validate_json(message.item_content) volumes = get_volumes_with_ref(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) @@ -302,7 +302,7 @@ async def test_process_program_missing_volumes( assert rejected_message.error_code == ErrorCode.VM_VOLUME_NOT_FOUND assert program_message.item_content - content = ProgramContent.parse_raw(program_message.item_content) + content = ProgramContent.model_validate_json(program_message.item_content) volume_refs = set(volume.ref for volume in get_volumes_with_ref(content)) assert isinstance(rejected_message.details, dict) assert set(rejected_message.details["errors"]) == volume_refs diff --git a/tests/schemas/test_pending_messages.py b/tests/schemas/test_pending_messages.py index e48ee076e..2b209469f 100644 --- a/tests/schemas/test_pending_messages.py +++ b/tests/schemas/test_pending_messages.py @@ -3,7 +3,6 @@ from typing import Dict import pytest -import pytz from aleph_message.models import ItemType from aleph.schemas.pending_messages import ( @@ -26,8 +25,8 @@ def check_basic_message_fields(pending_message: BasePendingMessage, message_dict assert pending_message.channel == message_dict["channel"] assert pending_message.signature == message_dict["signature"] assert pending_message.channel == message_dict["channel"] - assert pending_message.time == pytz.utc.localize( - dt.datetime.utcfromtimestamp(message_dict["time"]) + assert pending_message.time == dt.datetime.fromtimestamp( + message_dict["time"], tz=dt.timezone.utc ) @@ -170,7 +169,7 @@ def test_parse_program_message(): content = json.loads(message_dict["item_content"]) assert message.content.address == content["address"] assert message.content.time == content["time"] - assert message.content.code.dict(exclude_none=True) == content["code"] + assert message.content.code.model_dump(exclude_none=True) == content["code"] assert message.content.type == content["type"] diff --git a/tests/services/test_cost_service.py b/tests/services/test_cost_service.py index d92548b68..8d8441841 100644 --- a/tests/services/test_cost_service.py +++ b/tests/services/test_cost_service.py @@ -49,7 +49,7 @@ def fixture_hold_instance_message() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -109,7 +109,7 @@ def fixture_hold_instance_message_complete() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -145,7 +145,7 @@ def fixture_flow_instance_message() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -210,7 +210,7 @@ def fixture_flow_instance_message_complete() -> ExecutableContent: ], } - return InstanceContent.parse_obj(content) + return InstanceContent.model_validate(content) @pytest.fixture @@ -267,7 +267,7 @@ def fixture_hold_program_message_complete() -> ExecutableContent: }, } - return CostEstimationProgramContent.parse_obj(content) + return CostEstimationProgramContent.model_validate(content) def test_compute_cost( @@ -293,7 +293,7 @@ def test_compute_cost_conf( fixture_settings_aggregate_in_db, fixture_hold_instance_message, ): - message_dict = fixture_hold_instance_message.dict() + message_dict = fixture_hold_instance_message.model_dump() # Convert the message to conf message_dict["environment"].update( @@ -306,7 +306,7 @@ def test_compute_cost_conf( } ) - rebuilt_message = InstanceContent.parse_obj(message_dict) + rebuilt_message = InstanceContent.model_validate(message_dict) file_db = StoredFileDb() mock = Mock() @@ -410,7 +410,7 @@ def test_compute_flow_cost_conf( fixture_settings_aggregate_in_db, fixture_flow_instance_message, ): - message_dict = fixture_flow_instance_message.dict() + message_dict = fixture_flow_instance_message.model_dump() # Convert the message to conf message_dict["environment"].update( @@ -423,7 +423,7 @@ def test_compute_flow_cost_conf( } ) - rebuilt_message = InstanceContent.parse_obj(message_dict) + rebuilt_message = InstanceContent.model_validate(message_dict) # Proceed with the test file_db = StoredFileDb()