Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ class DatasetArguments(CustomDatasetArguments):
"Default is set to True."
},
)
dataloader_num_workers: int = field(
default=0,
metadata={
"help": "Number of worker processes for data loading. Set to 0 to disable "
"multiprocessing. Note: Custom data collators may not work with "
"multiprocessing. Default is 0."
},
)

def is_dataset_provided(self) -> bool:
return self.dataset is not None or self.dataset_path is not None
12 changes: 12 additions & 0 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,24 @@ def format_calibration_data(
tokenized_dataset: Dataset,
processor: Processor,
) -> DataLoader:
# Use num_workers for parallel data loading if specified
# Note: Only use num_workers > 0 if data_collator is picklable
# Custom callable collators may not work with multiprocessing
num_workers = getattr(args, "dataloader_num_workers", 0)
# Disable multiprocessing for small datasets (to avoid overhead) or for custom
# collators (which may not be picklable).
# Note: Built-in collators ("truncation", "padding") are handled by
# _make_collate_fn, return picklable functions, and are safe.
if len(tokenized_dataset) < 100 or isinstance(args.data_collator, Callable):
num_workers = 0

return DataLoader(
tokenized_dataset,
batch_size=args.batch_size,
sampler=_make_sampler(args, tokenized_dataset),
collate_fn=_make_collate_fn(args, processor),
pin_memory=False,
num_workers=num_workers,
)


Expand Down
11 changes: 10 additions & 1 deletion src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def from_dataloader(
"""
Initialize a cache with data from the provided dataloader

This method iterates through all batches in the dataloader and offloads
them to the specified device. For faster cache preparation, consider:
- Increasing batch_size to reduce the number of iterations
- Using num_workers > 0 in the DataLoader for parallel loading
- Ensuring data preprocessing is done before creating the dataloader

:param dataloader: dataloader which generates values to be cached
:param model_device: device which values will be onloaded to when fetched
:param offload_device: device to offload values to
Expand Down Expand Up @@ -233,8 +239,11 @@ def _offload_value(
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
# Skip device transfer if tensor is already on target device
if offload_device is not None and value.device != offload_device:
value = value.to(device=offload_device)
return IntermediateValue(
value=value.to(device=offload_device),
value=value,
device=(onload_device if onload_device else value.device),
)
case list():
Expand Down
Loading