Skip to content

Commit 29f99aa

Browse files
Classes separation into a different file
1 parent c63788e commit 29f99aa

File tree

9 files changed

+197
-225
lines changed

9 files changed

+197
-225
lines changed

src/codegate/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_VERSION = "dev"
1111
_DESC = "CodeGate - A Generative AI security gateway."
1212

13+
1314
def __get_version_and_description() -> tuple[str, str]:
1415
try:
1516
version = metadata.version("codegate")
@@ -19,6 +20,7 @@ def __get_version_and_description() -> tuple[str, str]:
1920
description = _DESC
2021
return version, description
2122

23+
2224
__version__, __description__ = __get_version_and_description()
2325

2426
__all__ = ["Config", "ConfigurationError", "LogFormat", "LogLevel", "setup_logging"]

src/codegate/db/connection.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import structlog
88
from alembic import command as alembic_command
99
from alembic.config import Config as AlembicConfig
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, ValidationError
1111
from sqlalchemy import CursorResult, TextClause, text
1212
from sqlalchemy.exc import OperationalError
1313
from sqlalchemy.ext.asyncio import create_async_engine
@@ -233,7 +233,12 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
233233
logger.error(f"Failed to record context: {context}.", error=str(e))
234234

235235
async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
236-
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
236+
try:
237+
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
238+
except ValidationError as e:
239+
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
240+
return None
241+
237242
sql = text(
238243
"""
239244
INSERT INTO workspaces (id, name)
@@ -365,7 +370,7 @@ async def get_workspace_by_name(self, name: str) -> List[Workspace]:
365370
)
366371
conditions = {"name": name}
367372
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
368-
return workspaces
373+
return workspaces[0] if workspaces else None
369374

370375
async def get_sessions(self) -> List[Session]:
371376
sql = text(

src/codegate/db/models.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import datetime
2+
import re
23
from typing import Any, Optional
34

4-
import pydantic
5+
from pydantic import BaseModel, field_validator
56

67

7-
class Alert(pydantic.BaseModel):
8+
class Alert(BaseModel):
89
id: Any
910
prompt_id: Any
1011
code_snippet: Optional[Any]
@@ -14,14 +15,14 @@ class Alert(pydantic.BaseModel):
1415
timestamp: Any
1516

1617

17-
class Output(pydantic.BaseModel):
18+
class Output(BaseModel):
1819
id: Any
1920
prompt_id: Any
2021
timestamp: Any
2122
output: Any
2223

2324

24-
class Prompt(pydantic.BaseModel):
25+
class Prompt(BaseModel):
2526
id: Any
2627
timestamp: Any
2728
provider: Optional[Any]
@@ -30,7 +31,7 @@ class Prompt(pydantic.BaseModel):
3031
workspace_id: Optional[str]
3132

3233

33-
class Setting(pydantic.BaseModel):
34+
class Setting(BaseModel):
3435
id: Any
3536
ip: Optional[Any]
3637
port: Optional[Any]
@@ -39,12 +40,19 @@ class Setting(pydantic.BaseModel):
3940
other_settings: Optional[Any]
4041

4142

42-
class Workspace(pydantic.BaseModel):
43+
class Workspace(BaseModel):
4344
id: str
4445
name: str
4546

47+
@field_validator("name", mode="plain")
48+
@classmethod
49+
def name_must_be_alphanumeric(cls, value):
50+
if not re.match(r"^[a-zA-Z0-9_-]+$", value):
51+
raise ValueError("name must be alphanumeric and can only contain _ and -")
52+
return value
4653

47-
class Session(pydantic.BaseModel):
54+
55+
class Session(BaseModel):
4856
id: str
4957
active_workspace_id: str
5058
last_update: datetime.datetime
@@ -53,7 +61,7 @@ class Session(pydantic.BaseModel):
5361
# Models for select queries
5462

5563

56-
class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel):
64+
class GetAlertsWithPromptAndOutputRow(BaseModel):
5765
id: Any
5866
prompt_id: Any
5967
code_snippet: Optional[Any]
@@ -70,7 +78,7 @@ class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel):
7078
output_timestamp: Optional[Any]
7179

7280

73-
class GetPromptWithOutputsRow(pydantic.BaseModel):
81+
class GetPromptWithOutputsRow(BaseModel):
7482
id: Any
7583
timestamp: Any
7684
provider: Optional[Any]
@@ -81,13 +89,13 @@ class GetPromptWithOutputsRow(pydantic.BaseModel):
8189
output_timestamp: Optional[Any]
8290

8391

84-
class WorkspaceActive(pydantic.BaseModel):
92+
class WorkspaceActive(BaseModel):
8593
id: str
8694
name: str
8795
active_workspace_id: Optional[str]
8896

8997

90-
class ActiveWorkspace(pydantic.BaseModel):
98+
class ActiveWorkspace(BaseModel):
9199
id: str
92100
name: str
93101
session_id: str

src/codegate/pipeline/extract_snippets/extract_snippets.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,8 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
124124
lang = None
125125

126126
#  just correct the typescript exception
127-
lang_map = {
128-
"typescript": "javascript"
129-
}
130-
lang = lang_map.get(lang, lang)
127+
lang_map = {"typescript": "javascript"}
128+
lang = lang_map.get(lang, lang)
131129
snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))
132130

133131
return snippets

src/codegate/pipeline/secrets/secrets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ async def process_chunk(
366366
if match:
367367
# Found a complete marker, process it
368368
encrypted_value = match.group(1)
369-
if encrypted_value.startswith('$'):
369+
if encrypted_value.startswith("$"):
370370
encrypted_value = encrypted_value[1:]
371371
original_value = input_context.sensitive.manager.get_original_value(
372372
encrypted_value,
+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import datetime
2+
from typing import Optional, Tuple
3+
4+
from codegate.db.connection import DbReader, DbRecorder
5+
from codegate.db.models import Session, Workspace
6+
7+
8+
class WorkspaceCrud:
9+
10+
def __init__(self):
11+
self._db_reader = DbReader()
12+
13+
async def add_workspace(self, new_workspace_name: str) -> bool:
14+
"""
15+
Add a workspace
16+
17+
Args:
18+
name (str): The name of the workspace
19+
"""
20+
db_recorder = DbRecorder()
21+
workspace_created = await db_recorder.add_workspace(new_workspace_name)
22+
return bool(workspace_created)
23+
24+
async def get_workspaces(self):
25+
"""
26+
Get all workspaces
27+
"""
28+
return await self._db_reader.get_workspaces()
29+
30+
async def _is_workspace_active_or_not_exist(
31+
self, workspace_name: str
32+
) -> Tuple[bool, Optional[Session], Optional[Workspace]]:
33+
"""
34+
Check if the workspace is active
35+
36+
Will return:
37+
- True if the workspace was activated
38+
- False if the workspace is already active or does not exist
39+
"""
40+
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
41+
if not selected_workspace:
42+
return True, None, None
43+
44+
sessions = await self._db_reader.get_sessions()
45+
# The current implementation expects only one active session
46+
if len(sessions) != 1:
47+
raise RuntimeError("Something went wrong. No active session found.")
48+
49+
session = sessions[0]
50+
if session.active_workspace_id == selected_workspace.id:
51+
return True, None, None
52+
return False, session, selected_workspace
53+
54+
async def activate_workspace(self, workspace_name: str) -> bool:
55+
"""
56+
Activate a workspace
57+
58+
Will return:
59+
- True if the workspace was activated
60+
- False if the workspace is already active or does not exist
61+
"""
62+
is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name)
63+
if is_active:
64+
return False
65+
66+
session.active_workspace_id = workspace.id
67+
session.last_update = datetime.datetime.now(datetime.timezone.utc)
68+
db_recorder = DbRecorder()
69+
await db_recorder.update_session(session)
70+
return True
71+
72+
73+
class WorkspaceCommands:
74+
75+
def __init__(self):
76+
self.workspace_crud = WorkspaceCrud()
77+
self.commands = {
78+
"list": self._list_workspaces,
79+
"add": self._add_workspace,
80+
"activate": self._activate_workspace,
81+
}
82+
83+
async def _list_workspaces(self, *args) -> str:
84+
"""
85+
List all workspaces
86+
"""
87+
workspaces = await self.workspace_crud.get_workspaces()
88+
respond_str = ""
89+
for workspace in workspaces:
90+
respond_str += f"- {workspace.name}"
91+
if workspace.active_workspace_id:
92+
respond_str += " **(active)**"
93+
respond_str += "\n"
94+
return respond_str
95+
96+
async def _add_workspace(self, *args) -> str:
97+
"""
98+
Add a workspace
99+
"""
100+
if args is None or len(args) == 0:
101+
return "Please provide a name. Use `codegate-workspace add your_workspace_name`"
102+
103+
new_workspace_name = args[0]
104+
if not new_workspace_name:
105+
return "Please provide a name. Use `codegate-workspace add your_workspace_name`"
106+
107+
workspace_created = await self.workspace_crud.add_workspace(new_workspace_name)
108+
if not workspace_created:
109+
return (
110+
"Something went wrong. Workspace could not be added.\n"
111+
"1. Check if the name is alphanumeric and only contains dashes, and underscores.\n"
112+
"2. Check if the workspace already exists."
113+
)
114+
return f"Workspace **{new_workspace_name}** has been added"
115+
116+
async def _activate_workspace(self, *args) -> str:
117+
"""
118+
Activate a workspace
119+
"""
120+
if args is None or len(args) == 0:
121+
return "Please provide a name. Use `codegate-workspace activate workspace_name`"
122+
123+
workspace_name = args[0]
124+
if not workspace_name:
125+
return "Please provide a name. Use `codegate-workspace activate workspace_name`"
126+
127+
was_activated = await self.workspace_crud.activate_workspace(workspace_name)
128+
if not was_activated:
129+
return (
130+
f"Workspace **{workspace_name}** does not exist or was already active. "
131+
f"Use `codegate-workspace add {workspace_name}` to add it"
132+
)
133+
return f"Workspace **{workspace_name}** has been activated"
134+
135+
async def execute(self, command: str, *args) -> str:
136+
"""
137+
Execute the given command
138+
139+
Args:
140+
command (str): The command to execute
141+
"""
142+
command_to_execute = self.commands.get(command)
143+
if command_to_execute is not None:
144+
return await command_to_execute(*args)
145+
else:
146+
return "Command not found"
147+
148+
async def parse_execute_cmd(self, last_user_message: str) -> str:
149+
"""
150+
Parse the last user message and execute the command
151+
152+
Args:
153+
last_user_message (str): The last user message
154+
"""
155+
command_and_args = last_user_message.lower().split("codegate-workspace ")[1]
156+
command, *args = command_and_args.split(" ")
157+
return await self.execute(command, *args)

0 commit comments

Comments
 (0)