Skip to content

Add LlamaSafetyOptimizer for Runtime Safety Checks and Performance Optimization #1326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/safety_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from safety.wrapper import LlamaSafetyOptimizer
from llama import Transformer, ModelArgs

def main():
# Initialize model
params = ModelArgs(
dim=512, # Smaller for testing
n_layers=8,
n_heads=8,
vocab_size=1000
)
model = Transformer(params)

# Initialize safety wrapper
safe_model = LlamaSafetyOptimizer(model)

# Test input
input_ids = torch.randint(0, 1000, (1, 512))

# Run with safety checks
result = safe_model.safe_forward(input_ids, start_pos=0)

# Print results
print("\nSafety Check Results:")
print(f"Is Safe: {result['is_safe']}")
print(f"\nSafety Metrics:")
for metric, value in result['safety_metrics'].items():
print(f"{metric}: {value}")

print("\nPerformance Metrics:")
for metric, value in result['performance'].items():
print(f"{metric}: {value}")

if __name__ == "__main__":
main()
25 changes: 25 additions & 0 deletions safety/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Llama Safety Optimizer

This module provides safety and optimization tools for Llama models.

## Features

- Runtime safety checks for model outputs
- Memory usage tracking and optimization
- Performance monitoring
- Automatic batch size optimization

## Usage

```python
from safety.wrapper import LlamaSafetyOptimizer

# Initialize with your model
optimizer = LlamaSafetyOptimizer(model)

# Use safe forward pass
result = optimizer.safe_forward(input_ids, start_pos=0)

# Get performance metrics
metrics = optimizer.get_performance_summary()
```
Empty file added safety/__init__.py
Empty file.
164 changes: 164 additions & 0 deletions safety/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# safety/wrapper.py
import torch
import time
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
import psutil
import gc

@dataclass
class PerformanceMetrics:
inference_time: float
memory_used: int
peak_memory: int
gpu_utilization: float

class LlamaSafetyOptimizer:
def __init__(
self,
model: torch.nn.Module,
safety_threshold: float = 0.8,
enable_memory_tracking: bool = True
):
self.model = model
self.safety_threshold = safety_threshold
self.enable_memory_tracking = enable_memory_tracking
self.performance_history = []

def _track_memory(self) -> Dict[str, int]:
"""Track current memory usage"""
if not self.enable_memory_tracking:
return {}

memory_stats = {
'cpu_percent': psutil.cpu_percent(),
'ram_used': psutil.Process().memory_info().rss // 1024 // 1024
}

if torch.cuda.is_available():
memory_stats.update({
'gpu_used': torch.cuda.memory_allocated() // 1024 // 1024,
'gpu_cached': torch.cuda.memory_reserved() // 1024 // 1024
})

return memory_stats

def _check_safety(self, logits: torch.Tensor) -> Tuple[bool, Dict]:
"""Perform safety checks on model outputs"""
with torch.no_grad():
# Example safety checks - expand based on your needs
max_value = torch.max(logits).item()
mean_value = torch.mean(logits).item()
std_value = torch.std(logits).item()

safety_metrics = {
'max_activation': max_value,
'mean_activation': mean_value,
'std_activation': std_value,
'outlier_ratio': torch.sum(torch.abs(logits) > 5).item() / logits.numel()
}

# Simple safety check - can be made more sophisticated
is_safe = (
safety_metrics['outlier_ratio'] < 0.1 and
abs(safety_metrics['mean_activation']) < 2
)

return is_safe, safety_metrics

def optimize_batch_size(self, start_size: int = 1, max_size: int = 32) -> int:
"""Find optimal batch size based on memory constraints"""
current_size = start_size

while current_size < max_size:
try:
# Create dummy batch
dummy_input = torch.randint(
0, 1000, (current_size, 512), device=self.model.device
)

# Test forward pass
with torch.no_grad():
_ = self.model(dummy_input, start_pos=0)

# If successful, try larger batch
current_size *= 2
torch.cuda.empty_cache()
gc.collect()

except RuntimeError as e:
# Memory error - return last successful size
return current_size // 2

return max_size

def safe_forward(
self,
input_ids: torch.Tensor,
start_pos: int,
optimize_memory: bool = True
) -> Dict:
"""Forward pass with safety checks and performance monitoring"""
start_time = time.time()
initial_memory = self._track_memory()

# Optimize batch size if requested
if optimize_memory:
batch_size = input_ids.shape[0]
optimal_batch_size = self.optimize_batch_size(max_size=batch_size)

if optimal_batch_size < batch_size:
# Split into smaller batches
outputs = []
for i in range(0, batch_size, optimal_batch_size):
batch = input_ids[i:i + optimal_batch_size]
output = self.model(batch, start_pos + i)
outputs.append(output)
output = torch.cat(outputs, dim=0)
else:
output = self.model(input_ids, start_pos)
else:
output = self.model(input_ids, start_pos)

# Perform safety checks
is_safe, safety_metrics = self._check_safety(output)

# Track performance metrics
end_time = time.time()
final_memory = self._track_memory()

performance = PerformanceMetrics(
inference_time=end_time - start_time,
memory_used=final_memory.get('ram_used', 0),
peak_memory=max(
initial_memory.get('ram_used', 0),
final_memory.get('ram_used', 0)
),
gpu_utilization=final_memory.get('gpu_used', 0)
)

self.performance_history.append(performance)

return {
'output': output if is_safe else None,
'is_safe': is_safe,
'safety_metrics': safety_metrics,
'performance': performance.__dict__,
'memory_tracking': final_memory
}

def get_performance_summary(self) -> Dict:
"""Get summary statistics of model performance"""
if not self.performance_history:
return {}

avg_inference_time = sum(p.inference_time for p in self.performance_history) / len(self.performance_history)
avg_memory_used = sum(p.memory_used for p in self.performance_history) / len(self.performance_history)
peak_memory = max(p.peak_memory for p in self.performance_history)

return {
'average_inference_time': avg_inference_time,
'average_memory_used': avg_memory_used,
'peak_memory_usage': peak_memory,
'total_inferences': len(self.performance_history)
}
Empty file added tests/__init__.py
Empty file.
Empty file added tests/safety/__init__.py
Empty file.
37 changes: 37 additions & 0 deletions tests/safety/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
from safety.wrapper import LlamaSafetyOptimizer, PerformanceMetrics

def test_safety_optimizer_initialization():
model = torch.nn.Linear(10, 10) # Dummy model for testing
optimizer = LlamaSafetyOptimizer(model)
assert optimizer.safety_threshold == 0.8
assert optimizer.enable_memory_tracking == True

def test_memory_tracking():
model = torch.nn.Linear(10, 10)
optimizer = LlamaSafetyOptimizer(model)
memory_stats = optimizer._track_memory()
assert 'cpu_percent' in memory_stats
assert 'ram_used' in memory_stats

def test_safety_checks():
model = torch.nn.Linear(10, 10)
optimizer = LlamaSafetyOptimizer(model)
test_tensor = torch.randn(1, 10)
is_safe, metrics = optimizer._check_safety(test_tensor)
assert isinstance(is_safe, bool)
assert 'max_activation' in metrics
assert 'mean_activation' in metrics
assert 'std_activation' in metrics

def test_safe_forward():
model = torch.nn.Linear(10, 10)
optimizer = LlamaSafetyOptimizer(model)
input_tensor = torch.randn(1, 10)
result = optimizer.safe_forward(input_tensor, start_pos=0)

assert 'output' in result
assert 'is_safe' in result
assert 'safety_metrics' in result
assert 'performance' in result