Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add instance table along with init code. #1234

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""add installation table

Revision ID: e4c05d7591a8
Revises: 3ec2b4ab569c
Create Date: 2025-03-05 21:26:19.034319+00:00

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "e4c05d7591a8"
down_revision: Union[str, None] = "3ec2b4ab569c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
CREATE TABLE IF NOT EXISTS instance (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
created_at DATETIME NOT NULL
);
"""
)

op.execute(
"""
-- The following trigger prevents multiple insertions in the
-- instance table. It is safe since the dimension of the table
-- is fixed.

CREATE TRIGGER single_instance
BEFORE INSERT ON instance
WHEN (SELECT COUNT(*) FROM instance) >= 1
BEGIN
SELECT RAISE(FAIL, 'only one instance!');
END;
"""
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have a trigger that prevents deletion of the one row as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid adding a constraint on deletion, it becomes cumbersome to modify the table since AFAIK you cannot disable triggers temporarily in sqlite, which means you'd have to rely on DROP TABLE/CREATE TABLE to fix any sort of issue.

Additionally, it's hard to delete that record by mistake since there's no delete_instance routine in DbRecorder and initialization takes care of recreating it.

I'm inclined to let this as is, but I'm not religious about it, we can add it now or later anyway.


# Finish transaction
op.execute("COMMIT;")


def downgrade() -> None:
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
DROP TABLE instance;
"""
)

# Finish transaction
op.execute("COMMIT;")
7 changes: 6 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from codegate.ca.codegate_ca import CertificateAuthority
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.db.connection import init_db_sync, init_session_if_not_exists
from codegate.db.connection import (
init_db_sync,
init_session_if_not_exists,
init_instance,
)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
from codegate.providers import crud as provendcrud
Expand Down Expand Up @@ -318,6 +322,7 @@ def serve( # noqa: C901
logger = structlog.get_logger("codegate").bind(origin="cli")

init_db_sync(cfg.db_path)
init_instance(cfg.db_path)
init_session_if_not_exists(cfg.db_path)

# Check certificates and create CA if necessary
Expand Down
46 changes: 44 additions & 2 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import datetime
import json
import sqlite3
import uuid
Expand All @@ -23,6 +24,7 @@
Alert,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
Instance,
IntermediatePromptWithOutputUsageAlerts,
MuxRule,
Output,
Expand Down Expand Up @@ -596,6 +598,27 @@ async def delete_persona(self, persona_id: str) -> None:
conditions = {"id": persona_id}
await self._execute_with_no_return(sql, conditions)

async def init_instance(self) -> None:
"""
Initializes instance details in the database.
"""
sql = text(
"""
INSERT INTO instance (id, created_at)
VALUES (:id, :created_at)
"""
)

try:
instance = Instance(
id=str(uuid.uuid4()),
created_at=datetime.datetime.now(datetime.timezone.utc),
)
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Instance already initialized.")


class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
Expand Down Expand Up @@ -1098,6 +1121,13 @@ async def get_all_personas(self) -> List[Persona]:
personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True)
return personas

async def get_instance(self) -> Instance:
"""
Get the details of the instance.
"""
sql = text("SELECT id, created_at FROM instance")
return await self._execute_select_pydantic_model(Instance, sql)


class DbTransaction:
def __init__(self):
Expand Down Expand Up @@ -1148,8 +1178,6 @@ def init_db_sync(db_path: Optional[str] = None):


def init_session_if_not_exists(db_path: Optional[str] = None):
import datetime

db_reader = DbReader(db_path)
sessions = asyncio.run(db_reader.get_sessions())
# If there are no sessions, create a new one
Expand All @@ -1169,5 +1197,19 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
logger.info("Session in DB initialized successfully.")


def init_instance(db_path: Optional[str] = None):
db_reader = DbReader(db_path)
instance = asyncio.run(db_reader.get_instance())
# Initialize instance if not already initialized.
if not instance:
db_recorder = DbRecorder(db_path)
try:
asyncio.run(db_recorder.init_instance())
except Exception as e:
logger.error(f"Failed to initialize instance in DB: {e}")
raise
logger.info("Instance initialized successfully.")


if __name__ == "__main__":
init_db_sync()
5 changes: 5 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class Session(BaseModel):
last_update: datetime.datetime


class Instance(BaseModel):
id: str
created_at: datetime.datetime


# Models for select queries


Expand Down