Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions tests/models/granite_speech/test_modeling_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from transformers.utils import (
is_datasets_available,
is_peft_available,
is_torch_available,
)

Expand Down Expand Up @@ -306,11 +307,17 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")

@pytest.mark.generate
@require_torch_sdpa
@slow
@unittest.skip(reason="Granite Speech doesn't support SDPA for all backbones")
def test_eager_matches_sdpa_generate(self):
pass


class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
# TODO - use the actual model path on HF hub after release.
self.model_path = "ibm-granite/granite-speech"
self.model_path = "ibm-granite/granite-speech-3.3-2b"
self.processor = AutoProcessor.from_pretrained(self.model_path)
self.prompt = self._get_prompt(self.processor.tokenizer)

Expand Down Expand Up @@ -338,7 +345,7 @@ def _load_datasamples(self, num_samples):
return [x["array"] for x in speech_samples]

@slow
@pytest.mark.skip("Public models not yet available")
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand very well here. Could you explain what are the different situations when peft is in the system and not in the system, it will cause this code in this test doing something differently and output will be different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ydshieh! Definitely - granite speech has a modality specific Lora that is only enabled when there are audio inputs. In this way, it lets users call it with audio for tasks like transcription, which will have the Lora enabled, and then choose to pass the raw text in through a second generate call if they want to, which is the same as calling the underlying llm (i.e., Lora is not enabled). This part of the model code will probably make it more clear:

if is_peft_available and self._hf_peft_config_loaded:

If Peft isn't installed, the model will not load the bundled audio Lora. This will cause the integration tests to fail with wrong outputs, which is why they are skipped if peft isn't installed

def test_small_model_integration_test_single(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(1)
Expand All @@ -364,9 +371,9 @@ def test_small_model_integration_test_single(self):
)

@slow
@pytest.mark.skip("Public models not yet available")
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
def test_small_model_integration_test_batch(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path)
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(2)
prompts = [self.prompt, self.prompt]

Expand All @@ -384,7 +391,7 @@ def test_small_model_integration_test_batch(self):

EXPECTED_DECODED_TEXT = [
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter"
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilp's manner less interesting than his matter"
] # fmt: skip

self.assertEqual(
Expand Down
4 changes: 1 addition & 3 deletions tests/models/granite_speech/test_processor_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor


@pytest.skip("Public models not yet available", allow_module_level=True)
@require_torch
@require_torchaudio
class GraniteSpeechProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
# TODO - use the actual model path on HF hub after release.
self.checkpoint = "ibm-granite/granite-speech"
self.checkpoint = "ibm-granite/granite-speech-3.3-8b"
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
processor.save_pretrained(self.tmpdirname)

Expand Down