Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions mypy-baseline.txt

Large diffs are not rendered by default.

109 changes: 69 additions & 40 deletions src/inmanta/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
from inmanta.types import JsonType, PrimitiveTypes, ResourceIdStr, ResourceType, ResourceVersionIdStr
from inmanta.util import parse_timestamp
from sqlalchemy import URL, AdaptedConnection, NullPool
from sqlalchemy.dialects import registry
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import ConnectionPoolEntry

Expand Down Expand Up @@ -1145,14 +1147,13 @@ def _from_db_single(self, name: str, value: object) -> object:

# asyncpg does not convert a jsonb field to a dict
if isinstance(value, str) and self.field_type is dict:
return json.loads(value)
return value
# asyncpg does not convert an enum field to an enum type
if isinstance(value, str) and issubclass(self.field_type, enum.Enum):
return self.field_type[value]
# decode typed json
if isinstance(value, str) and issubclass(self.field_type, pydantic.BaseModel):
jsv = json.loads(value)
return self.field_type(**jsv)
if isinstance(value, dict) and issubclass(self.field_type, pydantic.BaseModel):
return self.field_type(**value)
if self.field_type == pydantic.AnyHttpUrl:
return pydantic.TypeAdapter(pydantic.AnyHttpUrl).validate_python(value)

Expand Down Expand Up @@ -3884,7 +3885,7 @@ async def get_compile_details(cls, environment: uuid.UUID, id: uuid.UUID) -> Opt
returncode=compile["returncode"],
)
)
for name, url in (json.loads(compile["links"]) if requested_compile["links"] else {}).items():
for name, url in cast(dict[str, list[str]], compile.get("links", {})).items():
links[name].add(*url)

return m.CompileDetails(
Expand All @@ -3898,20 +3899,16 @@ async def get_compile_details(cls, environment: uuid.UUID, id: uuid.UUID) -> Opt
version=requested_compile["version"],
do_export=requested_compile["do_export"],
force_update=requested_compile["force_update"],
metadata=json.loads(requested_compile["metadata"]) if requested_compile["metadata"] else {},
environment_variables=(
json.loads(requested_compile["used_environment_variables"])
if requested_compile["used_environment_variables"] is not None
else {}
),
requested_environment_variables=(json.loads(requested_compile["requested_environment_variables"])),
mergeable_environment_variables=(json.loads(requested_compile["mergeable_environment_variables"])),
metadata=requested_compile["metadata"] or {},
environment_variables=requested_compile["used_environment_variables"] or {},
requested_environment_variables=requested_compile["requested_environment_variables"],
mergeable_environment_variables=requested_compile["mergeable_environment_variables"],
partial=requested_compile["partial"],
removed_resource_sets=requested_compile["removed_resource_sets"],
exporter_plugin=requested_compile["exporter_plugin"],
notify_failed_compile=requested_compile["notify_failed_compile"],
failed_compile_message=requested_compile["failed_compile_message"],
compile_data=json.loads(requested_compile["compile_data"]) if requested_compile["compile_data"] else None,
compile_data=requested_compile["compile_data"],
reports=reports,
links={key: sorted(list(links)) for key, links in links.items()},
)
Expand Down Expand Up @@ -4066,20 +4063,9 @@ def __init__(self, from_postgres: bool = False, **kwargs: object) -> None:
if self.changes == {}:
self.changes = None

