Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 8bb29e8

Browse files
author
Luke Hinds
authored
Merge pull request #334 from stacklok/cert-create-logic
Add logic to check if certs exist for generate_certs
2 parents cd7a3b2 + c30abfe commit 8bb29e8

File tree

6 files changed

+121
-111
lines changed

6 files changed

+121
-111
lines changed

src/codegate/ca/codegate_ca.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -368,35 +368,6 @@ def generate_certificates(self) -> Tuple[str, str]:
368368
)
369369

370370
# Print instructions for trusting the certificates
371-
logger.info(
372-
"""
373-
Certificates generated successfully in the 'certs' directory
374-
To trust these certificates:
375-
376-
On macOS:
377-
`sudo security add-trusted-cert -d -r trustRoot -k /Library/Keychains/System.keychain certs/ca.crt`
378-
379-
On Windows (PowerShell as Admin):
380-
`Import-Certificate -FilePath "certs\\ca.crt" -CertStoreLocation Cert:\\LocalMachine\\Root`
381-
382-
On Linux:
383-
`sudo cp certs/ca.crt /usr/local/share/ca-certificates/codegate.crt`
384-
`sudo update-ca-certificates`
385-
386-
For VSCode, add to settings.json:
387-
{
388-
"http.proxy": "https://localhost:8990",
389-
"http.proxyStrictSSL": true,
390-
"http.proxySupport": "on",
391-
"github.copilot.advanced": {
392-
"debug.useNodeFetcher": true,
393-
"debug.useElectronFetcher": true,
394-
"debug.testOverrideProxyUrl": "https://localhost:8990",
395-
"debug.overrideProxyUrl": "https://localhost:8990"
396-
},
397-
}
398-
"""
399-
)
400371
logger.debug("Certificates generated successfully")
401372
return server_cert, server_key
402373

@@ -422,23 +393,21 @@ def create_ssl_context(self) -> ssl.SSLContext:
422393
logger.debug("SSL context created successfully")
423394
return ssl_context
424395

425-
def ensure_certificates_exist(self) -> None:
396+
def check_certificates_exist(self) -> bool:
397+
"""Check if SSL certificates exist"""
398+
logger.debug("Checking if certificates exist fn: check_certificates_exist")
399+
return os.path.exists(
400+
os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert)
401+
) and os.path.exists(
402+
os.path.join(Config.get_config().certs_dir, Config.get_config().server_key)
403+
)
404+
405+
def ensure_certificates_exist(self) -> bool:
426406
"""Ensure SSL certificates exist, generate if they don't"""
427407
logger.debug("Ensuring certificates exist. fn ensure_certificates_exist")
428-
if not (
429-
os.path.exists(
430-
os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert)
431-
)
432-
and os.path.exists(
433-
os.path.join(Config.get_config().certs_dir, Config.get_config().server_key)
434-
)
435-
):
436-
logger.debug("Certificates not found, generating new certificates")
408+
if not self.check_certificates_exist():
409+
logger.info("Certificates not found. Generating new certificates.")
437410
self.generate_certificates()
438-
else:
439-
server_cert = Config.get_config().server_cert
440-
server_key = Config.get_config().server_key
441-
logger.debug(f"Certificates found at: {server_cert} and {server_key}.")
442411

443412
def get_ssl_context(self) -> ssl.SSLContext:
444413
"""Get SSL context with certificates"""

