Skip to content

Commit 9d7df23

Browse files
committed
Update
[ghstack-poisoned]
2 parents 2f8acf7 + ed4f44a commit 9d7df23

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+14879
-116
lines changed

.github/scripts/pre-build-script-win.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
pip install --upgrade setuptools
44

5-
export TORCHRL_BUILD_VERSION=0.8.0
5+
export TORCHRL_BUILD_VERSION=0.9.0

.github/scripts/td_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
export TORCHRL_BUILD_VERSION=0.8.0
3+
export TORCHRL_BUILD_VERSION=0.9.0
44
pip install --upgrade setuptools
55

66
# Check if ARCH is set to aarch64

.github/scripts/version_script.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@echo off
2-
set TORCHRL_BUILD_VERSION=0.8.0
2+
set TORCHRL_BUILD_VERSION=0.9.0
33
echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION%
44

55
@echo on

.github/unittest/linux/scripts/run_all.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@ set -v
99

1010
if [[ $OSTYPE != 'darwin'* ]]; then
1111
apt-get update && apt-get upgrade -y
12-
apt-get install -y vim git wget libsdl2-dev libsdl2-2.0-0 cmake
12+
apt-get install -y vim git wget cmake
13+
14+
# Enable universe repository
15+
# apt-get install -y software-properties-common
16+
# add-apt-repository universe
17+
# apt-get update
18+
19+
# apt-get install -y libsdl2-dev libsdl2-2.0-0
1320

1421
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev
1522
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb
@@ -208,11 +215,13 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro
208215
if [ "${CU_VERSION:-}" != cpu ] ; then
209216
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
210217
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
218+
--ignore test/llm \
211219
--timeout=120 --mp_fork_if_no_cuda
212220
else
213221
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
214222
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
215223
--ignore test/test_distributed.py \
224+
--ignore test/llm \
216225
--timeout=120 --mp_fork_if_no_cuda
217226
fi
218227

.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te
2828
export DISPLAY=:99
2929
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
3030

31-
CKPT_BACKEND=torch MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py
31+
CKPT_BACKEND=torch MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest \
32+
--instafail -v \
33+
--durations 200 \
34+
--ignore test/test_distributed.py \
35+
--ignore test/test_rlhf.py \
36+
--ignore test/llm \
37+
--mp_fork_if_no_cuda
38+
3239
#pytest --instafail -v --durations 200
3340
#python test/test_libs.py
3441
coverage combine

.github/unittest/linux_optdeps/scripts/run_all.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ export BATCHED_PIPE_TIMEOUT=60
159159
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
160160
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
161161
--ignore test/test_distributed.py \
162+
--ignore test/llm \
162163
--timeout=120 --mp_fork_if_no_cuda
163164

164165
coverage combine

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,7 @@ to be able to create this other composition:
11121112
CenterCrop
11131113
ClipTransform
11141114
Compose
1115+
ConditionalPolicySwitch
11151116
ConditionalSkip
11161117
Crop
11171118
DataLoadingPrimer

docs/source/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ API Reference
77
collectors
88
data
99
envs
10+
llms
1011
modules
1112
objectives
1213
trainers

docs/source/reference/llms.rst

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
.. currentmodule:: torchrl
2+
3+
LLM interface
4+
=============
5+
6+
.. _ref_llms:
7+
8+
TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.
9+
10+
Collectors
11+
----------
12+
13+
TorchRL offers a specialized collector class (:class:`~torchrl.collectors.llm.LLMCollector`) that is tailored for LLM
14+
use cases. We also provide dedicated updaters for some inference engines.
15+
16+
.. currentmodule:: torchrl.collectors.llm
17+
18+
.. autosummary::
19+
:toctree: generated/
20+
:template: rl_template.rst
21+
22+
vLLMUpdater
23+
LLMCollector
24+
25+
26+
Data structures
27+
---------------
28+
29+
To handle text-based data structures (such as conversations etc.), we offer a few data structures dedicated to carrying
30+
data for LLM post-training.
31+
32+
.. currentmodule:: torchrl.data.llm
33+
34+
.. autosummary::
35+
:toctree: generated/
36+
:template: rl_template.rst
37+
38+
History
39+
LLMData
40+
41+
Environments
42+
------------
43+
44+
When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the
45+
policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with
46+
tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.
47+
48+
The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
49+
easily extend or modify existing environments using transforms. This approach enables the isolation of individual
50+
components within specific :class:`~torchrl.envs.EnvBase` or :class:`~torchrl.envs.Transform` subclasses, making it
51+
simpler to augment or alter the environment logic.
52+
53+
Available Environment Classes and Utilities
54+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
56+
TorchRL provides various environment classes and utilities for working with LLMs, including:
57+
58+
- Various environment classes (:class:`~torchrl.envs.llm.ChatEnv`, :class:`~torchrl.envs.llm.DatasetChatEnv`,
59+
:class:`~torchrl.envs.llm.GSM8KEnv`, etc.)
60+
- Utility functions (:class:`~torchrl.envs.make_gsm8k_env`, :class:`~torchrl.envs.make_mlgym`, etc.)
61+
- Transforms and other supporting classes (:class:`~torchrl.envs.KLRewardTransform`,
62+
:class:`~torchrl.envs.TemplateTransform`, :class:`~torchrl.envs.Tokenizer`, etc.)
63+
64+
These components can be used to create customized environments tailored to specific use cases and requirements.
65+
66+
.. currentmodule:: torchrl.envs.llm
67+
68+
.. autosummary::
69+
:toctree: generated/
70+
:template: rl_template.rst
71+
72+
ChatEnv
73+
DatasetChatEnv
74+
GSM8KEnv
75+
make_gsm8k_env
76+
GSM8KPrepareQuestion
77+
GSM8KEnv
78+
IFEvalEnv
79+
IfEvalScorer
80+
IFEvalScoreData
81+
LLMEnv
82+
LLMHashingEnv
83+
make_mlgym
84+
MLGymWrapper
85+
GSM8KRewardParser
86+
IfEvalScorer
87+
as_nested_tensor
88+
as_padded_tensor
89+
DataLoadingPrimer
90+
KLRewardTransform
91+
TemplateTransform
92+
Tokenizer
93+
94+
Modules
95+
-------
96+
97+
The :ref:`~torchrl.modules.llm` section provides a set of wrappers and utility functions for popular training and
98+
inference backends. The main goal of these primitives is to:
99+
100+
- Unify the input / output data format across training and inference pipelines;
101+
- Unify the input / output data format across backends (to be able to use different backends across losses and
102+
collectors, for instance)
103+
- Give appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution,
104+
weight update, etc.)
105+
106+
Wrappers
107+
~~~~~~~~
108+
109+
.. currentmodule:: torchrl.modules.llm
110+
111+
.. autosummary::
112+
:toctree: generated/
113+
:template: rl_template.rst
114+
115+
TransformersWrapper
116+
vLLMWrapper
117+
118+
Utils
119+
~~~~~
120+
121+
.. currentmodule:: torchrl.modules.llm
122+
123+
.. autosummary::
124+
:toctree: generated/
125+
:template: rl_template.rst
126+
127+
CategoricalSequential
128+
LLMOnDevice
129+
make_vllm_worker
130+
stateless_init_process_group
131+
vLLMWorker
132+
133+
Objectives
134+
----------
135+
136+
LLM post training require some appropriate versions of the losses implemented in TorchRL.
137+
138+
GRPO
139+
~~~~
140+
141+
.. currentmodule:: torchrl.objectives.llm
142+
143+
.. autosummary::
144+
:toctree: generated/
145+
:template: rl_template.rst
146+
147+
GRPOLoss
148+
GRPOLossOutput
149+
MCAdvantage

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _main(argv):
172172
if is_nightly:
173173
tensordict_dep = "tensordict-nightly"
174174
else:
175-
tensordict_dep = "tensordict>=0.8.1,<0.9.0"
175+
tensordict_dep = "tensordict>=0.9.0,<0.10.0"
176176

