Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
fe0e49b
dac model
kamilakesbi Jun 11, 2024
b985da9
original dac works
kamilakesbi Jun 11, 2024
47afe66
add dac model
kamilakesbi Jun 14, 2024
f42af54
dac can be instatiated
kamilakesbi Jun 14, 2024
255e479
add forward pass
kamilakesbi Jun 14, 2024
44aa197
load weights
kamilakesbi Jun 17, 2024
a1b6b2b
all weights are used
kamilakesbi Jun 17, 2024
31b1e6f
convert checkpoint script ready
kamilakesbi Jun 17, 2024
375f826
test
kamilakesbi Jun 18, 2024
21a7146
add feature extractor
kamilakesbi Jun 18, 2024
c4dce70
up
kamilakesbi Jun 18, 2024
59166dd
make style
kamilakesbi Jun 19, 2024
27811d7
apply cookicutter
kamilakesbi Jun 19, 2024
1408e3a
fix tests
kamilakesbi Jun 19, 2024
4366412
iterate on FeatureExtractor
kamilakesbi Jun 19, 2024
ace2197
nit
kamilakesbi Jun 19, 2024
a563b4f
update dac doc
kamilakesbi Jun 20, 2024
3c21b38
replace nn.Sequential with nn.ModuleList
kamilakesbi Jun 20, 2024
21072a9
nit
kamilakesbi Jun 20, 2024
95d0d18
apply review suggestions 1/2
kamilakesbi Jun 21, 2024
cae002f
Update src/transformers/models/dac/modeling_dac.py
kamilakesbi Jun 21, 2024
6b52abe
up
kamilakesbi Jun 23, 2024
af9cd69
apply review suggestions 2/2
kamilakesbi Jun 24, 2024
bf09ca8
update padding in FeatureExtractor
kamilakesbi Jun 24, 2024
54a1ec6
apply review suggestions
kamilakesbi Jun 26, 2024
3bc40c6
iterate on design and tests
kamilakesbi Jun 26, 2024
01511b7
add integration tests
kamilakesbi Jun 27, 2024
5cdf0ae
feature extractor tests
kamilakesbi Jun 27, 2024
167cb8f
make style
kamilakesbi Jun 27, 2024
a4d1261
all tests pass
kamilakesbi Jun 27, 2024
1fd2496
make style
kamilakesbi Jun 27, 2024
09ec8b5
fixup
kamilakesbi Jun 27, 2024
a5ac7c6
apply review suggestions
kamilakesbi Jul 3, 2024
284c75b
fix-copies
kamilakesbi Jul 3, 2024
7512886
apply review suggestions
kamilakesbi Jul 4, 2024
dc2e85c
apply review suggestions
kamilakesbi Jul 10, 2024
c7318d5
Update docs/source/en/model_doc/dac.md
kamilakesbi Jul 8, 2024
fdb8ced
Update docs/source/en/model_doc/dac.md
kamilakesbi Jul 8, 2024
5388663
anticipate transfer weights to descript
kamilakesbi Jul 10, 2024
fac14fd
up
kamilakesbi Jul 12, 2024
e088e0d
make style
kamilakesbi Jul 12, 2024
bfaef5e
apply review suggestions
kamilakesbi Jul 22, 2024
a473975
update slow test values
kamilakesbi Jul 23, 2024
2be0f36
update slow tests
kamilakesbi Jul 25, 2024
c13180e
update test values
kamilakesbi Jul 25, 2024
8c72cda
update with CI values
kamilakesbi Jul 25, 2024
89b7143
update with vorace values
kamilakesbi Jul 26, 2024
5b02249
update test with slice
kamilakesbi Jul 26, 2024
1671917
make style
kamilakesbi Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@
title: Bark
- local: model_doc/clap
title: CLAP
- local: model_doc/dac
title: dac
- local: model_doc/encodec
title: EnCodec
- local: model_doc/hiera
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
| [Data2VecVision](model_doc/data2vec) | ✅ | ✅ | ❌ |
Expand Down
80 changes: 80 additions & 0 deletions docs/source/en/model_doc/dac.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# DAC