src/codegate/cli.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def restore_backup(backup_path: Path, backup_name: str) -> None:
448448
default=None,
449449
help="Name that will be given to the created server-key.",
450450
)
451+
@click.option(
452+
"--force-certs",
453+
is_flag=True,
454+
default=False,
455+
help="Force the generation of certificates even if they already exist.",
456+
)
451457
@click.option(
452458
"--log-level",
453459
type=click.Choice([level.value for level in LogLevel]),
@@ -466,6 +472,7 @@ def generate_certs(
466472
ca_key_name: Optional[str],
467473
server_cert_name: Optional[str],
468474
server_key_name: Optional[str],
475+
force_certs: bool,
469476
log_level: Optional[str],
470477
log_format: Optional[str],
471478
) -> None:
@@ -476,12 +483,22 @@ def generate_certs(
476483
ca_key=ca_key_name,
477484
server_cert=server_cert_name,
478485
server_key=server_key_name,
486+
force_certs=force_certs,
479487
cli_log_level=log_level,
480488
cli_log_format=log_format,
481489
)
482490
setup_logging(cfg.log_level, cfg.log_format)
491+
483492
ca = CertificateAuthority.get_instance()
484-
ca.generate_certificates()
493+
should_generate = force_certs or not ca.check_certificates_exist()
494+
495+
if should_generate:
496+
ca.generate_certificates()
497+
click.echo("Certificates generated successfully.")
498+
click.echo(f"Certificates saved to: {cfg.certs_dir}")
499+
click.echo("Make sure to add the new CA certificate to the operating system trust store.")
500+
else:
501+
click.echo("Certificates already exist. Skipping generation...")
485502

486503

487504
def main() -> None:

src/codegate/config.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Config:
5050
ca_key: str = "ca.key"
5151
server_cert: str = "server.crt"
5252
server_key: str = "server.key"
53+
force_certs: bool = False
5354

5455
# Provider URLs with defaults
5556
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
@@ -142,6 +143,7 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
142143
ca_key=config_data.get("ca_key", cls.ca_key),
143144
server_cert=config_data.get("server_cert", cls.server_cert),
144145
server_key=config_data.get("server_key", cls.server_key),
146+
force_certs=config_data.get("force_certs", cls.force_certs),
145147
prompts=prompts_config,
146148
provider_urls=provider_urls,
147149
)
@@ -187,6 +189,8 @@ def from_env(cls) -> "Config":
187189
config.server_cert = os.environ["CODEGATE_SERVER_CERT"]
188190
if "CODEGATE_SERVER_KEY" in os.environ:
189191
config.server_key = os.environ["CODEGATE_SERVER_KEY"]
192+
if "CODEGATE_FORCE_CERTS" in os.environ:
193+
config.force_certs = os.environ["CODEGATE_FORCE_CERTS"]
190194

191195
# Load provider URLs from environment variables
192196
for provider in DEFAULT_PROVIDER_URLS.keys():
@@ -216,6 +220,7 @@ def load(
216220
ca_key: Optional[str] = None,
217221
server_cert: Optional[str] = None,
218222
server_key: Optional[str] = None,
223+
force_certs: Optional[bool] = None,
219224
db_path: Optional[str] = None,
220225
) -> "Config":
221226
"""Load configuration with priority resolution.
@@ -242,6 +247,7 @@ def load(
242247
ca_key: Optional path to CA key
243248
server_cert: Optional path to server certificate
244249
server_key: Optional path to server key
250+
force_certs: Optional flag to force certificate generation
245251
db_path: Optional path to the SQLite database file
246252
247253
Returns:
@@ -289,6 +295,8 @@ def load(
289295
config.server_cert = env_config.server_cert
290296
if "CODEGATE_SERVER_KEY" in os.environ:
291297
config.server_key = env_config.server_key
298+
if "CODEGATE_FORCE_CERTS" in os.environ:
299+
config.force_certs = env_config.force_certs
292300

293301
# Override provider URLs from environment
294302
for provider, url in env_config.provider_urls.items():
@@ -325,16 +333,8 @@ def load(
325333
config.server_key = server_key
326334
if db_path is not None:
327335
config.db_path = db_path
328-
if certs_dir is not None:
329-
config.certs_dir = certs_dir
330-
if ca_cert is not None:
331-
config.ca_cert = ca_cert
332-
if ca_key is not None:
333-
config.ca_key = ca_key
334-
if server_cert is not None:
335-
config.server_cert = server_cert
336-
if server_key is not None:
337-
config.server_key = server_key
336+
if force_certs is not None:
337+
config.force_certs = force_certs
338338

339339
# Set the __config class attribute
340340
Config.__config = config

src/codegate/pipeline/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_last_user_message_idx(request: ChatCompletionRequest) -> int:
187187
if request.get("messages") is None:
188188
return -1
189189

190-
for idx, message in reversed(list(enumerate(request['messages']))):
190+
for idx, message in reversed(list(enumerate(request["messages"]))):
191191
if message.get("role", "") == "user":
192192
return idx
193193

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ async def process(
110110
return PipelineResult(request=request)
111111

112112
# Look for matches in vector DB using list of packages as filter
113-
searched_objects = await self.get_objects_from_search(
114-
user_messages, ecosystem, packages
115-
)
113+
searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages)
116114

117115
logger.info(
118116
f"Found {len(searched_objects)} matches in the database",
@@ -149,4 +147,3 @@ async def process(
149147
message["content"] = context_msg
150148

151149
return PipelineResult(request=new_request, context=context)
152-

0 commit comments

Comments
 (0)