diff --git a/.github/workflows/import_packages.yml b/.github/workflows/import_packages.yml index 07884eb2..eac7e0ad 100644 --- a/.github/workflows/import_packages.yml +++ b/.github/workflows/import_packages.yml @@ -70,9 +70,9 @@ jobs: export BACKUP_FOLDER=backup # Conditionally export the variables only if artifact download is enabled if [ "${{ github.event.inputs.enable_artifact_download }}" == "true" ]; then - python scripts/import_packages.py + python scripts/import_packages.py --jsonl-dir /tmp/jsonl-files/ else - python scripts/import_packages.py --restore-backup False + python scripts/import_packages.py --restore-backup False --jsonl-dir /tmp/jsonl-files/ fi - name: 'Upload Backup Files' diff --git a/scripts/import_packages.py b/scripts/import_packages.py index 15546232..3575eb1e 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -14,7 +14,7 @@ class PackageImporter: - def __init__(self, take_backup=True, restore_backup=True): + def __init__(self, jsonl_dir='data', take_backup=True, restore_backup=True): self.take_backup_flag = take_backup self.restore_backup_flag = restore_backup @@ -29,9 +29,9 @@ def __init__(self, take_backup=True, restore_backup=True): ) ) self.json_files = [ - "data/archived.jsonl", - "data/deprecated.jsonl", - "data/malicious.jsonl", + os.path.join(jsonl_dir, "archived.jsonl"), + os.path.join(jsonl_dir, "deprecated.jsonl"), + os.path.join(jsonl_dir, "malicious.jsonl"), ] self.client.connect() self.inference_engine = LlamaCppInferenceEngine() @@ -149,9 +149,16 @@ async def run_import(self): help="Specify whether to restore a backup before " "data import (True or False). Default is True.", ) + parser.add_argument( + "--jsonl-dir", + type=str, + default="data", + help="Directory containing JSONL files. Default is 'data'." + ) args = parser.parse_args() - importer = PackageImporter(take_backup=args.take_backup, restore_backup=args.restore_backup) + importer = PackageImporter(jsonl_dir=args.jsonl_dir, take_backup=args.take_backup, + restore_backup=args.restore_backup) asyncio.run(importer.run_import()) try: assert importer.client.is_live()