Skip to content

Add Whisper language detection #1097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"@types/node": "^22.10.1",
"@webgpu/types": "^0.1.51",
"catharsis": "github:xenova/catharsis",
"fastest-levenshtein": "^1.0.16",
"jest": "^30.0.0-alpha.6",
"jest-environment-node": "^30.0.0-alpha.6",
"jsdoc-to-markdown": "^9.1.1",
Expand Down
7 changes: 7 additions & 0 deletions src/generation/configuration_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ export class GenerationConfig {
*/
bad_words_ids = null;

/**
* List of token ids that are allowed to be generated.
* @type {number[][]}
* @default null
*/
good_words_ids = null;

/**
* List of token ids that must be generated.
* If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`.
Expand Down
34 changes: 34 additions & 0 deletions src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
}
}

export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor {
/**
* Create a `OnlyGoodWordsLogitsProcessor`.
* @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated.
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(good_words_ids, eos_token_id) {
super();
this.good_words_ids = good_words_ids;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const good_ids = this.good_words_ids.flat();
// Iterate over batches of input IDs and logits
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
// For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs
for (let j = 0; j < batch_logits_data.length; ++j) {
if (!good_ids.includes(j)) {
batch_logits_data[j] = -Infinity;
}
}
}
return logits
}
}

/**
* [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
* where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
Expand Down
13 changes: 13 additions & 0 deletions src/generation/stopping_criteria.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria {
return new Array(input_ids.length).fill(this.interrupted);
}
}

/**
* This class can be used to always stop generation after one pass.
*/
export class AlwaysStopCriteria extends StoppingCriteria {
constructor() {
super();
}

_call(input_ids, scores) {
return new Array(input_ids.length).fill(true);
}
}
101 changes: 78 additions & 23 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import {
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
OnlyGoodWordsLogitsProcessor,
} from './generation/logits_process.js';

import {
Expand All @@ -112,7 +113,7 @@ import {
import { RawImage } from './utils/image.js';

import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';

Expand Down Expand Up @@ -1212,6 +1213,10 @@ export class PreTrainedModel extends Callable {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}

if (generation_config.good_words_ids !== null) {
processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id));
}

if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
Expand Down Expand Up @@ -3119,10 +3124,48 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
}

/**
* Detects language by running input through the model and checking for language tokens in the output.
*
* @param {WhisperGenerationConfig} generation_config
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
* @returns {Promise<number[]>} A list of language token IDs detected.
*/
async _detect_language(options) {
const inputs = options.inputs
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing this PR "inputs" was in my case not present, instead I had "input_features".

I noticed the type returns: "(Tensor of varying shape depending on the modality, optional): The sequence used as a prompt for the generation or as model inputs to the encoder. If null the method initializes it with bos_token_id and a batch size of 1. For decoder-only models inputs should be in the format of input_ids. For encoder-decoder models inputs can represent any of input_ids, input_values, input_features, or pixel_values."

By the way thanks for adding language detection, hope it will be merged soon :)

Copy link
Author

@ae9is ae9is Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I can't reproduce. And reading the typing, sounds like the input_ids/input_values/input_features should always be stored as inputs.

And even if the typing is sometimes wrong, patching _detect_language() to use for ex. options?.inputs ?? options?.input_features still won't fix the generate() function which is currently in main. So it sounds like maybe worth filing a separate issue and or PR.

But if you're interested in just trying an alternative build, my "develop" branch is a fork of v3.0.2 with the language detection patch applied that works for me in a real app. Hope it helps!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the model used, you probably can reproduce it with https://huggingface.co/onnx-community/whisper-large-v3-turbo - I was able to fix it with const inputs = options.inputs ?? options.input_features; in _detect_language on my side.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already used turbo and it works fine for me, sorry! (I do get an unrelated error when using turbo instead of small in the test suite.)

I guess it's up to the maintainer to decide what to do with this PR, and edits are enabled.

But I don't understand why you're not also getting issues with the generate() code that's currently in main. And if so, that's worth a separate issue and PR.

const generation_config = options.generation_config;
const batch_size = inputs?.dims?.[0]
if (!inputs || batch_size <= 0 || inputs.size <= 0) {
throw new Error("Cannot detect language for empty input");
}
const start_of_transcript = generation_config.decoder_start_token_id;
const decoder_input_ids = full([batch_size, 1], Number(1.0)).mul_(start_of_transcript).tolist();
const all_lang_ids = Object.values(generation_config.lang_to_id);
if (!all_lang_ids || all_lang_ids.length <= 0) {
throw new Error("Cannot detect language without language code to token ID map for model");
}
const stopping_criteria = new StoppingCriteriaList();
stopping_criteria.push(new AlwaysStopCriteria());
const good_words_ids = [all_lang_ids];
const output = await this.generate({
...options,
generation_config: {
...generation_config,
good_words_ids,
num_beams: 1,
do_sample: false,
},
stopping_criteria,
decoder_input_ids,
});
Comment on lines +3148 to +3158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to replace this with a single forward pass (by called this.forward(...) instead of using a generation step.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of user options for (and logic in) generate and I wanted to respect it while running language detection. It was simpler to extend generate to just stop after one pass than to duplicate that and use forward directly.

Like, hypothetically, a user adds a logits processor that suppresses the first 10 seconds worth of tokens. There is a 15s audio clip in two languages, and the context switches at 10s. The language detection should detect the second language not the first.

const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x));
const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x));
return lang_ids;
}