177177
if is_nightly:
178178
version = get_nightly_version()

test/llm/libs/test_mlgym.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
9+
from functools import partial
10+
11+
import pytest
12+
13+
from torchrl import logger as torchrl_logger
14+
from torchrl.envs import SerialEnv
15+
16+
from torchrl.envs.llm import make_mlgym
17+
from torchrl.modules.llm import TransformersWrapper
18+
19+
20+
class TestMLGYM:
21+
def test_mlgym_specs(self):
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
model_name = "Qwen/Qwen2.5-7B-Instruct"
25+
tokenizer = AutoTokenizer.from_pretrained(model_name)
26+
tokenizer.eos_token = "<|im_end|>"
27+
policy = TransformersWrapper(
28+
AutoModelForCausalLM.from_pretrained(model_name).cuda(),
29+
tokenizer=tokenizer,
30+
from_text=True,
31+
generate=True,
32+
device="cuda:0",
33+
generate_kwargs={
34+
# "temperature": 0.8,
35+
# "repetition_penalty": 1.5,
36+
"max_new_tokens": 1024
37+
},
38+
)
39+
40+
env = SerialEnv(
41+
1,
42+
[
43+
partial(
44+
make_mlgym,
45+
task="prisonersDilemma",
46+
tokenizer=tokenizer,
47+
device="cuda:0",
48+
)
49+
],
50+
)
51+
rollout = env.rollout(3, policy)
52+
torchrl_logger.info(f"{rollout=}")
53+
env.check_env_specs(break_when_any_done="both")
54+
55+
def test_mlgym_task_reset(self):
56+
from transformers import AutoModelForCausalLM, AutoTokenizer
57+
58+
model_name = "Qwen/Qwen2.5-7B-Instruct"
59+
tokenizer = AutoTokenizer.from_pretrained(model_name)
60+
tokenizer.eos_token = "<|im_end|>"
61+
policy = TransformersWrapper(
62+
AutoModelForCausalLM.from_pretrained(model_name).cuda(),
63+
tokenizer=tokenizer,
64+
from_text=True,
65+
generate=True,
66+
device="cuda:0",
67+
generate_kwargs={
68+
# "temperature": 0.8,
69+
# "repetition_penalty": 1.5,
70+
"max_new_tokens": 1024
71+
},
72+
)
73+
74+
env = SerialEnv(
75+
1,
76+
[
77+
partial(
78+
make_mlgym,
79+
tasks=[
80+
"prisonersDilemma",
81+
"regressionKaggleHousePrice",
82+
"battleOfSexes",
83+
],
84+
tokenizer=tokenizer,
85+
device="cuda:0",
86+
)
87+
],
88+
)
89+
# We should get at least two tasks
90+
rollout = env.rollout(100, policy, break_when_any_done=False)
91+
torchrl_logger.info(f"{rollout=}")
92+
torchrl_logger.info(rollout["task"])
93+
94+
def test_mlgym_wrong_format(self):
95+
# A vanilla policy will not output anything useful, yet the env should run without error
96+
...
97+
98+
99+
if __name__ == "__main__":
100+
args, unknown = argparse.ArgumentParser().parse_known_args()
101+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)