-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtextgrad_optimizer.py
More file actions
572 lines (451 loc) · 22.8 KB
/
textgrad_optimizer.py
File metadata and controls
572 lines (451 loc) · 22.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
#!/usr/bin/env python3
"""
TextGrad optimizer for context compression prompts.
This script loads observation data, extracts context and query pairs,
and uses TextGrad to optimize the contextual compression prompt.
Similar to dspy_gepa_optimizer.py but uses TextGrad for text-based optimization.
"""
import json
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import logging
import warnings
import numpy as np
from datetime import datetime
import time
import textgrad as tg
import weave
import wandb
from dotenv import load_dotenv
from tqdm import tqdm
# Suppress wandb warnings
warnings.filterwarnings("ignore", category=UserWarning, module="wandb")
os.environ["WANDB_SILENT"] = "true"
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataLoader:
"""Load and parse observation data to extract context/query pairs."""
def __init__(self, observations_dir: str, gpt4o_dir: str):
self.observations_dir = Path(observations_dir)
self.gpt4o_dir = Path(gpt4o_dir)
def extract_context_and_query(self, input_messages: List[Dict]) -> Tuple[Optional[str], Optional[str]]:
"""Extract context and query from input messages."""
context = None
query = None
for message in input_messages:
content = message.get("content", "")
# Extract context from <context>...</context> tags (usually in system message)
if not context:
context_match = re.search(r'<context>(.*?)</context>', content, re.DOTALL)
if context_match:
context = context_match.group(1).strip()
# Extract query from <query>...</query> tags (usually in user message)
if not query:
query_match = re.search(r'<query>(.*?)</query>', content, re.DOTALL)
if query_match:
query = query_match.group(1).strip()
return context, query
def load_observations(self) -> List[Dict]:
"""Load all observation files and extract relevant data."""
observations = []
# Get all JSON files except those starting with "_"
observation_files = [f for f in self.observations_dir.glob("*.json") if not f.name.startswith("_")]
logger.info(f"Loading {len(observation_files)} observation files...")
for obs_file in tqdm(observation_files, desc="Loading observations"):
try:
with open(obs_file, 'r') as f:
data = json.load(f)
# Extract context and query from input
context, query = self.extract_context_and_query(data.get("input", []))
# Skip observations that are too large to process efficiently
if context and len(context) > 25000: # Limit context to 25k chars for stability
context = context[:25000] + "... [truncated]"
logger.debug(f"Truncated context for observation {data.get('id')} from {len(context)} to 25k chars")
if context and query:
obs_data = {
"id": data.get("id"),
"context": context,
"query": query,
"original_output": data.get("output", {}).get("content", ""),
}
# Load corresponding GPT-4o success if available
gpt4o_file = self.gpt4o_dir / f"{data.get('id')}.json"
if gpt4o_file.exists():
with open(gpt4o_file, 'r') as f:
gpt4o_data = json.load(f)
obs_data["target_output"] = gpt4o_data.get("output", "")
obs_data["has_target"] = True
else:
obs_data["has_target"] = False
observations.append(obs_data)
except Exception as e:
logger.warning(f"Error processing {obs_file}: {e}")
continue
logger.info(f"Loaded {len(observations)} valid observations")
logger.info(f"Observations with GPT-4o targets: {sum(1 for obs in observations if obs['has_target'])}")
return observations
# Define the base prompt from README.md (same as DSPy optimizer)
BASE_COMPRESSION_PROMPT = """You are tasked with performing a contextual compression of a document as part of a system that processes multiple documents. Your goal is to extract only the essential parts of the given context that are relevant to a specific query.
This process helps in focusing on the most important information and reducing noise in the context.
The query might refer to multiple documents, consider how does apply to a single document in the context as multiple documents might be relevant.
Your task is to extract any parts of the context that are directly relevant to answering this question. Follow these guidelines:
1. Only extract text *AS IS* that is directly related to the query.
2. Do not modify, paraphrase, or summarize the extracted text. Copy it exactly as it appears in the context.
3. You may extract multiple separate parts if necessary.
4. If a header relates to the query, extract also the text under that section.
5. Preserve headings and subheadings when extracting.
6. If you find no relevant information in the context, output "NO_OUTPUT"."""
class ContextualCompressionModel:
"""TextGrad model wrapper for contextual compression task"""
def __init__(self, system_prompt: tg.Variable, engine):
self.system_prompt = system_prompt
self.llm_engine = engine
# Create the BlackboxLLM with system prompt
self.model = tg.BlackboxLLM(engine=engine, system_prompt=system_prompt)
def __call__(self, user_message: tg.Variable) -> tg.Variable:
"""Forward pass through the LLM with current system prompt"""
return self.model(user_message)
def parameters(self):
"""Return parameters for the optimizer"""
return [self.system_prompt]
def create_contextual_compression_loss_fn(query: str, expected_output: str) -> tg.TextLoss:
"""
Create a TextGrad loss function that evaluates contextual compression quality.
Returns textual feedback that guides optimization.
"""
evaluation_instruction = f"""Evaluate the quality of this contextual compression output.
Original Query: {query}
Expected Output: {expected_output}
Evaluation Criteria:
1. Relevance: Does the extracted content directly relate to the query?
2. Completeness: Are all relevant parts extracted without missing important information?
3. Exactness: Is the text copied exactly without modification or paraphrasing?
4. NO_OUTPUT handling: Is NO_OUTPUT used appropriately when no relevant info exists?
5. Format preservation: Are headings and structure preserved correctly?
Provide specific, actionable feedback on how to improve the system prompt for better contextual compression.
Focus on what instructions would help the model extract more relevant content exactly as it appears.
Be constructive and specific about what changes would improve performance."""
# Create and return TextGrad's TextLoss for evaluation
return tg.TextLoss(evaluation_instruction)
def evaluate_compression_simple(pred_output: str, target_output: str, has_target: bool) -> float:
"""
Simple scoring function similar to DSPy optimizer for comparison.
"""
if has_target:
if pred_output == "NO_OUTPUT" and target_output != "NO_OUTPUT":
return 0.0
elif pred_output == "NO_OUTPUT" and target_output == "NO_OUTPUT":
return 1.0
elif target_output != "NO_OUTPUT" and pred_output != "NO_OUTPUT":
# Simple length-based similarity as proxy
target_len = len(target_output.strip())
pred_len = len(pred_output.strip())
if target_len > 0:
length_ratio = min(pred_len, target_len) / max(pred_len, target_len)
return length_ratio * 0.8 # Max score 0.8 for having content
else:
return 0.0
else:
return 0.0
else:
return 0.0
class TextGradOptimizer:
"""Main optimizer class using TextGrad."""
def __init__(self, observations: List[Dict]):
self.observations = observations
self.setup_textgrad()
self.setup_weave()
def setup_textgrad(self):
"""Configure TextGrad with engines."""
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY not found in environment variables")
# Use GPT-4o-mini as the target model we want to optimize for
self.target_engine = tg.get_engine('gpt-4o-mini')
# Use GPT-4o as the critic for gradients (like DSPy example)
self.critic_engine = tg.get_engine('gpt-4o')
tg.set_backward_engine(self.critic_engine)
logger.info("TextGrad configured with GPT-4o-mini target and GPT-4o critic")
def setup_weave(self):
"""Initialize W&B Weave for experiment tracking."""
wandb_project = os.getenv("WANDB_PROJECT", "context-compression-experiments")
wandb_api_key = os.getenv("WANDB_API_KEY")
try:
if wandb_api_key and not wandb_api_key.startswith("#"): # Skip if commented out
wandb.login(key=wandb_api_key, relogin=True)
weave.init(wandb_project)
logger.info(f"Weave initialized for project: {wandb_project}")
return True
else:
logger.info("WANDB_API_KEY not found or commented out, skipping Weave initialization")
return False
except Exception as e:
logger.warning(f"Failed to initialize Weave: {e}. Continuing without tracking.")
return False
def prepare_examples(self, max_examples: Optional[int] = None) -> List[Dict]:
"""Convert observations to training examples."""
examples = []
# Prioritize examples with targets for training
observations_with_targets = [obs for obs in self.observations if obs.get("has_target", False)]
observations_without_targets = [obs for obs in self.observations if not obs.get("has_target", False)]
# Use observations with targets first, then fill with others
selected_obs = observations_with_targets
if max_examples and len(selected_obs) < max_examples:
remaining = max_examples - len(selected_obs)
selected_obs.extend(observations_without_targets[:remaining])
elif max_examples:
selected_obs = selected_obs[:max_examples]
for obs in selected_obs:
# Create user message template like DSPy optimizer
user_message = f"""Here is the context document:
<context>
{obs["context"]}
</context>
Now, consider the following query:
<query>
{obs["query"]}
</query>
Now, proceed with the task using the provided context and query."""
example = {
"user_message": user_message,
"context": obs["context"],
"query": obs["query"],
"has_target": obs["has_target"],
"target_output": obs.get("target_output", ""),
"original_output": obs["original_output"]
}
examples.append(example)
logger.info(f"Prepared {len(examples)} examples for optimization")
return examples
def evaluate_model_on_examples(self, model: ContextualCompressionModel, examples: List[Dict],
max_examples: int = None) -> List[float]:
"""Evaluate model performance on a set of examples"""
if max_examples is None:
max_examples = len(examples)
scores = []
logger.info(f"Evaluating model on {min(max_examples, len(examples))} examples...")
for i, example in enumerate(tqdm(examples[:max_examples], desc="Evaluating")):
try:
# Create TextGrad variable
user_msg = tg.Variable(
example['user_message'],
requires_grad=False,
role_description="user input for contextual compression"
)
# Get model prediction
prediction = model(user_msg)
# Simple scoring using the same logic as DSPy optimizer
pred_output = prediction.value.strip()
target_output = example.get('target_output', '').strip()
has_target = example.get('has_target', False)
score = evaluate_compression_simple(pred_output, target_output, has_target)
scores.append(score)
except Exception as e:
logger.warning(f"Error evaluating example {i+1}: {e}")
scores.append(0.0)
return scores
@weave.op()
def run_optimization(
self,
max_examples: int = 100,
num_iterations: int = 8,
batch_size: int = 5,
train_split: float = 0.8
) -> Dict:
"""Run TextGrad optimization."""
# Prepare examples
examples = self.prepare_examples(max_examples)
# Split into train/validation
split_idx = int(len(examples) * train_split)
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]
logger.info(f"Training on {len(train_examples)} examples, validating on {len(val_examples)}")
# Initialize TextGrad system prompt variable (this is what we'll optimize)
logger.info("Initializing TextGrad system prompt...")
system_prompt = tg.Variable(
BASE_COMPRESSION_PROMPT,
requires_grad=True,
role_description="system prompt for contextual compression that guides the LLM to extract relevant text exactly as it appears in the context"
)
# Initialize model
model = ContextualCompressionModel(system_prompt, self.target_engine)
# Initialize optimizer (TextGrad's TGD - Textual Gradient Descent)
optimizer = tg.TGD(parameters=[system_prompt])
# Test initial model performance
logger.info("Testing initial model performance...")
initial_scores = self.evaluate_model_on_examples(model, val_examples[:10], max_examples=10)
initial_avg = np.mean(initial_scores) if initial_scores else 0.0
logger.info(f"Initial average score: {initial_avg:.3f}")
# TextGrad optimization loop
logger.info("Running TextGrad optimization...")
logger.info("Using textual gradients to iteratively improve the system prompt...")
best_score = initial_avg
best_prompt = system_prompt.value
results = {
"iteration": [],
"score": [],
"prompt": [],
"best_score": initial_avg
}
for iteration in range(num_iterations):
logger.info(f"Iteration {iteration + 1}/{num_iterations}")
# Sample training examples for this iteration
train_subset = np.random.choice(
len(train_examples),
min(batch_size, len(train_examples)),
replace=False
)
iteration_losses = []
# Process batch of examples
for idx in train_subset:
example = train_examples[idx]
try:
# Clear gradients
optimizer.zero_grad()
# Create variables
user_msg = tg.Variable(
example['user_message'],
requires_grad=False,
role_description="user input for contextual compression"
)
# Forward pass
prediction = model(user_msg)
# Create loss function for this example
loss_fn = create_contextual_compression_loss_fn(
example['query'],
example['target_output']
)
# Calculate loss
loss = loss_fn(prediction)
iteration_losses.append(loss)
# Backward pass to compute textual gradients
loss.backward()
except Exception as e:
logger.warning(f"Error in iteration {iteration + 1}, example {idx}: {e}")
continue
if iteration_losses:
# Apply optimization step (this will update the system prompt)
optimizer.step()
logger.info(f"Processed {len(iteration_losses)} examples in this iteration")
# Evaluate on validation set
val_scores = self.evaluate_model_on_examples(model, val_examples[:10], max_examples=10)
current_score = np.mean(val_scores) if val_scores else 0.0
logger.info(f"Current validation score: {current_score:.3f}")
# Track results
results["iteration"].append(iteration + 1)
results["score"].append(current_score)
results["prompt"].append(system_prompt.value)
# Update best if improved
if current_score > best_score:
best_score = current_score
best_prompt = system_prompt.value
logger.info(f"✅ New best score: {best_score:.3f}")
results["best_score"] = best_score
else:
logger.info(f"Score: {current_score:.3f} (best: {best_score:.3f})")
# Add small delay to avoid rate limits
time.sleep(2)
# Final evaluation
final_scores = self.evaluate_model_on_examples(model, val_examples, max_examples=len(val_examples))
final_avg = np.mean(final_scores) if final_scores else 0.0
optimization_results = {
"initial_avg": initial_avg,
"final_avg": final_avg,
"best_score": best_score,
"best_prompt": best_prompt,
"train_examples": len(train_examples),
"val_examples": len(val_examples),
"num_iterations": num_iterations,
"batch_size": batch_size,
"optimization_history": results
}
logger.info(f"Final validation accuracy: {final_avg:.3f}")
logger.info(f"Best score achieved: {best_score:.3f}")
return {
"model": model,
"results": optimization_results,
"examples": examples
}
def main():
"""Main execution function."""
# Setup paths
project_root = Path(__file__).parent.parent
observations_dir = project_root / "data" / "observations"
gpt4o_dir = project_root / "data" / "gpt-4o"
# Verify data directories exist
if not observations_dir.exists():
raise FileNotFoundError(f"Observations directory not found: {observations_dir}")
if not gpt4o_dir.exists():
raise FileNotFoundError(f"GPT-4o directory not found: {gpt4o_dir}")
# Load data
logger.info("Loading observation data...")
loader = DataLoader(str(observations_dir), str(gpt4o_dir))
observations = loader.load_observations()
if not observations:
raise ValueError("No valid observations found")
# Run optimization
optimizer = TextGradOptimizer(observations)
# Start with a smaller subset for initial testing (similar to DSPy optimizer)
optimization_results = optimizer.run_optimization(
max_examples=50, # Start small
num_iterations=8,
batch_size=5,
train_split=0.8
)
# Save results with timestamp and specific naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = project_root / "data" / "results"
results_dir.mkdir(parents=True, exist_ok=True)
# Create experiment-specific subdirectory
experiment_name = f"textgrad_context_compression_{timestamp}"
experiment_dir = results_dir / experiment_name
experiment_dir.mkdir(parents=True, exist_ok=True)
# Save optimized prompt
prompt_file = experiment_dir / "optimized_prompt.txt"
with open(prompt_file, 'w') as f:
f.write(optimization_results["results"]["best_prompt"])
# Save detailed results with experiment parameters
results_with_metadata = {
"experiment_name": experiment_name,
"timestamp": timestamp,
"script": "textgrad_optimizer.py",
"optimization_type": "TextGrad TGD",
"model_target": "gpt-4o-mini",
"critic_model": "gpt-4o",
"original_prompt": BASE_COMPRESSION_PROMPT,
"optimized_prompt": optimization_results["results"]["best_prompt"],
"parameters": {
"max_examples": 50,
"num_iterations": 8,
"batch_size": 5,
"train_split": 0.8
},
"results": optimization_results["results"]
}
with open(experiment_dir / "experiment_results.json", "w") as f:
json.dump(results_with_metadata, f, indent=2)
# Also save a summary in the main results directory
with open(results_dir / f"{experiment_name}_summary.json", "w") as f:
summary = {
"experiment_name": experiment_name,
"timestamp": timestamp,
"initial_accuracy": optimization_results["results"]["initial_avg"],
"final_accuracy": optimization_results["results"]["final_avg"],
"best_accuracy": optimization_results["results"]["best_score"],
"improvement": optimization_results["results"]["best_score"] - optimization_results["results"]["initial_avg"],
"script": "textgrad_optimizer.py"
}
json.dump(summary, f, indent=2)
logger.info(f"Optimization complete! Results saved to {experiment_dir}")
logger.info(f"Experiment: {experiment_name}")
logger.info(f"Initial accuracy: {optimization_results['results']['initial_avg']:.3f}")
logger.info(f"Final accuracy: {optimization_results['results']['final_avg']:.3f}")
logger.info(f"Best accuracy: {optimization_results['results']['best_score']:.3f}")
logger.info(f"Improvement: {optimization_results['results']['best_score'] - optimization_results['results']['initial_avg']:.3f}")
if __name__ == "__main__":
main()