Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]:
"force_finite": config.get("force_finite", False),
"max_open_streams": config.get("max_open_streams", None),
"token_equivalent_duration": config.get("token_equivalent_duration", None),
"tarred_random_access": config.get("tarred_random_access", False),
"skip_missing_manifest_entries": config.get("skip_missing_manifest_entries", False),
}
input_cfg = config.input_cfg
if isinstance(input_cfg, (str, Path)):
Expand Down Expand Up @@ -510,11 +510,11 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
LazyNeMoTarredIterator(
config.manifest_filepath,
tar_paths=config.tarred_audio_filepaths,
tarred_random_access=config.tarred_random_access,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
)
if not config.tarred_random_access and not force_finite:
if not force_finite:
cuts = cuts.repeat()
else:
cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs))
Expand Down Expand Up @@ -552,7 +552,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path,
tar_paths=tar_path,
tarred_random_access=config.tarred_random_access,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
else:
Expand Down
51 changes: 33 additions & 18 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
make_worker_init_fn,
)
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import CutSampler, TimeConstraint
from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint
from lhotse.lazy import LazyFlattener
from lhotse.utils import fastcopy, fix_random_seed
from omegaconf import DictConfig, OmegaConf
Expand All @@ -44,6 +44,7 @@
read_cutset_from_config,
)
from nemo.collections.common.data.lhotse.sampling import (
BucketingFilter,
DurationFilter,
FixedBucketBatchSizeConstraint2D,
MultimodalFixedBucketBatchSizeConstraint2D,
Expand Down Expand Up @@ -76,7 +77,8 @@ class LhotseDataLoadingConfig:
cuts_path: str | None = None
shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None
# Enable this to support dataloading from JSON manifests that reference subsets of audio tar files.
tarred_random_access: bool = False
skip_missing_manifest_entries: bool = False
tarred_random_access: bool = False # deprecated, replaced by: skip_missing_manifest_entries
# 2. Batch size.
# a. Existing NeMo options.
batch_size: int | None = None
Expand All @@ -91,6 +93,7 @@ class LhotseDataLoadingConfig:
bucket_duration_bins: Any = None # list[float] | list[list[float]] | None = None
bucket_buffer_size: int = 10000
concurrent_bucketing: bool = True # fetches data in a background thread
bucketing_2d_strict_mode: bool = True # reduces padding by discarding significant outliers
# d. Other Lhotse sampling options.
shuffle_buffer_size: int | None = 10000
drop_last: bool = False
Expand All @@ -117,15 +120,15 @@ class LhotseDataLoadingConfig:
min_duration: float | None = -1
max_duration: float | None = float("inf")
min_tps: int = -1 # allowed tokens per second (audio-only)
max_tps: float = float("inf")
max_tps: Any = float("inf") # float | list[float]
# * Text input
min_tokens: int | None = None
max_tokens: int | None = None
# When true, combine context+answer lengths into a total length; otherwise report context length.
# For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len).
measure_total_length: bool = True
min_tpt: int = -1 # allowed tokens per token (text-only)
max_tpt: float = float("inf")
max_tpt: Any = float("inf") # float | list[float]

# 3. Supported existing NeMo options.
shuffle: bool = False
Expand Down Expand Up @@ -530,7 +533,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
# Select the strategy customizing Lhotse sampler behaviour.
# Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc.
bucket_duration_bins = determine_bucket_duration_bins(config)
constraint = determine_sampling_constraint(bucket_duration_bins, config)
cuts, constraint = determine_sampling_constraint(cuts, bucket_duration_bins, config)

# 3. The sampler.
if config.use_bucketing:
Expand Down Expand Up @@ -608,13 +611,15 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
return sampler, use_iterable_dataset


def determine_sampling_constraint(bucket_duration_bins, config):
def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> tuple[CutSet, SamplingConstraint]:
"""
Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration.
Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D).
It is the appropriate customization point to introduce support of other modalities,
as it defines a method for example sequence length measurement (audio duration, text tokens, etc.).

Some constraints apply extra filter on ``cuts`` which is why we accept and return the ``CutSet``.

Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are
multimodal constraints (joint text + audio) and their 2D bucketing extensions.
"""
Expand All @@ -627,7 +632,10 @@ def determine_sampling_constraint(bucket_duration_bins, config):
max_seq_len_buckets=bucket_duration_bins,
batch_sizes=config.bucket_batch_size,
token_equivalent_duration=config.token_equivalent_duration,
strict_2d=config.bucketing_2d_strict_mode,
max_ratio=config.max_tpt if isinstance(config.max_tpt, Sequence) else None,
)
cuts = cuts.filter(BucketingFilter(constraint))
else:
constraint = MultimodalSamplingConstraint(
token_equivalent_duration=config.token_equivalent_duration,
Expand All @@ -643,14 +651,17 @@ def determine_sampling_constraint(bucket_duration_bins, config):
constraint = FixedBucketBatchSizeConstraint2D(
max_seq_len_buckets=bucket_duration_bins,
batch_sizes=config.bucket_batch_size,
strict_2d=config.bucketing_2d_strict_mode,
max_ratio=config.max_tps if isinstance(config.max_tps, Sequence) else None,
)
cuts = cuts.filter(BucketingFilter(constraint))
else:
constraint = TimeConstraint(
max_cuts=config.batch_size,
max_duration=config.batch_duration,
quadratic_duration=config.quadratic_duration,
)
return constraint
return cuts, constraint


def determine_bucket_duration_bins(config):
Expand Down Expand Up @@ -698,28 +709,32 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi
if not isinstance(config, DictConfig):
config = DictConfig(config)

if config.get("tarred_random_access", False):
logging.warning(
"Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be also add version from which this would be removed

)
config.skip_missing_manifest_entries = True

if config.skip_missing_manifest_entries:
logging.warning(
"Note: skip_missing_manifest_entries is set to True. "
"If any of your manifests and tar files are mismatched, the entire tar file will be skipped without warning. "
"It's your responsibility to ensure data integrity with this setting."
)

# Remove unsupported keys and warn about them.
supported_keys = set(OmegaConf.to_container(default).keys())
received_keys = set(OmegaConf.to_container(config).keys())
unsupported_keys = received_keys - supported_keys
if unsupported_keys:
warnings.warn(
f"The following configuration keys are no longer supported " f"and ignored: {','.join(unsupported_keys)}",
category=DeprecationWarning,
logging.warning(
f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}",
)
config = OmegaConf.masked_copy(config, list(supported_keys))

return OmegaConf.merge(default, config)


def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool:
assert not (
config.force_map_dataset and config.force_iterable_dataset
), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True"
use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset
return use_iterable_dataset


def tokenize(example, tokenizer):
if isinstance(example, Cut):
for s in example.supervisions:
Expand Down
42 changes: 19 additions & 23 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ class LazyNeMoTarredIterator:
This can be used for other cloud storage APIs such as S3, GCS, etc.
The same mechanism applies to ``manifest_path``.

If your data has been filtered so that the JSON manifests refer to just a subset of recordings,
set ``skip_missing_manifest_entries` to ``True``.
This will still read the tar files sequentially (very fast) and discard the audio files that
are not present in the corresponding manifest.

The ``shard_seed`` argument is used to seed the RNG shuffling the shards.
By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module).
Seed is resolved lazily so that every dataloading worker may sample a different one.
Expand Down Expand Up @@ -264,10 +269,10 @@ def __init__(
shard_seed: int | Literal["trng", "randomized"] = "trng",
text_field: str = "text",
lang_field: str = "lang",
tarred_random_access: bool = False,
skip_missing_manifest_entries: bool = False,
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.tarred_random_access = tarred_random_access
self.skip_missing_manifest_entries = skip_missing_manifest_entries
self.shard_id_to_manifest: dict[int, Iterable[dict]]
self.paths = expand_sharded_filepaths(manifest_path)
if len(self.paths) == 1:
Expand Down Expand Up @@ -346,29 +351,21 @@ def _validate(self) -> None:
def shard_ids(self) -> List[int]:
return sorted(self.shard_id_to_manifest.keys())

def _iter_random_read(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=BytesIO(open_best(tar_path, mode="rb").read()), mode="r") as tar:
for data in shard_manifest:
def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for tar_info in tar:
try:
tar_info = tar.getmember(data)
data = shard_manifest[tar_info.name]
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio, tar_info
except KeyError as e:
raise RuntimeError(
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file."
) from e

def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]:
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for tar_info in tar:
assert tar_info.name in shard_manifest, (
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Cannot locate JSON entry for tar file '{tar_info.name}'"
)
data = shard_manifest[tar_info.name]
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio, tar_info
if self.skip_missing_manifest_entries:
continue
else:
raise RuntimeError(
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Cannot locate JSON entry for tar file '{tar_info.name}'"
) from e

def __iter__(self) -> Generator[Cut, None, None]:
shard_ids = self.shard_ids
Expand All @@ -384,7 +381,6 @@ def __iter__(self) -> Generator[Cut, None, None]:
# They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset.
offset_pattern = re.compile(r'^(?P<stem>.+)(?P<sub>-sub\d+)(?P<ext>\.\w+)?$')

iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential
for sid in shard_ids:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]

Expand All @@ -398,7 +394,7 @@ def basename(d: dict) -> str:
shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
tar_path = self.shard_id_to_tar_path[sid]
try:
for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path):
for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path):
meta = soundfile.info(BytesIO(raw_audio))
recording = Recording(
id=tar_info.path,
Expand Down
Loading