Skip to content

Commit 29d0c0d

Browse files
author
The tunix Authors
committed
Merge pull request #772 from google:lance-fix1
PiperOrigin-RevId: 834528724
2 parents ad88ef5 + 26115c5 commit 29d0c0d

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

.github/workflows/tpu-tests.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
runs-on: [linux-x86-ct5lp-224-8tpu]
3838
environment: testing
3939
container:
40-
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:jax0.7.1_rev1
40+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
4141
options: --privileged
4242
env:
4343
CLOUD_TPU_ACCELERATOR: v5e-8
@@ -221,8 +221,13 @@ jobs:
221221
env:
222222
HF_TOKEN: ${{ secrets.HF_TOKEN }}
223223
run: |
224+
# Reinstall Tunix with prod dependencies
225+
pip install -e .[prod] --force-reinstall
226+
224227
# Loading tfds requires tensorflow.
225228
pip install tensorflow
229+
230+
export JAX_PLATFORMS=tpu,cpu
226231
./tests/sft/sft_tpu_smoke_test.sh
227232
- name: Run tunix cli tests
228233
env:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "google-tunix"
3-
version = "0.1.3"
3+
version = "0.1.4"
44
authors = [
55
{ name = "Tunix Developers", email = "[email protected]" },
66
]
@@ -31,7 +31,7 @@ dependencies = [
3131
"omegaconf", # CLI config
3232
"pylatexenc", # Eval result parsing
3333
"python-dotenv", # Huggingface API key
34-
"qwix<=0.1.1", # Newer version of qwix depends on unreleased flax beyond 0.12.0
34+
"qwix",
3535
"sentencepiece",
3636
"sympy", # Eval result parsing
3737
"tensorflow_datasets",

tests/sft/dpo/orpo_trainer_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import jax.numpy as jnp
2222
import numpy as np
2323
import optax
24-
from tunix.rl import common
2524
from tunix.sft.dpo import dpo_trainer as orpo_lib
2625
from tunix.tests import test_common as tc
2726

@@ -231,21 +230,28 @@ def test_orpo_loss_fn(self):
231230
np.random.seed(0)
232231
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
233232
# Use negative log probs (as they should be in reality)
234-
per_token_logps = -np.abs(np.random.normal(2, 1, size=(8, 4)))
233+
per_token_logps = -np.abs(np.random.rand(8, 4))
234+
completion_mask = np.ones((8, 4))
235+
token_logps = (per_token_logps * completion_mask).sum(axis=-1)
236+
237+
batch_size = token_logps.shape[0]
238+
chosen_logps = token_logps[: batch_size // 2]
239+
rejected_logps = token_logps[batch_size // 2 :]
240+
235241
train_example = orpo_lib.TrainExample(
236242
input_ids=jnp.arange(0, 32).reshape(8, 4),
237243
positions=jnp.ones((8, 4)),
238244
attention_mask=jnp.ones((8, 4, 4)),
239245
ref_chosen_logps=None,
240246
ref_rejected_logps=None,
241-
completion_mask=jnp.ones((8, 4)),
247+
completion_mask=completion_mask,
242248
logits_to_keep=4,
243249
)
244250

245251
with mock.patch.object(
246-
common,
247-
"get_per_token_logps",
248-
return_value=jnp.array(per_token_logps),
252+
orpo_lib,
253+
"compute_logps",
254+
return_value=(jnp.array(chosen_logps), jnp.array(rejected_logps)),
249255
):
250256
loss, aux = orpo_lib.dpo_loss_fn(
251257
model,

0 commit comments

Comments
 (0)