# load message json correctly
if from_postgres and self.messages:
new_messages = []
for message in self.messages:
# Not 100% sure why this is needed, could depend on whether the data is
# retrieved through sql alchemy (automatically deserializes json into dict)
# or through regular asyncpg queries (no automatic deserialization)
if isinstance(message, str):
pass
LOGGER.debug(f"GOT A STR {message=}")
if isinstance(message, dict):
pass
LOGGER.debug(f"GOT A DICT {message=}")
message = json.loads(message)
if "timestamp" in message:
ta = pydantic.TypeAdapter(datetime.datetime)
# use pydantic instead of datetime.strptime because strptime has trouble parsing isoformat timezone offset
Expand Down Expand Up @@ -5084,7 +5070,6 @@ async def get_resources_for_version(
async for record in con.cursor(query, *values):
if no_obj:
record = dict(record)
record["attributes"] = json.loads(record["attributes"])
cls.__mangle_dict(record)
resources_list.append(record)
else:
Expand All @@ -5107,11 +5092,7 @@ async def get_resources_for_version_raw(
(filter_statement, values) = cls._get_composed_filter(environment=environment, model=version)
query = "SELECT " + projection + " FROM " + cls.table_name() + " WHERE " + filter_statement
resource_records = await cls._fetch_query(query, *values, connection=connection)
resources = [dict(record) for record in resource_records]
for res in resources:
if "attributes" in res:
res["attributes"] = json.loads(res["attributes"])
return resources
return [dict(record) for record in resource_records]

@classmethod
async def get_resources_since_version_raw(
Expand Down Expand Up @@ -5149,8 +5130,6 @@ async def get_resources_since_version_raw(
# left join produced no resources
continue
resource: dict[str, object] = dict(raw_resource)
if "attributes" in resource:
resource["attributes"] = json.loads(resource["attributes"])
if projection is not None:
assert set(projection) <= resource.keys()
parsed_resources.append(resource)
Expand Down Expand Up @@ -5195,11 +5174,6 @@ def collect_projection(projection: Optional[Collection[str]], prefix: str) -> st
"""
resource_records = await cls._fetch_query(query, environment, version, connection=connection)
resources = [dict(record) for record in resource_records]
for res in resources:
if project_attributes:
for k in project_attributes:
if res[k]:
res[k] = json.loads(res[k])
return resources

@classmethod
Expand Down Expand Up @@ -5335,7 +5309,7 @@ def status_sub_query(resource_table_name: str) -> str:
return None
record = result[0]
parsed_id = resources.Id.parse_id(record["latest_resource_id"])
attributes = json.loads(record["attributes"])
attributes = record["attributes"]
# Due to a bug, the version field has always been present in the attributes dictionary.
# This bug has been fixed in the database. For backwards compatibility reason we here make sure that the
# version field is present in the attributes dictionary served out via the API.
Expand Down Expand Up @@ -6629,6 +6603,7 @@ async def connect_pool(
min_size=connection_pool_min_size,
max_size=connection_pool_max_size,
timeout=connection_timeout,
init=asyncpg_on_connect,
)
try:
set_connection_pool(pool)
Expand Down Expand Up @@ -6658,6 +6633,60 @@ async def disconnect_pool() -> None:
BaseDocument.remove_connection_pool()


class ExternalInitAsyncPG(PGDialect_asyncpg):
"""
Define our own postgres dialect to use in engine initialization. The parent dialect
reconfigures json serialization/deserialization each time a connection is
checked out from the pool.

Overwriting the on_connect method here removes this redundant behaviour. The
configuration for json serialization is set once when the asyncpg pool is
created
"""

def on_connect(self) -> None:
return None


registry.impls["postgresql.asyncpgnoi"] = lambda: ExternalInitAsyncPG


async def asyncpg_on_connect(connection: asyncpg.Connection) -> None:
"""
Helper method to configure json serialization/deserialization when
initializing the database connection pool.
"""

def _json_decoder(bin_value: bytes) -> object:
return json.loads(bin_value.decode())

await connection.set_type_codec(
"json",
encoder=str.encode,
decoder=_json_decoder,
schema="pg_catalog",
format="binary",
)

def _jsonb_encoder(str_value: str) -> bytes:
# \x01 is the prefix for jsonb used by PostgreSQL.
# asyncpg requires it when format='binary'
return b"\x01" + str_value.encode()

def _jsonb_decoder(bin_value: bytes) -> object:
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
# asyncpg returns it when format='binary'
return json.loads(bin_value[1:].decode())

await connection.set_type_codec(
"jsonb",
encoder=_jsonb_encoder,
decoder=_jsonb_decoder,
schema="pg_catalog",
format="binary",
)


async def start_engine(
*,
database_username: str,
Expand Down Expand Up @@ -6702,7 +6731,7 @@ async def start_engine(
)

url_object = URL.create(
drivername="postgresql+asyncpg",
drivername="postgresql+asyncpgnoi",
username=database_username,
password=database_password,
host=database_host,
Expand Down
27 changes: 12 additions & 15 deletions src/inmanta/data/dataview.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import abc
import json
import urllib.parse
from abc import ABC
from collections.abc import Sequence
Expand Down Expand Up @@ -623,7 +622,7 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[model.LatestRele
resource_version_id=resource["resource_id"] + ",v=" + str(resource["model"]),
id_details=data.Resource.get_details_from_resource_id(resource["resource_id"]),
status=resource["status"],
requires=json.loads(resource["attributes"]).get("requires", []),
requires=resource["attributes"].get("requires", []),
)
for resource in records
if resource["attributes"] # filter out bad joins
Expand Down Expand Up @@ -683,7 +682,7 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[model.VersionedR
resource_id=versioned_resource["resource_id"],
resource_version_id=versioned_resource["resource_id"] + f",v={self.version}",
id_details=data.Resource.get_details_from_resource_id(versioned_resource["resource_id"]),
requires=json.loads(versioned_resource["attributes"]).get("requires", []), # todo: broken
requires=versioned_resource["attributes"].get("requires", []),
)
for versioned_resource in records
]
Expand Down Expand Up @@ -751,12 +750,10 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[model.CompileRep
version=compile["version"],
do_export=compile["do_export"],
force_update=compile["force_update"],
metadata=json.loads(compile["metadata"]) if compile["metadata"] else {},
environment_variables=(
json.loads(compile["used_environment_variables"]) if compile["used_environment_variables"] else {}
),
requested_environment_variables=json.loads(compile["requested_environment_variables"]),
mergeable_environment_variables=json.loads(compile["mergeable_environment_variables"]),
metadata=compile["metadata"] or {},
environment_variables=compile["used_environment_variables"] or {},
requested_environment_variables=compile["requested_environment_variables"],
mergeable_environment_variables=compile["mergeable_environment_variables"],
partial=compile["partial"],
removed_resource_sets=compile["removed_resource_sets"],
exporter_plugin=compile["exporter_plugin"],
Expand Down Expand Up @@ -915,7 +912,7 @@ def get_base_query(self) -> SimpleQueryBuilder:

def construct_dtos(self, records: Sequence[Record]) -> Sequence[ResourceHistory]:
def get_attributes(record: Record) -> JsonType:
attributes: JsonType = json.loads(record["attributes"])
attributes = record["attributes"]
if "version" not in attributes:
# Due to a bug, the version field has always been present in the attributes dictionary.
# This bug has been fixed in the database. For backwards compatibility reason we here make sure that the
Expand All @@ -929,7 +926,7 @@ def get_attributes(record: Record) -> JsonType:
attribute_hash=record["attribute_hash"],
attributes=get_attributes(record),
date=record["date"],
requires=[Id.parse_id(rid).resource_str() for rid in json.loads(record["attributes"]).get("requires", [])],
requires=[Id.parse_id(rid).resource_str() for rid in record["attributes"].get("requires", [])],
)
for record in records
]
Expand Down Expand Up @@ -1018,7 +1015,7 @@ def get_base_query(self) -> SimpleQueryBuilder:
def construct_dtos(self, records: Sequence[Record]) -> Sequence[ResourceLog]:
logs = []
for record in records:
message = json.loads(record["unnested_message"])
message = record["unnested_message"]
logs.append(
ResourceLog(
action_id=record["action_id"],
Expand Down Expand Up @@ -1085,7 +1082,7 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[Fact]:
source=fact["source"],
updated=fact["updated"],
resource_id=fact["resource_id"],
metadata=json.loads(fact["metadata"]) if fact["metadata"] else None,
metadata=fact["metadata"],
environment=fact["environment"],
)
for fact in records
Expand Down Expand Up @@ -1203,7 +1200,7 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[model.Parameter]
value=parameter["value"],
source=parameter["source"],
updated=parameter["updated"],
metadata=json.loads(parameter["metadata"]) if parameter["metadata"] else None,
metadata=parameter["metadata"],
environment=parameter["environment"],
)
for parameter in records
Expand Down Expand Up @@ -1377,7 +1374,7 @@ def construct_dtos(self, records: Sequence[Record]) -> Sequence[dict[str, str]]:
return [
model.DiscoveredResource(
discovered_resource_id=res["discovered_resource_id"],
values=json.loads(res["values"]),
values=res["values"],
managed_resource_uri=(
f"/api/v2/resource/{urllib.parse.quote(str(res['discovered_resource_id']), safe='')}"
if res["managed"]
Expand Down
3 changes: 1 addition & 2 deletions src/inmanta/db/versions/v202503030.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Contact: [email protected]
"""

import json
import logging
from collections import defaultdict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -136,7 +135,7 @@ async def fetch_code_data() -> tuple[VersionsPerEnv, dict[int, dict[str, set[str

assert isinstance(model_version, int)

for file_hash, file_data in json.loads(source_refs).items(): # type: ignore
for file_hash, file_data in source_refs.items(): # type: ignore
file_path, python_module_name, requirements = file_data
inmanta_module_name = python_module_name.split(".")[1]
assert isinstance(inmanta_module_name, str)
Expand Down
3 changes: 1 addition & 2 deletions src/inmanta/deploy/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import datetime
import enum
import itertools
import json
import uuid
from collections import defaultdict
from collections.abc import Mapping, Sequence, Set
Expand Down Expand Up @@ -295,7 +294,7 @@ async def create_from_db(
resource_intent = ResourceIntent(
resource_id=resource_id,
attribute_hash=res["attribute_hash"],
attributes=json.loads(res["attributes"]),
attributes=res["attributes"],
)
result.intent[resource_id] = resource_intent

Expand Down
1 change: 0 additions & 1 deletion tests/deploy/e2e/test_code_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ async def test_logging_on_code_loading_failure_missing_code(server, client, envi
for resource_action in result.result["data"]
for log_line in resource_action["messages"]
)
assert False


@pytest.mark.parametrize("auto_start_agent", [True])
Expand Down
6 changes: 3 additions & 3 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def get_module_source(module: str, code: str) -> ModuleSource:
@pytest.mark.parametrize(
"install_all_dependencies,expected_dependencies",
[
(True, ["inmanta-module-std", "lorem"]),
(False, ["lorem"]),
(True, {"inmanta-module-std", "lorem"}),
(False, {"lorem"}),
],
)
def test_code_manager(tmpdir: py.path.local, deactive_venv, install_all_dependencies, expected_dependencies):
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_code_manager(tmpdir: py.path.local, deactive_venv, install_all_dependen
assert "multiple_plugin_files" in module_version_info.keys()
assert "single_plugin_file" in module_version_info.keys()

assert module_version_info["single_plugin_file"].requirements == expected_dependencies
assert set(module_version_info["single_plugin_file"].requirements) == expected_dependencies
assert len(module_version_info["single_plugin_file"].files_in_module) == 1
assert len(module_version_info["multiple_plugin_files"].files_in_module) == 3

Expand Down