Skip to content

Commit 3b8f003

Browse files
corbtclaude
andcommitted
feat: Add exclude parameter to S3 pull functionality
- Add typesafe exclude parameter to filter out directories during S3 sync - Valid exclude options: "checkpoints", "logs", "trajectories" - Update pull_model_trajectories helper to exclude checkpoints and logs by default - Enables pulling only trajectories for analysis without downloading large model weights This is useful when analyzing trajectory data locally without needing the model checkpoints, which can be very large and consume significant bandwidth. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 7ed50bf commit 3b8f003

4 files changed

Lines changed: 125 additions & 3 deletions

File tree

src/art/local/backend.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers.models.auto.tokenization_auto import AutoTokenizer
2323
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2424
from tqdm import auto as tqdm
25-
from typing import AsyncIterator, cast
25+
from typing import AsyncIterator, cast, Literal
2626
import wandb
2727
from wandb.sdk.wandb_run import Run
2828
import weave
@@ -443,6 +443,7 @@ async def _experimental_pull_from_s3(
443443
prefix: str | None = None,
444444
verbose: bool = False,
445445
delete: bool = False,
446+
exclude: list[Literal["checkpoints", "logs", "trajectories"]] | None = None,
446447
) -> None:
447448
"""Download the model directory from S3 into local Backend storage. Right now this can be used to pull trajectory logs for processing or model checkpoints.
448449
Args:
@@ -452,7 +453,19 @@ async def _experimental_pull_from_s3(
452453
prefix: The prefix to pull from S3. If None, the model name will be used.
453454
verbose: Whether to print verbose output.
454455
delete: Whether to delete the local model directory.
456+
exclude: List of directories to exclude from sync. Valid options: "checkpoints", "logs", "trajectories".
455457
"""
458+
# Validate exclude options
459+
validated_exclude = None
460+
if exclude:
461+
validated_exclude = []
462+
for item in exclude:
463+
if item not in ["checkpoints", "logs", "trajectories"]:
464+
raise ValueError(
465+
f"Invalid exclude option: {item}. Valid options are: checkpoints, logs, trajectories"
466+
)
467+
validated_exclude.append(item) # type: ignore
468+
456469
await pull_model_from_s3(
457470
model_name=model.name,
458471
project=model.project,
@@ -462,6 +475,7 @@ async def _experimental_pull_from_s3(
462475
verbose=verbose,
463476
delete=delete,
464477
art_path=self._path,
478+
exclude=validated_exclude, # type: ignore
465479
)
466480

467481
async def _experimental_push_to_s3(

src/art/utils/benchmarking/load_trajectories.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ async def pull_model_trajectories(model: ArtModel) -> None:
238238
model,
239239
s3_bucket=bucket,
240240
verbose=True,
241+
exclude=["checkpoints", "logs"],
241242
)
242243

243244
print("Finished pulling trajectories.", flush=True)

src/art/utils/s3.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
from asyncio.subprocess import DEVNULL
66
import tempfile
7-
from typing import Optional, Sequence
7+
from typing import Optional, Sequence, Literal
88
import zipfile
99

1010
from art.errors import ForbiddenBucketCreationError
@@ -17,6 +17,8 @@
1717

1818
__all__: Sequence[str] = ("s3_sync",)
1919

20+
ExcludableOption = Literal["checkpoints", "logs", "trajectories"]
21+
2022

2123
class S3SyncError(RuntimeError):
2224
"""Raised when the underlying *aws s3 sync* command exits with a non‑zero status."""
@@ -67,6 +69,7 @@ async def s3_sync(
6769
profile: Optional[str] = None,
6870
verbose: bool = False,
6971
delete: bool = False,
72+
exclude: list[ExcludableOption] | None = None,
7073
) -> None:
7174
"""Synchronise *source* and *destination* using the AWS CLI.
7275
@@ -82,6 +85,7 @@ async def s3_sync(
8285
profile: Optional AWS profile name to pass to the CLI.
8386
verbose: When *True*, the output of the AWS CLI is streamed to the
8487
calling process; otherwise it is suppressed.
88+
exclude: List of directories to exclude from sync.
8589
8690
Raises:
8791
S3SyncError: If the *aws s3 sync* command exits with a non‑zero status.
@@ -100,6 +104,12 @@ async def s3_sync(
100104

101105
if delete:
102106
cmd.append("--delete")
107+
108+
# Add exclude patterns for each excluded directory
109+
if exclude:
110+
for excluded_dir in exclude:
111+
cmd.extend(["--exclude", f"{excluded_dir}/*"])
112+
103113
cmd += [source, destination]
104114

105115
# Suppress output unless verbose mode is requested.
@@ -156,6 +166,7 @@ async def pull_model_from_s3(
156166
verbose: bool = False,
157167
delete: bool = False,
158168
art_path: str | None = None,
169+
exclude: list[ExcludableOption] | None = None,
159170
) -> str:
160171
"""Pull a model from S3 to the local directory.
161172
@@ -169,6 +180,7 @@ async def pull_model_from_s3(
169180
calling process; otherwise it is suppressed.
170181
delete: When *True*, delete the local model directory if it exists.
171182
art_path: The path to the ART directory.
183+
exclude: List of directories to exclude from sync.
172184
173185
Returns:
174186
The local directory path.
@@ -196,7 +208,7 @@ async def pull_model_from_s3(
196208
prefix=prefix,
197209
)
198210
await ensure_bucket_exists(s3_bucket)
199-
await s3_sync(s3_path, local_dir, verbose=verbose, delete=delete)
211+
await s3_sync(s3_path, local_dir, verbose=verbose, delete=delete, exclude=exclude)
200212

201213
# After pulling, migrate to new structure if needed
202214
if step is not None:

test_s3_exclude.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
"""Test script to verify S3 exclude functionality."""
3+
4+
import asyncio
5+
import os
6+
from dotenv import load_dotenv
7+
from art.model import Model
8+
from art.local import LocalBackend
9+
from art.utils.benchmarking.load_trajectories import pull_model_trajectories
10+
11+
# Load environment variables
12+
load_dotenv()
13+
14+
async def test_s3_exclude():
15+
"""Test pulling only trajectories from S3, excluding checkpoints and logs."""
16+
17+
# Create a test model
18+
model = Model(
19+
name="email-agent-216-4",
20+
project="email_agent"
21+
)
22+
23+
# Test 1: Use the helper function (should exclude checkpoints and logs by default now)
24+
print("Test 1: Using pull_model_trajectories helper (excludes checkpoints and logs)")
25+
print("-" * 60)
26+
27+
try:
28+
await pull_model_trajectories(model)
29+
print("✓ Helper function completed successfully")
30+
except Exception as e:
31+
print(f"✗ Helper function failed: {e}")
32+
33+
# Test 2: Direct backend call with exclude parameter
34+
print("\nTest 2: Direct backend call with exclude=['checkpoints', 'logs']")
35+
print("-" * 60)
36+
37+
with LocalBackend() as backend:
38+
bucket = os.getenv("BACKUP_BUCKET")
39+
if not bucket:
40+
print("✗ BACKUP_BUCKET environment variable not set")
41+
return
42+
43+
try:
44+
await backend._experimental_pull_from_s3(
45+
model,
46+
s3_bucket=bucket,
47+
verbose=True,
48+
exclude=["checkpoints", "logs"],
49+
)
50+
print("✓ Direct backend call completed successfully")
51+
except Exception as e:
52+
print(f"✗ Direct backend call failed: {e}")
53+
54+
# Test 3: Verify the pulled structure
55+
print("\nTest 3: Verifying pulled structure")
56+
print("-" * 60)
57+
58+
# Check current directory .art folder first
59+
art_path = ".art"
60+
model_path = f"{art_path}/{model.project}/models/{model.name}"
61+
62+
if os.path.exists(model_path):
63+
print(f"Model directory exists: {model_path}")
64+
65+
# Check what was pulled
66+
trajectories_path = os.path.join(model_path, "trajectories")
67+
checkpoints_path = os.path.join(model_path, "checkpoints")
68+
logs_path = os.path.join(model_path, "logs")
69+
70+
trajectories_exists = os.path.exists(trajectories_path)
71+
checkpoints_exists = os.path.exists(checkpoints_path)
72+
logs_exists = os.path.exists(logs_path)
73+
74+
print(f"{'✓' if trajectories_exists else '✗'} Trajectories directory exists: {trajectories_exists}")
75+
76+
# Check if excluded directories are empty
77+
checkpoints_empty = not checkpoints_exists or len(os.listdir(checkpoints_path)) == 0
78+
logs_empty = not logs_exists or len(os.listdir(logs_path)) == 0
79+
80+
print(f"{'✓' if checkpoints_empty else '✗'} Checkpoints directory empty: {checkpoints_empty}")
81+
print(f"{'✓' if logs_empty else '✗'} Logs directory empty: {logs_empty}")
82+
83+
# List contents
84+
print("\nDirectory contents:")
85+
for item in os.listdir(model_path):
86+
item_path = os.path.join(model_path, item)
87+
if os.path.isdir(item_path):
88+
print(f" - {item}/ ({len(os.listdir(item_path))} items)")
89+
else:
90+
print(f" - {item}")
91+
else:
92+
print(f"✗ Model directory not found: {model_path}")
93+
94+
if __name__ == "__main__":
95+
asyncio.run(test_s3_exclude())

0 commit comments

Comments
 (0)