Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ class AudioClassificationPipeline(Pipeline):
"""

def __init__(self, *args, **kwargs):
# Default, might be overriden by the model.config.
kwargs["top_k"] = kwargs.get("top_k", 5)
# Only set default top_k if explicitly provided
if "top_k" in kwargs and kwargs["top_k"] is None:
kwargs["top_k"] = None
elif "top_k" not in kwargs:
kwargs["top_k"] = 5
super().__init__(*args, **kwargs)

if self.framework != "pt":
Expand Down Expand Up @@ -141,12 +144,16 @@ def __call__(
return super().__call__(inputs, **kwargs)

def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
# No parameters on this pipeline right now
postprocess_params = {}
if top_k is not None:

# If top_k is None, use all labels
if top_k is None:
postprocess_params["top_k"] = self.model.config.num_labels
else:
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
postprocess_params["top_k"] = top_k

if function_to_apply is not None:
if function_to_apply not in ["softmax", "sigmoid", "none"]:
raise ValueError(
Expand Down
60 changes: 60 additions & 0 deletions tests/test_audio_classification_top_k.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Rocketknight1 this should not need a single file on it's own, this should go in the pipeline tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix, my bad for not noticing it was a separate file in the review!

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest

import numpy as np

from transformers import pipeline
from transformers.testing_utils import require_torch


@require_torch
class AudioClassificationTopKTest(unittest.TestCase):
def test_top_k_none_returns_all_labels(self):
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")

def test_top_k_none_with_few_labels(self):
model_name = "superb/hubert-base-superb-er" # model with fewer labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")

def test_top_k_greater_than_labels(self):
model_name = "superb/hubert-base-superb-er"
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=100, # intentionally large number
)

# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)

result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels

self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")