Skip to content

Commit 677b064

Browse files
authored
Merge pull request #1 from LonelyVikingMichael/usermodel-relationships
Usermodel relationships
2 parents b2b9fac + 7a2a2d3 commit 677b064

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

fastapi_users_db_ormar/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class OrmarUserDatabase(BaseUserDatabase[UD]):
4444
:param user_db_model: Pydantic model of a DB representation of a user.
4545
:param model: ormar ORM model.
4646
:param oauth_account_model: Optional ormar ORM model of a OAuth account.
47+
:param select_related: Optional list of relationship names to retrieve with User queries.
4748
"""
4849

4950
model: Type[OrmarBaseUserModel]
@@ -54,10 +55,12 @@ def __init__(
5455
user_db_model: Type[UD],
5556
model: Type[OrmarBaseUserModel],
5657
oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]] = None,
58+
select_related: Optional[List[str]] = None
5759
):
5860
super().__init__(user_db_model)
5961
self.model = model
6062
self.oauth_account_model = oauth_account_model
63+
self.select_related = select_related
6164

6265
async def get(self, id: UUID4) -> Optional[UD]:
6366
return await self._get_user(id=id)
@@ -73,6 +76,7 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD
7376
async def create(self, user: UD) -> UD:
7477
oauth_accounts = getattr(user, "oauth_accounts", [])
7578
model = await self.model(**user.dict(exclude={"oauth_accounts"})).save()
79+
await model.save_related()
7680
if oauth_accounts and self.oauth_account_model:
7781
await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts)
7882
user_db = await self._get_user(id=user.id)
@@ -105,6 +109,9 @@ async def _get_db_user(self, **kwargs: Any) -> OrmarBaseUserModel:
105109
query = self.model.objects.filter(**kwargs)
106110
if self.oauth_account_model is not None:
107111
query = query.select_related("oauth_accounts")
112+
if self.select_related is not None:
113+
for relation in self.select_related:
114+
query = query.select_related(relation)
108115
return cast(OrmarBaseUserModel, await query.get())
109116

110117
async def _get_user(self, **kwargs: Any) -> Optional[UD]:

tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1+
import uuid
12
import asyncio
2-
from typing import Optional
3+
from typing import Optional, List
34

45
import pytest
56
from fastapi_users import models
7+
from pydantic import BaseModel, UUID4, Field
8+
9+
10+
class Role(BaseModel):
11+
id: UUID4 = Field(default_factory=uuid.uuid4)
12+
name: str
613

714

815
class User(models.BaseUser):
916
first_name: Optional[str]
17+
roles: Optional[List[Role]]
1018

1119

1220
class UserCreate(models.BaseUserCreate):
1321
first_name: Optional[str]
22+
roles: Optional[List[Role]]
1423

1524

1625
class UserUpdate(models.BaseUserUpdate):

tests/test_fastapi_users_db_ormar.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@
2020
database = databases.Database(DATABASE_URL)
2121

2222

23+
class Role(ormar.Model):
24+
class Meta:
25+
tablename = "roles"
26+
metadata = metadata
27+
database = database
28+
29+
id = ormar.UUID(primary_key=True, uuid_format="string")
30+
name = ormar.String(nullable=False, max_length=255)
31+
32+
2333
class User(OrmarBaseUserModel):
2434
class Meta:
2535
metadata = metadata
2636
database = database
2737

2838
first_name = ormar.String(nullable=True, max_length=255)
39+
roles = ormar.ManyToMany(Role, skip_reverse=True)
2940

3041

3142
class OAuthAccount(OrmarBaseOAuthAccountModel):
@@ -68,6 +79,21 @@ async def ormar_user_db_oauth() -> AsyncGenerator[OrmarUserDatabase, None]:
6879
await database.disconnect()
6980

7081

82+
@pytest.fixture
83+
async def ormar_user_db_with_relations() -> AsyncGenerator[OrmarUserDatabase, None]:
84+
engine = sqlalchemy.create_engine(
85+
DATABASE_URL, connect_args={"check_same_thread": False}
86+
)
87+
metadata.create_all(engine)
88+
89+
await database.connect()
90+
91+
yield OrmarUserDatabase(user_db_model=UserDB, model=User, select_related=['roles'])
92+
93+
metadata.drop_all(engine)
94+
await database.disconnect()
95+
96+
7197
@pytest.mark.asyncio
7298
@pytest.mark.db
7399
async def test_queries(ormar_user_db: OrmarUserDatabase[UserDB]):
@@ -196,3 +222,29 @@ async def test_queries_oauth(
196222
# Unknown OAuth account
197223
unknown_oauth_user = await ormar_user_db_oauth.get_by_oauth_account("foo", "bar")
198224
assert unknown_oauth_user is None
225+
226+
227+
@pytest.mark.asyncio
228+
@pytest.mark.db
229+
async def test_queries_custom_fields_relations(
230+
ormar_user_db_with_relations: OrmarUserDatabase[UserDB]
231+
):
232+
# Create role to pair with
233+
role = await Role.objects.create(
234+
id=uuid.uuid4(),
235+
name="editor"
236+
)
237+
238+
assert role.id is not None
239+
240+
user = UserDB(
241+
242+
hashed_password="guinevere",
243+
roles=[role]
244+
)
245+
246+
# Create with relationship
247+
user_db = await ormar_user_db_with_relations.create(user)
248+
assert user_db.roles is not None
249+
assert len(user_db.roles) is not 0
250+
assert user_db.roles[0].id == role.id

0 commit comments

Comments
 (0)