Skip to content

Significant memory usage increase since 4.36 #28024

@oraluben

Description

@oraluben

bisected to #26681

System Info

Device: A10

  • huggingface_hub version: 0.19.4
  • Platform: Linux-5.10.134-15.al8.x86_64-x86_64-with-glibc2.32
  • Python version: 3.10.13
  • Running in iPython ?: No
  • Running in notebook ?: No
  • Running in Google Colab ?: No
  • Token path ?: /home/ecs-user/.cache/huggingface/token
  • Has saved token ?: False
  • Configured git credential helpers:
  • FastAI: N/A
  • Tensorflow: N/A
  • Torch: 2.2.0.dev20231213+cu121
  • Jinja2: 3.1.2
  • Graphviz: N/A
  • Pydot: N/A
  • Pillow: 9.3.0
  • hf_transfer: N/A
  • gradio: N/A
  • tensorboard: N/A
  • numpy: 1.24.1
  • pydantic: 2.5.2
  • aiohttp: 3.9.1
  • ENDPOINT: https://huggingface.co
  • HF_HUB_CACHE: /home/ecs-user/.cache/huggingface/hub
  • HF_ASSETS_CACHE: /home/ecs-user/.cache/huggingface/assets
  • HF_TOKEN_PATH: /home/ecs-user/.cache/huggingface/token
  • HF_HUB_OFFLINE: False
  • HF_HUB_DISABLE_TELEMETRY: False
  • HF_HUB_DISABLE_PROGRESS_BARS: None
  • HF_HUB_DISABLE_SYMLINKS_WARNING: False
  • HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
  • HF_HUB_DISABLE_IMPLICIT_TOKEN: False
  • HF_HUB_ENABLE_HF_TRANSFER: False
  • HF_HUB_ETAG_TIMEOUT: 10
  • HF_HUB_DOWNLOAD_TIMEOUT: 10

Who can help?

@tomaarsen

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Script:

import json

import numpy as np
import torch.nn.functional as F
from datasets import Dataset, load_dataset
from transformers import LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

config = LlamaConfig(num_hidden_layers=2)
config._flash_attn_2_enabled = True

def _flash_attention_forward(self, q, k, v, m, ql, dropout=0.0, softmax_scale=None):
    assert m is None
    return F.scaled_dot_product_attention(
        q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
        is_causal=True).transpose(1, 2)

LlamaFlashAttention2._flash_attention_forward = _flash_attention_forward

model = LlamaForCausalLM(config)
DEEPSPEED_TEMPLATE = '{"optimizer": {"type": "AdamW", "params": {"lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto"}}, "scheduler": {"type": "WarmupLR", "params": {"warmup_min_lr": "auto", "warmup_max_lr": "auto", "warmup_num_steps": "auto"}}, "zero_optimization": {"stage": 3, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e8, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": "auto"}, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false}'
ds_config = json.loads(DEEPSPEED_TEMPLATE)
ds_config['zero_optimization']['stage'] = 3

training_args = TrainingArguments(
    remove_unused_columns=False,
    log_level='info',
    per_device_train_batch_size=2,
    logging_steps=1,
    output_dir='./tmp',
    bf16=True,
    deepspeed=ds_config,
    gradient_checkpointing=True,
)

input_ids = np.random.randint(100, 30000, (1000, 2048))
data_set = Dataset.from_dict({
    "input_ids": input_ids,
    "labels": input_ids
})

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=data_set,
)
trainer.train()

  1. torchrun llama.py
  2. fail with torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.06 GiB. GPU 0 has a total capacity of 21.99 GiB of which 2.79 GiB is free. Including non-PyTorch memory, this process has 19.19 GiB memory in use. Of the allocated memory 16.93 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated.

Expected behavior

The tranning runs normally.

With transformers==4.35.2:

$ nvidia-smi 
Thu Dec 14 11:24:56 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A10                     On  | 00000000:00:07.0 Off |                    0 |
|  0%   37C    P0             157W / 150W |  20660MiB / 23028MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1281528      C   ...al/miniconda3/envs/zero3/bin/python    20648MiB |
+---------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions