Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ class DatasetArguments(CustomDatasetArguments):
"Default is set to True."
},
)
dataloader_num_workers: int = field(
default=0,
metadata={
"help": "Number of worker processes for data loading. Set to 0 to disable "
"multiprocessing. Note: Custom data collators may not work with "
"multiprocessing. Default is 0."
},
)

def is_dataset_provided(self) -> bool:
return self.dataset is not None or self.dataset_path is not None
1 change: 1 addition & 0 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def format_calibration_data(
sampler=_make_sampler(args, tokenized_dataset),
collate_fn=_make_collate_fn(args, processor),
pin_memory=False,
num_workers=args.dataloader_num_workers,
)


Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def oneshot(
streaming: bool = False,
overwrite_cache: bool = False,
preprocessing_num_workers: int | None = None,
dataloader_num_workers: int = 0,
min_tokens_per_module: float | None = None,
moe_calibrate_all_experts: bool = True,
quantization_aware_calibration: bool = True,
Expand Down Expand Up @@ -329,6 +330,9 @@ def oneshot(
:param streaming: True to stream data from a cloud dataset.
:param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
:param preprocessing_num_workers: Number of processes for dataset preprocessing.
:param dataloader_num_workers: Number of worker processes for data loading. Set to 0
to disable multiprocessing. Note: Custom data collators may not work with
multiprocessing. Default is 0.
:param min_tokens_per_module: Minimum percentage of tokens per
module, relevant for MoE models.
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
Expand Down
8 changes: 7 additions & 1 deletion src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def from_dataloader(
"""
Initialize a cache with data from the provided dataloader

This method iterates through all batches in the dataloader and offloads
them to the specified device. For faster cache preparation, consider:
- Increasing batch_size to reduce the number of iterations
- Using num_workers > 0 in the DataLoader for parallel loading
- Ensuring data preprocessing is done before creating the dataloader

:param dataloader: dataloader which generates values to be cached
:param model_device: device which values will be onloaded to when fetched
:param offload_device: device to offload values to
Expand Down Expand Up @@ -234,7 +240,7 @@ def _offload_value(
match value:
case torch.Tensor():
return IntermediateValue(
value=value.to(device=offload_device),
value=value.to(device=offload_device) if offload_device is not None else value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this addition necessary?

device=(onload_device if onload_device else value.device),
)
case list():
Expand Down