Skip to content

Commit 47a8611

Browse files
authored
feat(users): ✨ add user auth and the option for no expiry of… (#37)
1 parent de3f3ac commit 47a8611

File tree

13 files changed

+1212
-307
lines changed

13 files changed

+1212
-307
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ ENV PGID=1000
4444
ENV DB_PATH=/config/comparisons.db
4545
ENV UPLOADS_PATH=/data/uploads
4646
ENV RETENTION_DAYS=7
47+
ENV ADMIN_INVITATION_CODE=change-me-in-production
4748

4849
# Expose port
4950
EXPOSE 8000

auth.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""
2+
Authentication module for Comps.
3+
Handles user authentication, invitation codes, and session management.
4+
"""
5+
import hashlib
6+
import secrets
7+
import sqlite3
8+
import time
9+
from typing import Optional, Dict, Any
10+
from fastapi import Request, HTTPException, Depends, Cookie
11+
from fastapi.security import APIKeyCookie
12+
from jose import JWTError, jwt
13+
from datetime import datetime, timedelta
14+
15+
# Constants
16+
DB_PATH = 'comparisons.db'
17+
SECRET_KEY = secrets.token_hex(32) # Generate a random secret key
18+
ALGORITHM = "HS256"
19+
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 1 week
20+
21+
# Cookie security
22+
cookie_sec = APIKeyCookie(name="session")
23+
24+
def hash_invitation_code(code: str) -> str:
25+
"""Hash an invitation code for secure storage"""
26+
return hashlib.sha256(code.encode()).hexdigest()
27+
28+
def create_invitation_code(created_by_id: int) -> str:
29+
"""Create a new invitation code"""
30+
# Generate a random code
31+
code = secrets.token_urlsafe(16)
32+
33+
conn = sqlite3.connect(DB_PATH)
34+
c = conn.cursor()
35+
36+
# Store the code
37+
c.execute(
38+
'INSERT INTO invitation_codes (code, created_by) VALUES (?, ?)',
39+
(code, created_by_id)
40+
)
41+
42+
conn.commit()
43+
conn.close()
44+
45+
return code
46+
47+
def verify_invitation_code(code: str) -> bool:
48+
"""Verify if an invitation code is valid and unused"""
49+
conn = sqlite3.connect(DB_PATH)
50+
c = conn.cursor()
51+
52+
c.execute('SELECT is_used FROM invitation_codes WHERE code = ?', (code,))
53+
result = c.fetchone()
54+
55+
conn.close()
56+
57+
# Code is valid if it exists and is not used
58+
return result is not None and not result[0]
59+
60+
def register_user(username: str, invitation_code: str) -> Optional[Dict[str, Any]]:
61+
"""Register a new user with an invitation code"""
62+
if not verify_invitation_code(invitation_code):
63+
return None
64+
65+
# Hash the invitation code for storage
66+
code_hash = hash_invitation_code(invitation_code)
67+
68+
conn = sqlite3.connect(DB_PATH)
69+
c = conn.cursor()
70+
71+
try:
72+
# Check if username already exists
73+
c.execute('SELECT id FROM users WHERE username = ?', (username,))
74+
if c.fetchone():
75+
conn.close()
76+
return None
77+
78+
# Create the user
79+
c.execute(
80+
'INSERT INTO users (username, invitation_code_hash, never_expire_comparisons) VALUES (?, ?, ?)',
81+
(username, code_hash, 1) # All invited users get permanent comparisons
82+
)
83+
user_id = c.lastrowid
84+
85+
# Mark the invitation code as used
86+
c.execute(
87+
'UPDATE invitation_codes SET is_used = 1, used_by = ? WHERE code = ?',
88+
(user_id, invitation_code)
89+
)
90+
91+
# Get the user data
92+
c.execute('SELECT id, username, is_admin, never_expire_comparisons FROM users WHERE id = ?', (user_id,))
93+
user = c.fetchone()
94+
95+
conn.commit()
96+
97+
if user:
98+
return {
99+
"id": user[0],
100+
"username": user[1],
101+
"is_admin": bool(user[2]),
102+
"never_expire_comparisons": bool(user[3])
103+
}
104+
return None
105+
except Exception as e:
106+
print(f"Error registering user: {e}")
107+
conn.rollback()
108+
return None
109+
finally:
110+
conn.close()
111+
112+
def authenticate_user(username: str, invitation_code: str) -> Optional[Dict[str, Any]]:
113+
"""Authenticate a user with their username and invitation code"""
114+
# Hash the invitation code for comparison
115+
code_hash = hash_invitation_code(invitation_code)
116+
117+
conn = sqlite3.connect(DB_PATH)
118+
c = conn.cursor()
119+
120+
c.execute(
121+
'SELECT id, username, is_admin, never_expire_comparisons FROM users WHERE username = ? AND invitation_code_hash = ?',
122+
(username, code_hash)
123+
)
124+
user = c.fetchone()
125+
126+
conn.close()
127+
128+
if user:
129+
return {
130+
"id": user[0],
131+
"username": user[1],
132+
"is_admin": bool(user[2]),
133+
"never_expire_comparisons": bool(user[3])
134+
}
135+
return None
136+
137+
def create_access_token(data: dict) -> str:
138+
"""Create a JWT access token"""
139+
to_encode = data.copy()
140+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
141+
to_encode.update({"exp": expire})
142+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
143+
return encoded_jwt
144+
145+
def get_current_user(session: str = Depends(cookie_sec)) -> Optional[Dict[str, Any]]:
146+
"""Get the current user from the session cookie"""
147+
try:
148+
payload = jwt.decode(session, SECRET_KEY, algorithms=[ALGORITHM])
149+
user_id = payload.get("sub")
150+
if user_id is None:
151+
return None
152+
153+
conn = sqlite3.connect(DB_PATH)
154+
c = conn.cursor()
155+
156+
c.execute(
157+
'SELECT id, username, is_admin, never_expire_comparisons FROM users WHERE id = ?',
158+
(user_id,)
159+
)
160+
user = c.fetchone()
161+
162+
conn.close()
163+
164+
if user:
165+
return {
166+
"id": user[0],
167+
"username": user[1],
168+
"is_admin": bool(user[2]),
169+
"never_expire_comparisons": bool(user[3])
170+
}
171+
return None
172+
except JWTError:
173+
return None
174+
except Exception as e:
175+
print(f"Error getting current user: {e}")
176+
return None
177+
178+
def get_user_invitation_codes(user_id: int) -> list:
179+
"""Get all invitation codes created by a user"""
180+
conn = sqlite3.connect(DB_PATH)
181+
c = conn.cursor()
182+
183+
c.execute('''
184+
SELECT ic.code, ic.is_used, u.username, ic.created_at
185+
FROM invitation_codes ic
186+
LEFT JOIN users u ON ic.used_by = u.id
187+
WHERE ic.created_by = ?
188+
ORDER BY ic.created_at DESC
189+
''', (user_id,))
190+
191+
codes = []
192+
for code, is_used, used_by, created_at in c.fetchall():
193+
codes.append({
194+
"code": code,
195+
"is_used": bool(is_used),
196+
"used_by": used_by,
197+
"created_at": created_at
198+
})
199+
200+
conn.close()
201+
return codes
202+
203+
def is_admin(user: dict) -> bool:
204+
"""Check if a user is an admin"""
205+
return user and user.get("is_admin", False)
206+
207+
async def get_optional_user(request: Request) -> Optional[Dict[str, Any]]:
208+
"""Get the current user if logged in, otherwise return None"""
209+
session = request.cookies.get("session")
210+
if not session:
211+
return None
212+
213+
try:
214+
payload = jwt.decode(session, SECRET_KEY, algorithms=[ALGORITHM])
215+
user_id = payload.get("sub")
216+
if user_id is None:
217+
return None
218+
219+
conn = sqlite3.connect(DB_PATH)
220+
c = conn.cursor()
221+
222+
c.execute(
223+
'SELECT id, username, is_admin, never_expire_comparisons FROM users WHERE id = ?',
224+
(user_id,)
225+
)
226+
user = c.fetchone()
227+
228+
conn.close()
229+
230+
if user:
231+
return {
232+
"id": user[0],
233+
"username": user[1],
234+
"is_admin": bool(user[2]),
235+
"never_expire_comparisons": bool(user[3])
236+
}
237+
return None
238+
except JWTError:
239+
return None
240+
except Exception as e:
241+
print(f"Error getting optional user: {e}")
242+
return None

database.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sqlite3
22
import os
3-
import shutil
3+
import shutil
44
from datetime import datetime, timedelta
55
from typing import List, Optional
66
from migrations.manager import MigrationManager
@@ -12,15 +12,31 @@ def init_db():
1212
migration_manager = MigrationManager()
1313
migration_manager.migrate(DB_PATH)
1414

15-
def create_comparison(comparison_id: str, name: Optional[str], show_name: Optional[str], tags: Optional[List[str]], metadata: dict):
15+
def create_comparison(comparison_id: str, name: Optional[str], show_name: Optional[str], tags: Optional[List[str]], metadata: dict, user_id: Optional[int] = None):
1616
conn = sqlite3.connect(DB_PATH)
1717
c = conn.cursor()
1818

1919
c.execute(
20-
'INSERT INTO comparisons (id, name, show_name, total_rows, total_columns, expiration_type, expiration_days) VALUES (?, ?, ?, ?, ?, ?, ?)',
21-
(comparison_id, name, show_name, metadata.get('total_rows', 1), metadata.get('total_columns', 2), metadata.get('expiration_type', 'from_last_access'), int(metadata.get('expiration_days', 7)))
20+
'INSERT INTO comparisons (id, name, show_name, total_rows, total_columns, expiration_type, expiration_days, never_expire) VALUES (?, ?, ?, ?, ?, ?, ?, ?)',
21+
(comparison_id, name, show_name, metadata.get('total_rows', 1), metadata.get('total_columns', 2),
22+
metadata.get('expiration_type', 'from_last_access'), int(metadata.get('expiration_days', 7)), 0)
2223
)
23-
c.execute('CREATE INDEX IF NOT EXISTS idx_image_positions ON image_positions(comparison_id, row_number, column_position)')
24+
25+
# Set default expiration based on metadata
26+
if metadata.get('never_expire') is not None:
27+
c.execute('UPDATE comparisons SET never_expire = ? WHERE id = ?',
28+
(1 if metadata.get('never_expire') else 0, comparison_id))
29+
30+
# If user is authenticated, associate the comparison with them and set never_expire if applicable
31+
if user_id is not None:
32+
c.execute('UPDATE comparisons SET user_id = ? WHERE id = ?', (user_id, comparison_id))
33+
c.execute('SELECT never_expire_comparisons FROM users WHERE id = ?', (user_id,))
34+
never_expire = c.fetchone()[0]
35+
if never_expire:
36+
# Only set never_expire if the user didn't explicitly choose to have it expire
37+
if metadata.get('never_expire') is None:
38+
c.execute('UPDATE comparisons SET never_expire = 1 WHERE id = ?', (comparison_id,))
39+
2440

2541
if tags:
2642
for tag in tags:
@@ -60,6 +76,11 @@ def get_comparison(comparison_id: str):
6076
c.execute('SELECT tag FROM tags WHERE comparison_id = ?', (comparison_id,))
6177
tags = [row[0] for row in c.fetchall()]
6278

79+
# Get user information if available
80+
c.execute('SELECT user_id, never_expire FROM comparisons WHERE id = ?', (comparison_id,))
81+
user_info = c.fetchone()
82+
user_id, never_expire = user_info if user_info else (None, 0)
83+
6384
return {
6485
'id': comparison[0],
6586
'name': comparison[1],
@@ -70,7 +91,9 @@ def get_comparison(comparison_id: str):
7091
'expiration_type': comparison[5] or 'from_last_access',
7192
'expiration_days': comparison[6] or 7,
7293
'created_at': comparison[7],
73-
'last_accessed': comparison[8]
94+
'last_accessed': comparison[8],
95+
'user_id': user_id,
96+
'never_expire': bool(never_expire)
7497
}
7598

7699
return None
@@ -180,16 +203,21 @@ def get_expired_comparisons(retention_days: int):
180203

181204
if 'last_accessed' in columns and 'expiration_type' in columns and 'expiration_days' in columns:
182205
# Get all comparisons with their expiration settings
183-
c.execute('SELECT id, expiration_type, expiration_days, created_at, last_accessed FROM comparisons')
206+
c.execute('SELECT id, expiration_type, expiration_days, created_at, last_accessed, never_expire FROM comparisons')
184207
comparisons = c.fetchall()
185-
208+
186209
print(f"Checking for expired comparisons with retention_days={retention_days}")
187210
print(f"Found {len(comparisons)} comparisons to check for expiration")
188211
current_time = datetime.now()
189-
for comp_id, exp_type, exp_days, created_at, last_accessed in comparisons:
212+
for comp_id, exp_type, exp_days, created_at, last_accessed, never_expire in comparisons:
190213
# Use comparison's own expiration days if available, otherwise use default
191214
days = exp_days if exp_days is not None else retention_days
192215
print(f"Checking comparison {comp_id}: type={exp_type}, days={days}, created={created_at}, last_accessed={last_accessed}")
216+
217+
# Check if this comparison is marked as never expire
218+
if never_expire:
219+
print(f" Comparison {comp_id} is marked as never expire, skipping")
220+
continue
193221

194222
if exp_type == 'from_creation' and created_at:
195223
cutoff_date = datetime.strptime(created_at, '%Y-%m-%d %H:%M:%S') + timedelta(days=days)

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ services:
1313
- PGID=1000
1414
- DB_PATH=/config/comparisons.db
1515
- UPLOADS_PATH=/data/uploads
16+
- ADMIN_INVITATION_CODE=your-secure-admin-code
1617
volumes:
1718
- ./config:/config
1819
- ./data:/data

0 commit comments

Comments
 (0)