Skip to content

gvcallen/refrax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Refrax

Refrax is a small, pragmatic library implementing the optics pattern for JAX PyTrees.

It focuses on elegant syntax and composable conditional chains to provide easy yet powerful PyTree manipulation.

Installation

Refrax can be installed using pip:

pip install refrax

Quick example

import equinox as eqx
import jax

class Model(eqx.Module):
    core: eqx.nn.Linear
    head: eqx.nn.Linear
    dropout: float

key1, key2 = jax.random.split(jax.random.key(0))
model = Model(
    core=eqx.nn.Linear(in_features=5, out_features=5, key=key1),
    head=eqx.nn.Linear(in_features=5, out_features=2, key=key2),
    dropout=0.5
)

Then we can do updates using focus:

from refrax import focus

model = focus(model).dropout.set(0.1)
model = focus(model).select("core", "head").bias.apply(lambda b: b + 1.0)

model.dropout
# 0.1

model.core.bias
# [0.646, 0.860, 0.670 , 1.277 , 0.727]

model.head.bias
# [1.634, 1.877]

One of the useful optics methods is simply traversing over a PyTree's leaves:

len(focus(model).leaves().get())
# 5

Documentation

Documentation is available here.

Related

The library uses Equinox (specifically eqx.tree_at) under the hood.

About

Chainable optics for JAX PyTrees

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages