-
Notifications
You must be signed in to change notification settings - Fork 908
Add Whisper language detection #1097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
242d33c
467851a
88bba08
fc56bdc
7cd642c
ecdd598
db84540
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,6 +90,7 @@ import { | |
TopKLogitsWarper, | ||
TopPLogitsWarper, | ||
ClassifierFreeGuidanceLogitsProcessor, | ||
OnlyGoodWordsLogitsProcessor, | ||
} from './generation/logits_process.js'; | ||
|
||
import { | ||
|
@@ -112,7 +113,7 @@ import { | |
import { RawImage } from './utils/image.js'; | ||
|
||
import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; | ||
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; | ||
import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; | ||
import { LogitsSampler } from './generation/logits_sampler.js'; | ||
import { apis } from './env.js'; | ||
|
||
|
@@ -1212,6 +1213,10 @@ export class PreTrainedModel extends Callable { | |
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); | ||
} | ||
|
||
if (generation_config.good_words_ids !== null) { | ||
processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id)); | ||
} | ||
|
||
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { | ||
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); | ||
} | ||
|
@@ -3119,10 +3124,48 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { | |
} | ||
|
||
/** | ||
* Detects language by running input through the model and checking for language tokens in the output. | ||
* | ||
* @param {WhisperGenerationConfig} generation_config | ||
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options | ||
* @returns {Promise<number[]>} A list of language token IDs detected. | ||
*/ | ||
async _detect_language(options) { | ||
const inputs = options.inputs | ||
const generation_config = options.generation_config; | ||
const batch_size = inputs?.dims?.[0] | ||
if (!inputs || batch_size <= 0 || inputs.size <= 0) { | ||
throw new Error("Cannot detect language for empty input"); | ||
} | ||
const start_of_transcript = generation_config.decoder_start_token_id; | ||
const decoder_input_ids = full([batch_size, 1], Number(1.0)).mul_(start_of_transcript).tolist(); | ||
const all_lang_ids = Object.values(generation_config.lang_to_id); | ||
if (!all_lang_ids || all_lang_ids.length <= 0) { | ||
throw new Error("Cannot detect language without language code to token ID map for model"); | ||
} | ||
const stopping_criteria = new StoppingCriteriaList(); | ||
stopping_criteria.push(new AlwaysStopCriteria()); | ||
const good_words_ids = [all_lang_ids]; | ||
const output = await this.generate({ | ||
...options, | ||
generation_config: { | ||
...generation_config, | ||
good_words_ids, | ||
num_beams: 1, | ||
do_sample: false, | ||
}, | ||
stopping_criteria, | ||
decoder_input_ids, | ||
}); | ||
Comment on lines
+3148
to
+3158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to replace this with a single forward pass (by called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a lot of user options for (and logic in) generate and I wanted to respect it while running language detection. It was simpler to extend generate to just stop after one pass than to duplicate that and use forward directly. Like, hypothetically, a user adds a logits processor that suppresses the first 10 seconds worth of tokens. There is a 15s audio clip in two languages, and the context switches at 10s. The language detection should detect the second language not the first. |
||
const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x)); | ||
const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x)); | ||
return lang_ids; | ||
} | ||
|
||
/** | ||
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options | ||
*/ | ||
_retrieve_init_tokens(generation_config) { | ||
async _retrieve_init_tokens(options) { | ||
const generation_config = options.generation_config | ||
// prefix tokens are of the form: | ||
// - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] | ||
// - English-only: <|startoftranscript|> [<|notimestamps|>] | ||
|
@@ -3134,16 +3177,26 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { | |
let language = generation_config.language; | ||
const task = generation_config.task; | ||
if (generation_config.is_multilingual) { | ||
let lang_id; | ||
if (!language) { | ||
// TODO: Implement language detection | ||
console.warn('No language specified - defaulting to English (en).'); | ||
language = 'en'; | ||
try { | ||
const lang_token_ids = await this._detect_language(options); | ||
lang_id = lang_token_ids[0]; | ||
if (!lang_id) { | ||
throw new Error("No language detected"); | ||
} | ||
} catch (err) { | ||
console.warn("No language detected - defaulting to English (en)."); | ||
language = "en"; | ||
} | ||
} | ||
|
||
// Add language token | ||
const language_code = whisper_language_to_code(language); | ||
const language_token = `<|${language_code}|>`; | ||
init_tokens.push(generation_config.lang_to_id[language_token]) | ||
if (language) { | ||
// Add language token | ||
const language_code = whisper_language_to_code(language); | ||
const language_token = `<|${language_code}|>`; | ||
lang_id = generation_config.lang_to_id[language_token]; | ||
} | ||
init_tokens.push(lang_id); | ||
|
||
// Add task token | ||
// NOTE: Defaults to 'transcribe' if no task is specified | ||
|
@@ -3180,22 +3233,24 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { | |
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options | ||
* @returns {Promise<ModelOutput|Tensor>} The output of the model, which can contain the generated token ids, attentions, and scores. | ||
*/ | ||
async generate({ | ||
inputs = null, | ||
generation_config = null, | ||
logits_processor = null, | ||
stopping_criteria = null, | ||
async generate(options) { | ||
let { | ||
inputs = null, | ||
generation_config = null, | ||
logits_processor = null, | ||
//stopping_criteria = null, | ||
|
||
// Whisper-specific options (passed to kwargs) | ||
// prompt_ids = null, | ||
// language = null, | ||
// task = null, | ||
// Whisper-specific options (passed to kwargs) | ||
// prompt_ids = null, | ||
// language = null, | ||
// task = null, | ||
|
||
...kwargs | ||
} = options; | ||
|
||
...kwargs | ||
}) { | ||
generation_config = this._prepare_generation_config(generation_config, kwargs); | ||
|
||
const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); | ||
const init_tokens = kwargs.decoder_input_ids ?? await this._retrieve_init_tokens({ ...options, generation_config }); | ||
|
||
if (generation_config.return_timestamps) { | ||
logits_processor ??= new LogitsProcessorList(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When testing this PR "inputs" was in my case not present, instead I had "input_features".
I noticed the type returns: "(Tensor of varying shape depending on the modality, optional): The sequence used as a prompt for the generation or as model inputs to the encoder. If null the method initializes it with bos_token_id and a batch size of 1. For decoder-only models inputs should be in the format of input_ids. For encoder-decoder models inputs can represent any of input_ids, input_values, input_features, or pixel_values."
By the way thanks for adding language detection, hope it will be merged soon :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I can't reproduce. And reading the typing, sounds like the
input_ids
/input_values
/input_features
should always be stored asinputs
.And even if the typing is sometimes wrong, patching
_detect_language()
to use for ex.options?.inputs ?? options?.input_features
still won't fix thegenerate()
function which is currently in main. So it sounds like maybe worth filing a separate issue and or PR.But if you're interested in just trying an alternative build, my "develop" branch is a fork of v3.0.2 with the language detection patch applied that works for me in a real app. Hope it helps!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it depends on the model used, you probably can reproduce it with https://huggingface.co/onnx-community/whisper-large-v3-turbo - I was able to fix it with
const inputs = options.inputs ?? options.input_features;
in _detect_language on my side.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've already used turbo and it works fine for me, sorry! (I do get an unrelated error when using turbo instead of small in the test suite.)
I guess it's up to the maintainer to decide what to do with this PR, and edits are enabled.
But I don't understand why you're not also getting issues with the
generate()
code that's currently in main. And if so, that's worth a separate issue and PR.