## Overview


The DAC model was proposed in [Descript Audio Codec: High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546) by Rithesh Kumar, Prem Seetharaman, Alejandro Luebs, Ishaan Kumar, Kundan Kumar.

The Descript Audio Codec (DAC) model is a powerful tool for compressing audio data, making it highly efficient for storage and transmission. By compressing 44.1 KHz audio into tokens at just 8kbps bandwidth, the DAC model enables high-quality audio processing while significantly reducing the data footprint. This is particularly useful in scenarios where bandwidth is limited or storage space is at a premium, such as in streaming applications, remote conferencing, and archiving large audio datasets.

The abstract from the paper is the following:

*Language models have been successfully used to model natural signals, such as images, speech, and music. A key component of these models is a high quality neural compression model that can compress high-dimensional natural signals into lower dimensional discrete tokens. To that end, we introduce a high-fidelity universal neural audio compression algorithm that achieves ~90x compression of 44.1 KHz audio into tokens at just 8kbps bandwidth. We achieve this by combining advances in high-fidelity audio generation with better vector quantization techniques from the image domain, along with improved adversarial and reconstruction losses. We compress all domains (speech, environment, music, etc.) with a single universal model, making it widely applicable to generative modeling of all audio. We compare with competing audio compression algorithms, and find our method outperforms them significantly. We provide thorough ablations for every design choice, as well as open-source code and trained model weights. We hope our work can lay the foundation for the next generation of high-fidelity audio modeling.*

This model was contributed by [Kamil Akesbi](https://huggingface.co/kamilakesbi).
The original code can be found [here](https://github.com/descriptinc/descript-audio-codec/tree/main?tab=readme-ov-file).


## Model structure

The Descript Audio Codec (DAC) model is structured into three distinct stages:

1. Encoder Model: This stage compresses the input audio, reducing its size while retaining essential information.
2. Residual Vector Quantizer (RVQ) Model: Working in tandem with the encoder, this model quantizes the latent codes of the audio, refining the compression and ensuring high-quality reconstruction.
3. Decoder Model: This final stage reconstructs the audio from its compressed form, restoring it to a state that closely resembles the original input.

## Usage example

Here is a quick example of how to encode and decode an audio using this model:

```python
>>> from datasets import load_dataset, Audio
>>> from transformers import DacModel, AutoProcessor
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

>>> model = DacModel.from_pretrained("descript/dac_16khz")
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")

>>> encoder_outputs = model.encode(inputs["input_values"])
>>> # Get the intermediate audio codes
>>> audio_codes = encoder_outputs.audio_codes
>>> # Reconstruct the audio from its quantized representation
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"]).audio_values
```

## DacConfig

[[autodoc]] DacConfig

## DacFeatureExtractor

[[autodoc]] DacFeatureExtractor
- __call__

## DacModel

[[autodoc]] DacModel
- decode
- encode
- forward
15 changes: 15 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
"CTRLTokenizer",
],
"models.cvt": ["CvtConfig"],
"models.dac": ["DacConfig", "DacFeatureExtractor"],
"models.data2vec": [
"Data2VecAudioConfig",
"Data2VecTextConfig",
Expand Down Expand Up @@ -1757,6 +1758,12 @@
"CvtPreTrainedModel",
]
)
_import_structure["models.dac"].extend(
[
"DacModel",
"DacPreTrainedModel",
]
)
_import_structure["models.data2vec"].extend(
[
"Data2VecAudioForAudioFrameClassification",
Expand Down Expand Up @@ -5026,6 +5033,10 @@
CTRLTokenizer,
)
from .models.cvt import CvtConfig
from .models.dac import (
DacConfig,
DacFeatureExtractor,
)
from .models.data2vec import (
Data2VecAudioConfig,
Data2VecTextConfig,
Expand Down Expand Up @@ -6450,6 +6461,10 @@
CvtModel,
CvtPreTrainedModel,
)
from .models.dac import (
DacModel,
DacPreTrainedModel,
)
from .models.data2vec import (
Data2VecAudioForAudioFrameClassification,
Data2VecAudioForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
cpmant,
ctrl,
cvt,
dac,
data2vec,
dbrx,
deberta,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
("cpmant", "CpmAntConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("dac", "DacConfig"),
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
("data2vec-vision", "Data2VecVisionConfig"),
Expand Down Expand Up @@ -354,6 +355,7 @@
("cpmant", "CPM-Ant"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("dac", "DAC"),
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),
("data2vec-vision", "Data2VecVision"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
("conditional_detr", "ConditionalDetrFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("cvt", "ConvNextFeatureExtractor"),
("dac", "DacFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deformable_detr", "DeformableDetrFeatureExtractor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
("cpmant", "CpmAntModel"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("dac", "DacModel"),
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
("data2vec-vision", "Data2VecVisionModel"),
Expand Down
60 changes: 60 additions & 0 deletions src/transformers/models/dac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_dac": ["DacConfig"],
"feature_extraction_dac": ["DacFeatureExtractor"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_dac"] = [
"DacModel",
"DacPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_dac import (
DacConfig,
)
from .feature_extraction_dac import DacFeatureExtractor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_dac import (
DacModel,
DacPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
111 changes: 111 additions & 0 deletions src/transformers/models/dac/configuration_dac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dac model configuration"""

import math

import numpy as np

from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)


class DacConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`DacModel`]. It is used to instantiate a
Dac model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[descript/dac_16khz](https://huggingface.co/descript/dac_16khz) architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
encoder_hidden_size (`int`, *optional*, defaults to 64):
Intermediate representation dimension for the encoder.
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
decoder_hidden_size (`int`, *optional*, defaults to 1536):
Intermediate representation dimension for the decoder.
n_codebooks (`int`, *optional*, defaults to 9):
Number of codebooks in the VQVAE.
codebook_size (`int`, *optional*, defaults to 1024):
Number of discrete codes in each codebook.
codebook_dim (`int`, *optional*, defaults to 8):
Dimension of the codebook vectors. If not defined, uses `encoder_hidden_size`.
quantizer_dropout (`bool`, *optional*, defaults to 0):
Whether to apply dropout to the quantizer.
commitment_loss_weight (float, *optional*, defaults to 0.25):
Weight of the commitment loss term in the VQVAE loss function.
codebook_loss_weight (float, *optional*, defaults to 1.0):
Weight of the codebook loss term in the VQVAE loss function.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
Example:

```python
>>> from transformers import DacModel, DacConfig

>>> # Initializing a "descript/dac_16khz" style configuration
>>> configuration = DacConfig()

>>> # Initializing a model (with random weights) from the "descript/dac_16khz" style configuration
>>> model = DacModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "dac"

def __init__(
self,
encoder_hidden_size=64,
downsampling_ratios=[2, 4, 8, 8],
decoder_hidden_size=1536,
n_codebooks=9,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0,
commitment_loss_weight=0.25,
codebook_loss_weight=1.0,
sampling_rate=16000,
**kwargs,
):
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_hidden_size = decoder_hidden_size
self.upsampling_ratios = downsampling_ratios[::-1]
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.sampling_rate = sampling_rate

self.hidden_size = encoder_hidden_size * (2 ** len(downsampling_ratios))

self.hop_length = int(np.prod(downsampling_ratios))
self.commitment_loss_weight = commitment_loss_weight
self.codebook_loss_weight = codebook_loss_weight

super().__init__(**kwargs)

@property
def frame_rate(self) -> int:
hop_length = np.prod(self.upsampling_ratios)
return math.ceil(self.sampling_rate / hop_length)
Loading