This folder demonstrates how to train TE-accelerated Llama 3 with a native PyTorch training loop, including sequence packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. This recipe is configured for genomic sequences using a custom nucleotide tokenizer.
This folder contains an independent, minimal training example. It does not depend on any other code in the top-level bionemo-framework repository. You can download a zipped directory of this folder alone by clicking here.
🚧 Under development
| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism |
|---|---|---|---|---|---|---|---|
| Llama 3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
✅: Supported
🚧: Under development
❌: Not supported
[1]: Requires compute capability 9.0 and above (Hopper+)
[2]: Requires compute capability 10.0 and 10.3 (Blackwell), 12.0 support pending
The easiest way to get started with this recipe is to use the provided Dockerfile, which uses the latest NVIDIA PyTorch base image to provide optimized versions of PyTorch and TransformerEngine. To build the container, run:
docker build -t llama3_native_te .To run the container, run:
docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo llama3_native_te /bin/bashAlternatively, the dependencies can be installed manually in an environment with CUDA support. See requirements.txt
for the list of dependencies.
Scaling Llama 3 70B with Context Parallelism (CP) on 32x NVIDIA GB300 GPUs (NVL32) with synthetic data of increasing sequence length. MFU was calculated using a 2.5 PFLOPS/GPU maximum theoretical bf16 throughput, with model FLOPS calculated with the formula
def compute_model_pflops(seq_len, global_batch_size, step_time_s):
B, S, H, L, V = global_batch_size, seq_len, HIDDEN_DIM, N_LAYERS, VOCAB_SIZE
model_flops = (
(24 * B * S * H * H + 4 * B * S * S * H) * (3 * L) + (6 * B * S * H * V)
) / step_time_s
return model_flops / 1e15We compared the convergence of this Llama3 recipe (with FSDP2) against NeMo 2.0 (https://github.com/NVIDIA-NeMo/NeMo) and the facebookresearch/lingua implementation on the DCLM Baseline 1.0 dataset. See Training on Natural Language Data (Lingua Reproduction) for more details. The figure above shows similar loss convergence and step time to the NeMo 2.0 training example, and the following table shows downstream performance on various tasks using the lm-eval library. The variation in training step time every 10,000 steps are due checkpointing, further work will be done to improve training step time stability.
| name | arc_challenge | arc_easy | boolq | copa | hella_swag | piqa | winogrande |
|---|---|---|---|---|---|---|---|
| Llama3 1B (Lingua) | 32 | 69.32 | 62.94 | 77 | 47.16 | 73.34 | 63.06 |
| Llama3 1B (NeMo2) | 32.34 | 67 | 59.79 | 78 | 59.09 | 74.43 | 59.83 |
| Llama3 1B (Recipes) | 30.63 | 65.57 | 55.6 | 80 | 44.57 | 73.29 | 58.72 |
Models were trained on 64 NVIDIA H100 GPUs with a micro batch size of 4 and a context length of 4096 for 60,000 steps. Training was performed with BF16 precision.
This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Context Parallelism, shown in three separate training entrypoints:
- Distributed Data Parallel (DDP), shown in
train_ddp.py - Fully Sharded Data Parallel 2 (FSDP2), shown in
train_fsdp2.py - FSDP2 with Context Parallelism, shown in
train_fsdp2_cp.py
To run single-process training on one GPU, run:
python train_fsdp2.py # or train_ddp.pyTo run multi-process training locally on 2+ GPUs, run (e.g. 2 GPUs):
torchrun --nproc_per_node=2 train_fsdp2.py # or train_ddp.pyMulti-Node training is supported with both strategies.
A convergence test configuration (L0_convergence) is also available, which uses a tiny Llama model
to verify that the training loop can overfit on a small dataset:
python train_fsdp2.py --config-name L0_convergenceGradient accumulation is supported with both strategies. To enable gradient accumulation, set grad_acc_steps to the
number of steps to accumulate gradients before updating the model parameters. This is useful to scale the effective
batch size while running on a smaller number of GPUs.
python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2To run training with FP8, enable it by overriding the fp8_config.enabled=true configuration parameter. Additional FP8
configuration parameters, including switching to MXFP8BlockScaling, can be set via the hydra configuration.
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=trueWhen training with FP8, you can initialize model weights directly in the target quantized format by setting
config_kwargs.use_quantized_model_init=true. This tells TransformerEngine to create weights inside a
te.quantized_model_init context, avoiding a separate quantization step after initialization.
python train_fsdp2.py --config-name L0_sanity \
fp8_config.enabled=true \
+config_kwargs.use_quantized_model_init=trueWe also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.
To enable this please select the following config options.
python train_fsdp2.py \
fp8_stats_config.enabled=True \
fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \
fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \
fp8_config.enabled=TrueNote: This feature is available for the train_ddp and the train_fsdp2 scripts.
The config file structure fp8_debugging_stats.yaml is explained in the NVIDIA Transformer Engine config file documentation in more detail. Below we will cover some very basic elements of the file structure.
This comes as a performance cost that is dependent on the freq parameter mentioned above. freq=1 collects stats on every step which in our
experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using freq>=10 to reduce this performance hit.
Sequence packing is handled via a padding-free collator (in collator.py) that provides input arguments (e.g.
cu_seq_lens_q) needed for padding-free attention. To enable sequence packing, set use_sequence_packing=true
in the hydra configuration.
python train_fsdp2.py --config-name L0_sanity use_sequence_packing=trueTo combine FP8 training with sequence packing, the number of unpadded input tokens must be a multiple of 16. The data collator will automatically pad packed sequences to the maximum number of tokens per batch.
python train_fsdp2.py --config-name L0_sanity \
fp8_config.enabled=true \
use_sequence_packing=trueContext parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very
long sequences. Use train_fsdp2_cp.py with the L0_sanity_cp configuration and set cp_size to the number of context
parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported.
torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2This recipe is configured to use genomic sequences. The default configuration uses a local test file
test_genomic_sequences.parquet.
For large-scale training, you can point the dataset configuration to a directory containing Parquet files or a Hugging Face dataset, e.g., for the DCLM Baseline 1.0 dataset:
hf download mlfoundations/dclm-baseline-1.0 \
--repo-type dataset --include 'global-shard_01_of_10/*' \
--local-dir /path/to/download/directory
python train_fsdp2.py --config-name L0_sanity \
dataset.load_dataset_kwargs.path=/path/to/download/directoryWe provide a configuration to reproduce the Llama-3.2-1B training experiments from Meta Lingua, using the DCLM Baseline 1.0 dataset.
To launch this training, use the L2_lingua_1b configuration:
python train_fsdp2.py --config-name L2_lingua_1bThis configuration:
- Uses the
meta-llama/Llama-3-8Btokenizer (requires Hugging Face login). - Streams the
mlfoundations/dclm-baseline-1.0dataset. - Sets up training parameters to match the Lingua 1B baseline.
Note: Ensure you have authenticated with Hugging Face (huggingface-cli login) to access the tokenizer if it requires
gated access, or update the config to use a public Llama 3 tokenizer.
To enable checkpoint saving, ensure that checkpoint.ckpt_dir is set to a writable directory. Checkpointing frequency is
controlled by the checkpoint.save_every_n_steps configuration parameter.
python train_fsdp2.py --config-name L0_sanity \
checkpoint.ckpt_dir=/path/to/ckpt_dir \
checkpoint.save_every_n_steps=100To enable checkpoint loading, set checkpoint.resume_from_checkpoint=true to resume from the latest checkpoint.
python train_fsdp2.py --config-name L0_sanity \
checkpoint.ckpt_dir=/path/to/ckpt_dir \
checkpoint.resume_from_checkpoint=trueWe also show how to export a final model at the end of training, which is suitable for uploading to the Hugging Face Hub
or for local inference as a more durable format than torch distributed checkpoints. To enable this, set
checkpoint.save_final_model=true in the hydra configuration. The resulting model will be saved to the final_model
directory within the checkpoint directory.
These examples show how to save and resume your dataloader by passing the dataloader instance to our save_checkpoint_*
and load_checkpoint_* functions using the StatefulDataLoader class from torchdata. See checkpoint.py for
implementation details.
This recipe includes built-in support for profiling with NVIDIA Nsight Systems, which provides detailed performance traces including CUDA kernels, CPU activities, memory operations, and NVTX ranges. The profiler allows you to specify the exact training step range to profile.
To profile a training run on a single GPU:
nsys profile \
-o nsight_trace \
--trace=cuda,nvtx,osrt,cudnn,cublas \
--pytorch=autograd-nvtx \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
python train_fsdp2.py \
profiler.enabled=true \
profiler.start_step=10 \
profiler.end_step=15Profiler Configuration Parameters:
profiler.enabled: Enable/disable profiling (default: false)profiler.start_step: Training step at which to start profiling (default: 10)profiler.end_step: Training step at which to end profiling (default: 15)
Nsight Systems Flags:
--pytorch=autograd-nvtx: Adds NVTX markers for PyTorch autograd operations (forward/backward passes, optimizer steps). This helps visualize the training loop structure and identify bottlenecks in the computation graph.--pytorch-backtrace=cuda: Captures Python backtraces for CUDA kernel launches, helping identify which Python code triggered each kernel. This is invaluable for debugging performance issues and understanding which operations are expensive.--python-sampling=true(optional): Periodically samples Python call stacks to identify CPU-side bottlenecks. Useful when investigating data loading, preprocessing, or Python overhead. Adds ~5-15% overhead, so only use when needed.
Note: The PyTorch-specific flags (--pytorch=autograd-nvtx and --pytorch-backtrace=cuda) add minimal overhead but provide significantly more detailed insights into PyTorch operations, making them highly recommended for training workload profiling. Use --python-sampling=true only when investigating CPU/Python performance.
The profiler will start capturing performance data at start_step and stop at end_step. It's recommended to start profiling after a few steps to allow training to stabilize.
For distributed training, profiling is only performed on global rank 0 to minimize overhead and avoid redundant data collection. Other ranks will skip profiling automatically.
nsys profile \
-o nsight_trace_rank0 \
--trace=cuda,nvtx,osrt,cudnn,cublas \
--pytorch=autograd-nvtx \
--pytorch-backtrace=cuda \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
torchrun --nproc_per_node=2 train_fsdp2.py \
profiler.enabled=trueFor more information on Nsight Systems, see the official documentation.
Models can be loaded from the final checkpoint directory using the AutoModelForCausalLM method (or
NVLlamaForCausalLM if using TE layers directly, though standard HF loading works if the config.json is updated to
include an auto_map entry for modeling_llama_te.NVLlamaForCausalLM and the custom forward pass is packaged in the
checkpoint directory).
If you trained with TE layers (which is the default in this recipe), you can use the NVLlamaForCausalLM class for
inference, which supports high-performance generation with TE-provided InferenceParams key-value cache class.
import torch
from transformers import AutoTokenizer
from transformer_engine.pytorch.attention import InferenceParams
from modeling_llama_te import NVLlamaForCausalLM, NVLlamaConfig
# Load the model configuration and weights
config = NVLlamaConfig.from_pretrained("path/to/final_model")
model = NVLlamaForCausalLM.from_pretrained("path/to/final_model", config=config)
tokenizer = AutoTokenizer.from_pretrained("./tokenizers/nucleotide_fast_tokenizer")
model.to("cuda")
model.eval()
# Example genomic sequence
sequence = "ACGTACGT"
inputs = tokenizer(sequence, return_tensors="pt").to("cuda")
# Setup inference parameters for efficient generation
past_key_values = InferenceParams(
max_batch_size=1,
max_sequence_length=256,
num_heads_kv=model.config.num_key_value_heads,
head_dim_k=model.config.hidden_size // model.config.num_attention_heads,
dtype=torch.bfloat16,
qkv_format="thd",
max_ctx_len=256,
)
for layer_number in range(1, model.config.num_hidden_layers + 1):
past_key_values.allocate_memory(layer_number)
# Generate
with torch.no_grad():
output_ids = model.generate(
**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values
)
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(generated_text)If you want to convert the trained TE model back to a standard Hugging Face LlamaForCausalLM (e.g., to use with vLLM
or SGLang), you can use the provided conversion script in ../../models/llama3/convert.py.
import torch
from transformers import AutoTokenizer
from modeling_llama_te import NVLlamaForCausalLM, NVLlamaConfig
from convert import convert_llama_te_to_hf # Import from ../../models/llama3/convert.py
# Load TE model
config = NVLlamaConfig.from_pretrained("path/to/final_model")
model_te = NVLlamaForCausalLM.from_pretrained("path/to/final_model", config=config)
# Convert to HF
model_hf = convert_llama_te_to_hf(model_te)
# Save HF model
model_hf.save_pretrained("path/to/hf_converted_model")
tokenizer = AutoTokenizer.from_pretrained("./tokenizers/nucleotide_fast_tokenizer")
tokenizer.save_pretrained("path/to/hf_converted_model")Once converted, the model can be loaded by any library that supports Llama 3, such as vLLM or SGLang.
# Example vLLM usage
vllm serve path/to/hf_converted_modelTo run tests locally, run recipes_local_test.py from the repository root with the recipe directory as an argument.
./ci/scripts/recipes_local_test.py bionemo-recipes/recipes/llama3_native_te/To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the
"BioNeMo Recipes Dev Container" option. To run the tests inside the container, run pytest -v . in the recipe
directory.
Hydra is a powerful configuration management library for Python. This recipe uses Hydra to manage training configurations, allowing for easy modification of training hyper-parameters and model settings.
Configuration parameters can be overridden from the command line, e.g.
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true.


