You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[rewriter] Decouple llama rule sets and make API explicit (#2388)
This PR addresses the misleading naming and tangled organization of
rewrite rules by decoupling the `llama_rule_sets.py` module and creating
a more explicit API.
## Problem
The original `llama_rule_sets.py` contained general optimization rules
that weren't specific to Llama models, making the naming misleading. The
API didn't explicitly specify what rules were being applied, making it
unclear what optimizations were happening.
```python
# Before: Unclear what this does
from onnxscript.rewriter import llama_rule_sets
rules = llama_rule_sets.llama_p0_rule_set() # What rules? Why "llama"? What's "p0"?
```
## Solution
### 1. Created `basic_rules.py` with explicit naming
- Moved all general optimization rules to a new `basic_rules.py` module
- Used descriptive function name: `basic_optimization_rules()`
- Added comprehensive documentation for each rule
### 2. Made API explicit for fine-grained control
```python
# New explicit API - users know exactly what they're getting
from onnxscript.rewriter import basic_rules
# Use all basic optimizations (recommended default)
rules = basic_rules.basic_optimization_rules()
# Or use specific individual rules
transpose_rule = basic_rules.transpose_identity_rule
cast_rule = basic_rules.cast_identity_rule
# Or create custom rule combinations
custom_rules = basic_rules.orp.RewriteRuleSet([
basic_rules.transpose_identity_rule,
basic_rules.cast_identity_rule,
])
```
### 3. Updated default rewriter to be explicit
```python
# Before (in rewriter/__init__.py)
*llama_rule_sets.llama_p0_rule_set().rules,
# After - much clearer what's being applied
*basic_rules.basic_optimization_rules().rules,
```
### 4. Maintained backward compatibility
- `llama_rule_sets.py` now serves as a compatibility wrapper
- All existing APIs continue to work with deprecation warnings
- Existing tests pass unchanged
## Available Rules
The new API provides access to these optimization rules:
- `cast_cast_rule` - Eliminates consecutive casts
- `cast_identity_rule` - Removes redundant casts
- `expand_identity_rule` - Removes no-op expands
- `reshape_reshape_rule` - Combines consecutive reshapes
- `slice_split_rule` - Converts slices to splits when beneficial
- `transpose_identity_rule` - Removes identity transposes
- `transpose_transpose_rule` - Combines consecutive transposes
- `unsqueeze_unsqueeze_rule` - Combines consecutive unsqueezes
- `squeeze_reshape_1d_rule` - Optimizes 1D squeeze+reshape patterns
## Migration
```python
# OLD (deprecated but still works)
from onnxscript.rewriter import llama_rule_sets
rules = llama_rule_sets.llama_p0_rule_set()
# NEW (recommended)
from onnxscript.rewriter import basic_rules
rules = basic_rules.basic_optimization_rules()
```
This change resolves the core issue by making the optimizer API
explicitly specify what rules are being applied, while providing users
with fine-grained control over optimization behavior.
Fixes#2128.
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.
---------
Signed-off-by: Justin Chu <[email protected]>
Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Co-authored-by: justinchuby <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
0 commit comments