/**
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
*/
_retrieve_init_tokens(generation_config) {
async _retrieve_init_tokens(options) {
const generation_config = options.generation_config
// prefix tokens are of the form:
// - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>]
// - English-only: <|startoftranscript|> [<|notimestamps|>]
Expand All @@ -3134,16 +3177,26 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
let language = generation_config.language;
const task = generation_config.task;
if (generation_config.is_multilingual) {
let lang_id;
if (!language) {
// TODO: Implement language detection
console.warn('No language specified - defaulting to English (en).');
language = 'en';
try {
const lang_token_ids = await this._detect_language(options);
lang_id = lang_token_ids[0];
if (!lang_id) {
throw new Error("No language detected");
}
} catch (err) {
console.warn("No language detected - defaulting to English (en).");
language = "en";
}
}

// Add language token
const language_code = whisper_language_to_code(language);
const language_token = `<|${language_code}|>`;
init_tokens.push(generation_config.lang_to_id[language_token])
if (language) {
// Add language token
const language_code = whisper_language_to_code(language);
const language_token = `<|${language_code}|>`;
lang_id = generation_config.lang_to_id[language_token];
}
init_tokens.push(lang_id);

// Add task token
// NOTE: Defaults to 'transcribe' if no task is specified
Expand Down Expand Up @@ -3180,22 +3233,24 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
* @returns {Promise<ModelOutput|Tensor>} The output of the model, which can contain the generated token ids, attentions, and scores.
*/
async generate({
inputs = null,
generation_config = null,
logits_processor = null,
stopping_criteria = null,
async generate(options) {
let {
inputs = null,
generation_config = null,
logits_processor = null,
//stopping_criteria = null,

// Whisper-specific options (passed to kwargs)
// prompt_ids = null,
// language = null,
// task = null,
// Whisper-specific options (passed to kwargs)
// prompt_ids = null,
// language = null,
// task = null,

...kwargs
} = options;

...kwargs
}) {
generation_config = this._prepare_generation_config(generation_config, kwargs);

const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config);
const init_tokens = kwargs.decoder_input_ids ?? await this._retrieve_init_tokens({ ...options, generation_config });

if (generation_config.return_timestamps) {
logits_processor ??= new LogitsProcessorList();
Expand Down
18 changes: 16 additions & 2 deletions tests/models/whisper/test_modeling_whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ export default () => {
it(
"language unset; task unset",
async () => {
// language defaults to 'en'
// language defaults to detect, falling back to 'en'
// task defaults to 'transcribe'

const outputs = await model.generate({
input_features,
max_new_tokens: 1,
Expand All @@ -66,6 +65,21 @@ export default () => {
MAX_TEST_EXECUTION_TIME,
);

it(
"language unset; task set",
async () => {
// language defaults to detect, falling back to 'en'
const outputs = await model.generate({
input_features,
max_new_tokens: 1,
task: "translate",
});

expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n]]);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"language set; task unset",
async () => {
Expand Down
33 changes: 28 additions & 5 deletions tests/pipelines.test.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { pipeline, cos_sim } from "../src/transformers.js";
import { init, MAX_TEST_EXECUTION_TIME } from "./init.js";
import { collect_and_execute_pipeline_tests, compare, loadAudio } from "./test_utils.js";
import { collect_and_execute_pipeline_tests, compare, compareString, loadAudio } from "./test_utils.js";

// Initialise the testing environment
init();
Expand Down Expand Up @@ -724,7 +724,8 @@ xdescribe("Pipelines (ignored)", () => {
// Transcribe English
let output = await transcriber(audioData);
expect(output.text.length).toBeGreaterThan(50);
// { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." }
const expected = " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.";
compareString(expected, output.text);
}

{
Expand Down Expand Up @@ -757,20 +758,42 @@ xdescribe("Pipelines (ignored)", () => {
// Transcribe French
let output = await transcriber(audioData, { language: "french", task: "transcribe" });
expect(output.text.length).toBeGreaterThan(20);
// { text: " J'adore, j'aime, je n'aime pas, je déteste." }
const expected = " J'adore, j'aime, je n'aime pas, je déteste.";
compareString(expected, output.text);
}

{
// Translate French to English.
let output = await transcriber(audioData, { language: "french", task: "translate" });
expect(output.text.length).toBeGreaterThan(20);
// { text: " I love, I like, I don't like, I hate." }
const expected = " I love, I like, I don't like, I hate.";
compareString(expected, output.text);
}
await transcriber.dispose();
},
MAX_TEST_EXECUTION_TIME,
);


it(
`${models[1]}-language-detect`,
async () => {
let transcriber = await pipeline("automatic-speech-recognition", models[1]);
let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav";
let audioData = await loadAudio(url);
{
// Transcribe Japanese by autodetecting language
// Note: this sample needs to be hard enough for Whisper not to be able to transcribe it properly
// with the fallback 'en' language set!
let output = await transcriber(audioData, { language: null, task: "transcribe" });
expect(output.text.length).toBeGreaterThan(20);
const expected = "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。";
compareString(expected, output.text, 0.8);
}
await transcriber.dispose();
},
MAX_TEST_EXECUTION_TIME,
);

it(
models[2].join(" + "),
async () => {
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import fs from "fs";
import path from "path";
import { fileURLToPath } from "url";

import { distance } from "fastest-levenshtein";

export async function loadAudio(url) {
// NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data.
// For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing
Expand Down Expand Up @@ -68,6 +70,22 @@ export function compare(val1, val2, tol = 0.1) {
}
}

/**
* Compare two strings adding some tolerance for variation between model outputs.
*
* Similarity score is computing using Levenshtein distance (n_diff) between the two strings, as a fraction of the first string's length:
* similarity score = 1 - n_diff / str1.length.
*
* @param {string} str1 The first string
* @param {string} str2 The second string
* @param {number} tol Tolerance score for similarity between strings, from -Infinity to 1.0 (100% match).
*/
export function compareString(str1, str2, tol = 0.9) {
const dist = distance(str1, str2);
const score = 1 - dist / (str1.length ?? 1);
expect(score).toBeGreaterThanOrEqual(tol);
}

const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const models_dir = path.join(__dirname, "models");
Expand Down
Loading