diff --git a/src/models.js b/src/models.js index 976d1c000..8236f1ff7 100644 --- a/src/models.js +++ b/src/models.js @@ -1,15 +1,15 @@ /** * @file Definitions of all models available in Transformers.js. - * + * * **Example:** Load and run an `AutoModel`. - * + * * ```javascript * import { AutoModel, AutoTokenizer } from '@huggingface/transformers'; - * + * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); - * + * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { @@ -19,22 +19,22 @@ * // size: 183132, * // } * ``` - * + * * We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example: - * + * * **Example:** Load and run an `AutoModelForSeq2SeqLM`. * ```javascript * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers'; - * + * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); - * + * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); * // 'Ich liebe Transformatoren!' * ``` - * + * * @module models */ @@ -309,7 +309,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { /** * Helper function to create multiple InferenceSession objects. - * + * * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {Record} names The names of the model files to load. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. @@ -394,7 +394,7 @@ function validateInputs(session, inputs) { * NOTE: `inputs` must contain at least the input names of the model. * - If additional inputs are passed, they will be ignored. * - If inputs are missing, an error will be thrown. - * + * * @param {Object} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. @@ -815,7 +815,7 @@ function cumsum_masked_fill(attention_mask, start_index = 0) { /** * If the model supports providing position_ids, we create position_ids on the fly for batch generation, * by computing the cumulative sum of the attention mask along the sequence length dimension. - * + * * Equivalent to: * ```python * position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1046,17 +1046,17 @@ export class PreTrainedModel extends Callable { /** * Instantiate one of the model classes of the library from a pretrained model. - * + * * The model class to instantiate is selected based on the `model_type` property of the config object * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) - * + * * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. - * + * * @returns {Promise} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { @@ -1248,7 +1248,7 @@ export class PreTrainedModel extends Callable { * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] * instances used for multinomial sampling. * @param {GenerationConfig} generation_config The generation config. - * @returns {LogitsProcessorList} generation_config + * @returns {LogitsProcessorList} generation_config */ _get_logits_warper(generation_config) { @@ -1271,7 +1271,7 @@ export class PreTrainedModel extends Callable { } /** - * @param {GenerationConfig} generation_config + * @param {GenerationConfig} generation_config * @param {number} input_ids_seq_length The starting sequence length for the input ids. * @returns {LogitsProcessorList} * @private @@ -1440,9 +1440,9 @@ export class PreTrainedModel extends Callable { } /** - * - * @param {GenerationConfig} generation_config - * @param {StoppingCriteriaList} [stopping_criteria=null] + * + * @param {GenerationConfig} generation_config + * @param {StoppingCriteriaList} [stopping_criteria=null] */ _get_stopping_criteria(generation_config, stopping_criteria = null) { const criteria = new StoppingCriteriaList(); @@ -1505,7 +1505,7 @@ export class PreTrainedModel extends Callable { } /** - * + * * @param {Object} inputs * @param {bigint[][]} inputs.generated_input_ids * @param {Object} inputs.outputs @@ -1617,7 +1617,7 @@ export class PreTrainedModel extends Callable { /** * Prepares `decoder_input_ids` for generation with encoder-decoder models - * @param {*} param0 + * @param {*} param0 */ _prepare_decoder_input_ids_for_generation({ batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id, generation_config }) { let { decoder_input_ids, ...model_inputs } = model_kwargs; @@ -3281,11 +3281,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } /** - * - * @param {WhisperGenerationConfig} generation_config + * + * @param {WhisperGenerationConfig} generation_config */ _retrieve_init_tokens(generation_config) { - // prefix tokens are of the form: + // prefix tokens are of the form: // - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] // - English-only: <|startoftranscript|> [<|notimestamps|>] @@ -3901,25 +3901,25 @@ export class CLIPPreTrainedModel extends PreTrainedModel { } /** * CLIP Text and Vision Model with a projection layers on top - * + * * **Example:** Perform zero-shot image classification with a `CLIPModel`. - * + * * ```javascript * import { AutoTokenizer, AutoProcessor, CLIPModel, RawImage } from '@huggingface/transformers'; - * + * * // Load tokenizer, processor, and model * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/clip-vit-base-patch16'); * let processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch16'); * let model = await CLIPModel.from_pretrained('Xenova/clip-vit-base-patch16'); - * + * * // Run tokenization * let texts = ['a photo of a car', 'a photo of a football match'] * let text_inputs = tokenizer(texts, { padding: true, truncation: true }); - * + * * // Read image and run processor * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); * let image_inputs = await processor(image); - * + * * // Run model with both text and pixel inputs * let output = await model({ ...text_inputs, ...image_inputs }); * // { @@ -3960,20 +3960,20 @@ export class CLIPTextModel extends CLIPPreTrainedModel { /** * CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output) - * + * * **Example:** Compute text embeddings with `CLIPTextModelWithProjection`. - * + * * ```javascript * import { AutoTokenizer, CLIPTextModelWithProjection } from '@huggingface/transformers'; - * + * * // Load tokenizer and text model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clip-vit-base-patch16'); * const text_model = await CLIPTextModelWithProjection.from_pretrained('Xenova/clip-vit-base-patch16'); - * + * * // Run tokenization * let texts = ['a photo of a car', 'a photo of a football match']; * let text_inputs = tokenizer(texts, { padding: true, truncation: true }); - * + * * // Compute embeddings * const { text_embeds } = await text_model(text_inputs); * // Tensor { @@ -4011,20 +4011,20 @@ export class CLIPVisionModel extends CLIPPreTrainedModel { /** * CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output) - * + * * **Example:** Compute vision embeddings with `CLIPVisionModelWithProjection`. - * + * * ```javascript * import { AutoProcessor, CLIPVisionModelWithProjection, RawImage} from '@huggingface/transformers'; - * + * * // Load processor and vision model * const processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch16'); * const vision_model = await CLIPVisionModelWithProjection.from_pretrained('Xenova/clip-vit-base-patch16'); - * + * * // Read image and run processor * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); * let image_inputs = await processor(image); - * + * * // Compute embeddings * const { image_embeds } = await vision_model(image_inputs); * // Tensor { @@ -4054,25 +4054,25 @@ export class SiglipPreTrainedModel extends PreTrainedModel { } /** * SigLIP Text and Vision Model with a projection layers on top - * + * * **Example:** Perform zero-shot image classification with a `SiglipModel`. - * + * * ```javascript * import { AutoTokenizer, AutoProcessor, SiglipModel, RawImage } from '@huggingface/transformers'; - * + * * // Load tokenizer, processor, and model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224'); * const processor = await AutoProcessor.from_pretrained('Xenova/siglip-base-patch16-224'); * const model = await SiglipModel.from_pretrained('Xenova/siglip-base-patch16-224'); - * + * * // Run tokenization * const texts = ['a photo of 2 cats', 'a photo of 2 dogs']; * const text_inputs = tokenizer(texts, { padding: 'max_length', truncation: true }); - * + * * // Read image and run processor * const image = await RawImage.read('http://images.cocodataset.org/val2017/000000039769.jpg'); * const image_inputs = await processor(image); - * + * * // Run model with both text and pixel inputs * const output = await model({ ...text_inputs, ...image_inputs }); * // { @@ -4099,20 +4099,20 @@ export class SiglipModel extends SiglipPreTrainedModel { } /** * The text model from SigLIP without any head or projection on top. - * + * * **Example:** Compute text embeddings with `SiglipTextModel`. - * + * * ```javascript * import { AutoTokenizer, SiglipTextModel } from '@huggingface/transformers'; - * + * * // Load tokenizer and text model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224'); * const text_model = await SiglipTextModel.from_pretrained('Xenova/siglip-base-patch16-224'); - * + * * // Run tokenization * const texts = ['a photo of 2 cats', 'a photo of 2 dogs']; * const text_inputs = tokenizer(texts, { padding: 'max_length', truncation: true }); - * + * * // Compute embeddings * const { pooler_output } = await text_model(text_inputs); * // Tensor { @@ -4136,20 +4136,20 @@ export class SiglipTextModel extends SiglipPreTrainedModel { /** * The vision model from SigLIP without any head or projection on top. - * + * * **Example:** Compute vision embeddings with `SiglipVisionModel`. - * + * * ```javascript * import { AutoProcessor, SiglipVisionModel, RawImage} from '@huggingface/transformers'; - * + * * // Load processor and vision model * const processor = await AutoProcessor.from_pretrained('Xenova/siglip-base-patch16-224'); * const vision_model = await SiglipVisionModel.from_pretrained('Xenova/siglip-base-patch16-224'); - * + * * // Read image and run processor * const image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); * const image_inputs = await processor(image); - * + * * // Compute embeddings * const { pooler_output } = await vision_model(image_inputs); * // Tensor { @@ -4251,25 +4251,25 @@ export class CLIPSegModel extends CLIPSegPreTrainedModel { } /** * CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation. - * + * * **Example:** Perform zero-shot image segmentation with a `CLIPSegForImageSegmentation` model. - * + * * ```javascript * import { AutoTokenizer, AutoProcessor, CLIPSegForImageSegmentation, RawImage } from '@huggingface/transformers'; - * + * * // Load tokenizer, processor, and model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clipseg-rd64-refined'); * const processor = await AutoProcessor.from_pretrained('Xenova/clipseg-rd64-refined'); * const model = await CLIPSegForImageSegmentation.from_pretrained('Xenova/clipseg-rd64-refined'); - * + * * // Run tokenization * const texts = ['a glass', 'something to fill', 'wood', 'a jar']; * const text_inputs = tokenizer(texts, { padding: true, truncation: true }); - * + * * // Read image and run processor * const image = await RawImage.read('https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true'); * const image_inputs = await processor(image); - * + * * // Run model with both text and pixel inputs * const { logits } = await model({ ...text_inputs, ...image_inputs }); * // logits: Tensor { @@ -4279,7 +4279,7 @@ export class CLIPSegModel extends CLIPSegPreTrainedModel { } * // size: 495616 * // } * ``` - * + * * You can visualize the predictions as follows: * ```javascript * const preds = logits @@ -4288,7 +4288,7 @@ export class CLIPSegModel extends CLIPSegPreTrainedModel { } * .mul_(255) * .round_() * .to('uint8'); - * + * * for (let i = 0; i < preds.dims[0]; ++i) { * const img = RawImage.fromTensor(preds[i]); * img.save(`prediction_${i}.png`); @@ -4553,7 +4553,7 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { * text height position_ids: [3, 4, 5, 6, 7] * text width position_ids: [3, 4, 5, 6, 7] * Here we calculate the text start position_ids as the max vision position_ids plus 1. - * + * * @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`. * @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`. * @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`. @@ -4939,22 +4939,22 @@ export class VitMattePreTrainedModel extends PreTrainedModel { } /** * ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes. - * + * * **Example:** Perform image matting with a `VitMatteForImageMatting` model. * ```javascript * import { AutoProcessor, VitMatteForImageMatting, RawImage } from '@huggingface/transformers'; - * + * * // Load processor and model * const processor = await AutoProcessor.from_pretrained('Xenova/vitmatte-small-distinctions-646'); * const model = await VitMatteForImageMatting.from_pretrained('Xenova/vitmatte-small-distinctions-646'); - * + * * // Load image and trimap * const image = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png'); * const trimap = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png'); - * + * * // Prepare image + trimap for the model * const inputs = await processor(image, trimap); - * + * * // Predict alpha matte * const { alphas } = await model(inputs); * // Tensor { @@ -4964,14 +4964,14 @@ export class VitMattePreTrainedModel extends PreTrainedModel { } * // data: Float32Array(614400) [ 0.9894027709960938, 0.9970508813858032, ... ] * // } * ``` - * + * * You can visualize the alpha matte as follows: * ```javascript * import { Tensor, cat } from '@huggingface/transformers'; - * + * * // Visualize predicted alpha matte * const imageTensor = image.toTensor(); - * + * * // Convert float (0-1) alpha matte to uint8 (0-255) * const alphaChannel = alphas * .squeeze(0) @@ -4979,10 +4979,10 @@ export class VitMattePreTrainedModel extends PreTrainedModel { } * .clamp_(0, 255) * .round_() * .to('uint8'); - * + * * // Concatenate original image with predicted alpha * const imageData = cat([imageTensor, alphaChannel], 0); - * + * * // Save output image * const outputImage = RawImage.fromTensor(imageData); * outputImage.save('output.png'); @@ -5235,25 +5235,25 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel { } /** * Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration. - * + * * **Example:** Super-resolution w/ `Xenova/swin2SR-classical-sr-x2-64`. - * + * * ```javascript * import { AutoProcessor, Swin2SRForImageSuperResolution, RawImage } from '@huggingface/transformers'; - * + * * // Load processor and model * const model_id = 'Xenova/swin2SR-classical-sr-x2-64'; * const processor = await AutoProcessor.from_pretrained(model_id); * const model = await Swin2SRForImageSuperResolution.from_pretrained(model_id); - * + * * // Prepare model inputs * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/butterfly.jpg'; * const image = await RawImage.fromURL(url); * const inputs = await processor(image); - * + * * // Run model * const outputs = await model(inputs); - * + * * // Convert Tensor to RawImage * const output = outputs.reconstruction.squeeze().clamp_(0, 1).mul_(255).round_().to('uint8'); * const outputImage = RawImage.fromTensor(output); @@ -5278,32 +5278,32 @@ export class DPTModel extends DPTPreTrainedModel { } /** * DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. - * + * * **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`. * ```javascript * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate_4d } from '@huggingface/transformers'; - * + * * // Load model and processor * const model_id = 'Xenova/dpt-hybrid-midas'; * const model = await DPTForDepthEstimation.from_pretrained(model_id); * const processor = await AutoProcessor.from_pretrained(model_id); - * + * * // Load image from URL * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; * const image = await RawImage.read(url); - * + * * // Prepare image for the model * const inputs = await processor(image); - * + * * // Run model * const { predicted_depth } = await model(inputs); - * + * * // Interpolate to original size * const prediction = (await interpolate_4d(predicted_depth.unsqueeze(1), { * size: image.size.reverse(), * mode: 'bilinear', * })).squeeze(1); - * + * * // Visualize the prediction * const min = prediction.min().item(); * const max = prediction.max().item(); @@ -5358,28 +5358,28 @@ export class GLPNModel extends GLPNPreTrainedModel { } /** * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate_4d } from '@huggingface/transformers'; - * + * * // Load model and processor * const model_id = 'Xenova/glpn-kitti'; * const model = await GLPNForDepthEstimation.from_pretrained(model_id); * const processor = await AutoProcessor.from_pretrained(model_id); - * + * * // Load image from URL * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; * const image = await RawImage.read(url); - * + * * // Prepare image for the model * const inputs = await processor(image); - * + * * // Run model * const { predicted_depth } = await model(inputs); - * + * * // Interpolate to original size * const prediction = (await interpolate_4d(predicted_depth.unsqueeze(1), { * size: image.size.reverse(), * mode: 'bilinear', * })).squeeze(1); - * + * * // Visualize the prediction * const min = prediction.min().item(); * const max = prediction.max().item(); @@ -5401,56 +5401,56 @@ export class DonutSwinPreTrainedModel extends PreTrainedModel { } /** * The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top. - * + * * **Example:** Step-by-step Document Parsing. - * + * * ```javascript * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@huggingface/transformers'; - * + * * // Choose model to use * const model_id = 'Xenova/donut-base-finetuned-cord-v2'; - * + * * // Prepare image inputs * const processor = await AutoProcessor.from_pretrained(model_id); * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png'; * const image = await RawImage.read(url); * const image_inputs = await processor(image); - * + * * // Prepare decoder inputs * const tokenizer = await AutoTokenizer.from_pretrained(model_id); * const task_prompt = ''; * const decoder_input_ids = tokenizer(task_prompt, { * add_special_tokens: false, * }).input_ids; - * + * * // Create the model * const model = await AutoModelForVision2Seq.from_pretrained(model_id); - * + * * // Run inference * const output = await model.generate(image_inputs.pixel_values, { * decoder_input_ids, * max_length: model.config.decoder.max_position_embeddings, * }); - * + * * // Decode output * const decoded = tokenizer.batch_decode(output)[0]; * // CINNAMON SUGAR 17,000 1 x 17,000 17,000 17,000 20,000 3,000 * ``` - * + * * **Example:** Step-by-step Document Visual Question Answering (DocVQA) - * + * * ```javascript * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@huggingface/transformers'; - * + * * // Choose model to use * const model_id = 'Xenova/donut-base-finetuned-docvqa'; - * + * * // Prepare image inputs * const processor = await AutoProcessor.from_pretrained(model_id); * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/invoice.png'; * const image = await RawImage.read(url); * const image_inputs = await processor(image); - * + * * // Prepare decoder inputs * const tokenizer = await AutoTokenizer.from_pretrained(model_id); * const question = 'What is the invoice number?'; @@ -5458,16 +5458,16 @@ export class DonutSwinPreTrainedModel extends PreTrainedModel { } * const decoder_input_ids = tokenizer(task_prompt, { * add_special_tokens: false, * }).input_ids; - * + * * // Create the model * const model = await AutoModelForVision2Seq.from_pretrained(model_id); - * + * * // Run inference * const output = await model.generate(image_inputs.pixel_values, { * decoder_input_ids, * max_length: model.config.decoder.max_position_embeddings, * }); - * + * * // Decode output * const decoded = tokenizer.batch_decode(output)[0]; * // What is the invoice number? us-001 @@ -5600,21 +5600,21 @@ export class SamPreTrainedModel extends PreTrainedModel { } /** * Segment Anything Model (SAM) for generating segmentation masks, given an input image * and optional 2D location and bounding boxes. - * + * * **Example:** Perform mask generation w/ `Xenova/sam-vit-base`. * ```javascript * import { SamModel, AutoProcessor, RawImage } from '@huggingface/transformers'; - * + * * const model = await SamModel.from_pretrained('Xenova/sam-vit-base'); * const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base'); - * + * * const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'; * const raw_image = await RawImage.read(img_url); * const input_points = [[[450, 600]]] // 2D localization of a window - * + * * const inputs = await processor(raw_image, { input_points }); * const outputs = await model(inputs); - * + * * const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes); * // [ * // Tensor { @@ -5648,7 +5648,7 @@ export class SamModel extends SamPreTrainedModel { async get_image_embeddings({ pixel_values }) { // in: // - pixel_values: tensor.float32[batch_size,3,1024,1024] - // + // // out: // - image_embeddings: tensor.float32[batch_size,256,64,64] // - image_positional_embeddings: tensor.float32[batch_size,256,64,64] @@ -5745,6 +5745,167 @@ export class SamImageSegmentationOutput extends ModelOutput { ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +export class U2NetPreTrainedModel extends PreTrainedModel { } + +/** + * # Description + * U-2-Net is a deep learning model designed for image segmentation tasks, + * particularly for generating detailed masks. It leverages a unique + * architecture with nested U-blocks that allow the model to capture both + * high-level semantic features and fine-grained details. + * + * # Example + * Perform mask generation with the `BritishWerewolf/U-2-Net`. + * ```javascript + * import { AutoModel, AutoProcessor, RawImage } from '@huggingface/transformers'; + * + * const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'; + * const image = await RawImage.read(img_url); + * + * const processor = await AutoProcessor.from_pretrained('BritishWerewolf/U-2-Net'); + * const processed = await processor(image); + * + * const model = await AutoModel.from_pretrained('BritishWerewolf/U-2-Net', { + * dtype: 'fp32', + * }); + * + * const output = await model({ input: processed.pixel_values }); + * // { + * // mask: Tensor { + * // dims: [ 1, 320, 320 ], + * // type: 'uint8', + * // data: Uint8Array(102400) [ ... ], + * // size: 102400 + * // } + * // } + * ``` + */ +export class U2NetModel extends U2NetPreTrainedModel { + constructor(config, sessions, configs) { + super(config, sessions, configs); + } + + static async from_pretrained(pretrained_model_name_or_path, { + progress_callback = null, + config = null, + cache_dir = null, + local_files_only = false, + revision = 'main', + model_file_name = null, + subfolder = 'onnx', + device = null, + dtype = null, + use_external_data_format = null, + session_options = {}, + } = {}) { + return super.from_pretrained(pretrained_model_name_or_path, { + progress_callback, + config, + cache_dir, + local_files_only, + revision, + model_file_name, + subfolder, + device, + dtype, + use_external_data_format, + session_options, + }); + } + + /** + * @typedef {{ input?: Tensor, input_image?: Tensor, [key: string]: any }} U2NetModelInput + * @param {U2NetModelInput} model_inputs Object containing the model inputs. + * @returns {Promise} The output of the model. + * @throws {Error} If the model session is not found. + * @throws {Error} If the model doesn't have an input Tensor. + */ + async forward(model_inputs) { + // Check that 'model' session exists + if (!this.sessions.model) { + throw new Error("Model session not found. Ensure that the 'model' session is loaded correctly."); + } + + // @ts-ignore + const input_name = this.config.input_name; + // @ts-ignore + const output_name = this.config.output_composite; + + // The provided input should either have a key of `input` or be based on + // the config `input_name` key. + if (!("input" in model_inputs) && !(input_name in model_inputs)) { + throw new Error(`Model requires an input Tensor. Ensure that a Tensor is provided with a key of \`input\` or \`${input_name}\`.`); + } + + // We assume that the input will be called `input`, but will fallback to + // the value that the model expects. + const input = { + [input_name]: model_inputs['input'] ?? model_inputs[input_name], + }; + + const outputs = await sessionRun(this.sessions.model, input); + + // Post-process the mask output. + return await this.postProcessMask(outputs[output_name]); + } + + /** + * Post-processes the U2Net model's raw output. + * @param {Tensor} mask The raw mask output from the model. + * @returns {Promise} Binary mask with applied thresholding. + */ + async postProcessMask(mask) { + // Squeeze to remove the batch size. + mask = mask.squeeze(0); + mask = mask.sigmoid(); + + // Multiply to convert from decimal to pixel value. + const maxValue = 255; + mask = mask.mul(maxValue); + + // Erode, dilate, then Gaussian blur the Tensor. + mask = mask.to('uint8'); + mask = await mask.morph('OPEN', 3, 'ELLIPSE'); + mask = await RawImage.fromTensor(mask) + .gaussianBlur(5, 2) + .then(i => i.toTensor()); + + // Perform binary thresholding. + const threshold = 0.54; // value from 0-1 + mask = mask.map(value => value < (threshold * maxValue) ? 0 : maxValue); + + return mask; + } + + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs Model inputs + * @returns {Promise} Object containing segmentation outputs + */ + async _call(model_inputs) { + return new U2NetImageSegmentationOutput(await super._call(model_inputs)); + } +} + + +/** + * Base class for Segment-Anything model's output. + */ +export class U2NetImageSegmentationOutput extends ModelOutput { + /** + * @param {Tensor} output The output of the model. + */ + constructor(output) { + super(); + this.mask = output; + } +} +////////////////////////////////////////////////// + + + ////////////////////////////////////////////////// // MarianMT models export class MarianPreTrainedModel extends PreTrainedModel { }; @@ -5769,17 +5930,17 @@ export class Wav2Vec2PreTrainedModel extends PreTrainedModel { }; /** * The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top. - * + * * **Example:** Load and run a `Wav2Vec2Model` for feature extraction. - * + * * ```javascript * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers'; - * + * * // Read and preprocess audio * const processor = await AutoProcessor.from_pretrained('Xenova/mms-300m'); * const audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000); * const inputs = await processor(audio); - * + * * // Run model with inputs * const model = await AutoModel.from_pretrained('Xenova/mms-300m'); * const output = await model(inputs); @@ -5844,22 +6005,22 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { } /** * PyAnnote Model with a frame classification head on top for tasks like Speaker Diarization. - * + * * **Example:** Load and run a `PyAnnoteForAudioFrameClassification` for speaker diarization. - * + * * ```javascript * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers'; - * + * * // Load model and processor * const model_id = 'onnx-community/pyannote-segmentation-3.0'; * const model = await AutoModelForAudioFrameClassification.from_pretrained(model_id); * const processor = await AutoProcessor.from_pretrained(model_id); - * + * * // Read and preprocess audio * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.wav'; * const audio = await read_audio(url, processor.feature_extractor.config.sampling_rate); * const inputs = await processor(audio); - * + * * // Run model with inputs * const { logits } = await model(inputs); * // { @@ -5870,7 +6031,7 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { } * // size: 5369 * // } * // } - * + * * const result = processor.post_process_speaker_diarization(logits, audio.length); * // [ * // [ @@ -5879,7 +6040,7 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { } * // ... * // ] * // ] - * + * * // Display result * console.table(result[0], ['start', 'end', 'id', 'confidence']); * // ┌─────────┬────────────────────┬────────────────────┬────┬─────────────────────┐ @@ -6052,17 +6213,17 @@ export class HubertPreTrainedModel extends PreTrainedModel { } /** * The bare Hubert Model transformer outputting raw hidden-states without any specific head on top. - * + * * **Example:** Load and run a `HubertModel` for feature extraction. - * + * * ```javascript * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers'; - * + * * // Read and preprocess audio * const processor = await AutoProcessor.from_pretrained('Xenova/hubert-base-ls960'); * const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav', 16000); * const inputs = await processor(audio); - * + * * // Load and run model with inputs * const model = await AutoModel.from_pretrained('Xenova/hubert-base-ls960'); * const output = await model(inputs); @@ -6116,17 +6277,17 @@ export class WavLMPreTrainedModel extends PreTrainedModel { }; /** * The bare WavLM Model transformer outputting raw hidden-states without any specific head on top. - * + * * **Example:** Load and run a `WavLMModel` for feature extraction. - * + * * ```javascript * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers'; - * + * * // Read and preprocess audio * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base'); * const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav', 16000); * const inputs = await processor(audio); - * + * * // Run model with inputs * const model = await AutoModel.from_pretrained('Xenova/wavlm-base'); * const output = await model(inputs); @@ -6172,17 +6333,17 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel { /** * WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. - * + * * **Example:** Extract speaker embeddings with `WavLMForXVector`. * ```javascript * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers'; - * + * * // Read and preprocess audio * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv'); * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; * const audio = await read_audio(url, 16000); * const inputs = await processor(audio); - * + * * // Run model with inputs * const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv'); * const outputs = await model(inputs); @@ -6215,17 +6376,17 @@ export class WavLMForXVector extends WavLMPreTrainedModel { /** * WavLM Model with a frame classification head on top for tasks like Speaker Diarization. - * + * * **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`. * ```javascript * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers'; - * + * * // Read and preprocess audio * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd'); * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; * const audio = await read_audio(url, 16000); * const inputs = await processor(audio); - * + * * // Run model with inputs * const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd'); * const { logits } = await model(inputs); @@ -6237,7 +6398,7 @@ export class WavLMForXVector extends WavLMPreTrainedModel { * // size: 1098 * // } * // } - * + * * const labels = logits[0].sigmoid().tolist().map( * frames => frames.map(speaker => speaker > 0.5 ? 1 : 0) * ); @@ -6278,20 +6439,20 @@ export class SpeechT5Model extends SpeechT5PreTrainedModel { }; /** * SpeechT5 Model with a speech encoder and a text decoder. - * + * * **Example:** Generate speech from text with `SpeechT5ForSpeechToText`. * ```javascript * import { AutoTokenizer, AutoProcessor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, Tensor } from '@huggingface/transformers'; - * + * * // Load the tokenizer and processor * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/speecht5_tts'); * const processor = await AutoProcessor.from_pretrained('Xenova/speecht5_tts'); - * + * * // Load the models * // NOTE: We use the full-precision versions as they are more accurate * const model = await SpeechT5ForTextToSpeech.from_pretrained('Xenova/speecht5_tts', { dtype: 'fp32' }); * const vocoder = await SpeechT5HifiGan.from_pretrained('Xenova/speecht5_hifigan', { dtype: 'fp32' }); - * + * * // Load speaker embeddings from URL * const speaker_embeddings_data = new Float32Array( * await (await fetch('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin')).arrayBuffer() @@ -6301,10 +6462,10 @@ export class SpeechT5Model extends SpeechT5PreTrainedModel { }; * speaker_embeddings_data, * [1, speaker_embeddings_data.length] * ) - * + * * // Run tokenization * const { input_ids } = tokenizer('Hello, my dog is cute'); - * + * * // Generate waveform * const { waveform } = await model.generate_speech(input_ids, speaker_embeddings, { vocoder }); * console.log(waveform) @@ -6421,7 +6582,7 @@ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel { /** * HiFi-GAN vocoder. - * + * * See [SpeechT5ForSpeechToText](./models#module_models.SpeechT5ForSpeechToText) for example usage. */ export class SpeechT5HifiGan extends PreTrainedModel { @@ -6489,20 +6650,20 @@ export class ClapModel extends ClapPreTrainedModel { } /** * CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output). - * + * * **Example:** Compute text embeddings with `ClapTextModelWithProjection`. - * + * * ```javascript * import { AutoTokenizer, ClapTextModelWithProjection } from '@huggingface/transformers'; - * + * * // Load tokenizer and text model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clap-htsat-unfused'); * const text_model = await ClapTextModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused'); - * + * * // Run tokenization * const texts = ['a sound of a cat', 'a sound of a dog']; * const text_inputs = tokenizer(texts, { padding: true, truncation: true }); - * + * * // Compute embeddings * const { text_embeds } = await text_model(text_inputs); * // Tensor { @@ -6526,20 +6687,20 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel { /** * CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output). - * + * * **Example:** Compute audio embeddings with `ClapAudioModelWithProjection`. - * + * * ```javascript * import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@huggingface/transformers'; - * + * * // Load processor and audio model * const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused'); * const audio_model = await ClapAudioModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused'); - * + * * // Read audio and run processor * const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav'); * const audio_inputs = await processor(audio); - * + * * // Compute embeddings * const { audio_embeds } = await audio_model(audio_inputs); * // Tensor { @@ -6569,18 +6730,18 @@ export class VitsPreTrainedModel extends PreTrainedModel { } /** * The complete VITS model, for text-to-speech synthesis. - * + * * **Example:** Generate speech from text with `VitsModel`. * ```javascript * import { AutoTokenizer, VitsModel } from '@huggingface/transformers'; - * + * * // Load the tokenizer and model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/mms-tts-eng'); * const model = await VitsModel.from_pretrained('Xenova/mms-tts-eng'); - * + * * // Run tokenization * const inputs = tokenizer('I love transformers'); - * + * * // Generate waveform * const { waveform } = await model(inputs); * // Tensor { @@ -6678,21 +6839,21 @@ export class MusicgenForCausalLM extends MusicgenPreTrainedModel { } /** * The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, * for music generation tasks with one or both of text and audio prompts. - * + * * **Example:** Generate music from text with `Xenova/musicgen-small`. * ```javascript * import { AutoTokenizer, MusicgenForConditionalGeneration } from '@huggingface/transformers'; - * + * * // Load tokenizer and model * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/musicgen-small'); * const model = await MusicgenForConditionalGeneration.from_pretrained( * 'Xenova/musicgen-small', { dtype: 'fp32' } * ); - * + * * // Prepare text input * const prompt = '80s pop track with bassy drums and synth'; * const inputs = tokenizer(prompt); - * + * * // Generate audio * const audio_values = await model.generate({ * ...inputs, @@ -6700,11 +6861,11 @@ export class MusicgenForCausalLM extends MusicgenPreTrainedModel { } * do_sample: true, * guidance_scale: 3, * }); - * + * * // (Optional) Write the output to a WAV file * import wavefile from 'wavefile'; * import fs from 'fs'; - * + * * const wav = new wavefile.WaveFile(); * wav.fromScratch(1, model.config.audio_encoder.sampling_rate, '32f', audio_values.data); * fs.writeFileSync('musicgen_out.wav', wav.toBuffer()); @@ -7119,7 +7280,7 @@ export class PretrainedMixin { static MODEL_CLASS_MAPPINGS = null; /** - * Whether to attempt to instantiate the base class (`PretrainedModel`) if + * Whether to attempt to instantiate the base class (`PretrainedModel`) if * the model type is not found in the mapping. */ static BASE_IF_FAIL = false; @@ -7252,6 +7413,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]], ['mobilenet_v3', ['MobileNetV3Model', MobileNetV3Model]], ['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]], + ['u2net', ['U2NetModel', U2NetModel]], ['maskformer', ['MaskFormerModel', MaskFormerModel]], ['mgp-str', ['MgpstrForSceneTextRecognition', MgpstrForSceneTextRecognition]], @@ -7672,7 +7834,7 @@ for (const [name, model, type] of CUSTOM_MAPPING) { /** * Helper class which is used to instantiate pretrained models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); */ @@ -7686,7 +7848,7 @@ export class AutoModel extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained sequence classification models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForSequenceClassification.from_pretrained('Xenova/distilbert-base-uncased-finetuned-sst-2-english'); */ @@ -7697,7 +7859,7 @@ export class AutoModelForSequenceClassification extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained token classification models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForTokenClassification.from_pretrained('Xenova/distilbert-base-multilingual-cased-ner-hrl'); */ @@ -7708,7 +7870,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained sequence-to-sequence models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); */ @@ -7719,7 +7881,7 @@ export class AutoModelForSeq2SeqLM extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained sequence-to-sequence speech-to-text models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForSpeechSeq2Seq.from_pretrained('openai/whisper-tiny.en'); */ @@ -7730,7 +7892,7 @@ export class AutoModelForSpeechSeq2Seq extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained sequence-to-sequence text-to-spectrogram models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForTextToSpectrogram.from_pretrained('microsoft/speecht5_tts'); */ @@ -7741,7 +7903,7 @@ export class AutoModelForTextToSpectrogram extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained text-to-waveform models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForTextToSpectrogram.from_pretrained('facebook/mms-tts-eng'); */ @@ -7752,7 +7914,7 @@ export class AutoModelForTextToWaveform extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained causal language models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForCausalLM.from_pretrained('Xenova/gpt2'); */ @@ -7763,7 +7925,7 @@ export class AutoModelForCausalLM extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained masked language models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForMaskedLM.from_pretrained('Xenova/bert-base-uncased'); */ @@ -7774,7 +7936,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained question answering models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForQuestionAnswering.from_pretrained('Xenova/distilbert-base-cased-distilled-squad'); */ @@ -7785,7 +7947,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained vision-to-sequence models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForVision2Seq.from_pretrained('Xenova/vit-gpt2-image-captioning'); */ @@ -7796,7 +7958,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained image classification models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForImageClassification.from_pretrained('Xenova/vit-base-patch16-224'); */ @@ -7807,7 +7969,7 @@ export class AutoModelForImageClassification extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained image segmentation models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForImageSegmentation.from_pretrained('Xenova/detr-resnet-50-panoptic'); */ @@ -7818,7 +7980,7 @@ export class AutoModelForImageSegmentation extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained image segmentation models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForSemanticSegmentation.from_pretrained('nvidia/segformer-b3-finetuned-cityscapes-1024-1024'); */ @@ -7829,7 +7991,7 @@ export class AutoModelForSemanticSegmentation extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained universal image segmentation models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForUniversalSegmentation.from_pretrained('hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation'); */ @@ -7840,7 +8002,7 @@ export class AutoModelForUniversalSegmentation extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForObjectDetection.from_pretrained('Xenova/detr-resnet-50'); */ @@ -7856,7 +8018,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin { /** * Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. - * + * * @example * let model = await AutoModelForMaskGeneration.from_pretrained('Xenova/sam-vit-base'); */ diff --git a/src/models/image_processors.js b/src/models/image_processors.js index 95f275893..b753f8eec 100644 --- a/src/models/image_processors.js +++ b/src/models/image_processors.js @@ -34,6 +34,7 @@ export * from './segformer/image_processing_segformer.js' export * from './siglip/image_processing_siglip.js' export * from './smolvlm/image_processing_smolvlm.js' export * from './swin2sr/image_processing_swin2sr.js' +export * from './u2net/image_processing_u2net.js' export * from './vit/image_processing_vit.js' export * from './vitmatte/image_processing_vitmatte.js' export * from './vitpose/image_processing_vitpose.js' diff --git a/src/models/processors.js b/src/models/processors.js index e64273123..0b6db00cf 100644 --- a/src/models/processors.js +++ b/src/models/processors.js @@ -13,6 +13,7 @@ export * from './qwen2_vl/processing_qwen2_vl.js'; export * from './sam/processing_sam.js'; export * from './smolvlm/processing_smolvlm.js'; export * from './speecht5/processing_speecht5.js'; +export * from './u2net/processing_u2net.js'; export * from './ultravox/processing_ultravox.js'; export * from './wav2vec2/processing_wav2vec2.js'; export * from './wav2vec2_with_lm/processing_wav2vec2_with_lm.js'; diff --git a/src/models/u2net/image_processing_u2net.js b/src/models/u2net/image_processing_u2net.js new file mode 100644 index 000000000..ff39de7f9 --- /dev/null +++ b/src/models/u2net/image_processing_u2net.js @@ -0,0 +1,52 @@ +import { ImageProcessor } from "../../base/image_processors_utils.js"; + +/** + * @typedef {Object} ImageProcessorConfig A configuration object used to create an image processor. + * @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. + * @property {number[]} [image_mean] The mean values for image normalization. + * @property {number[]} [image_std] The standard deviation values for image normalization. + * @property {boolean} [do_rescale] Whether to rescale the image pixel values to the [0,1] range. + * @property {number} [rescale_factor] The factor to use for rescaling the image pixel values. + * @property {boolean} [do_normalize] Whether to normalize the image pixel values. + * @property {boolean} [do_resize] Whether to resize the image. + * @property {number} [resample] What method to use for resampling. + * @property {number|Object} [size] The size to resize the image to. + * @property {number|Object} [image_size] The size to resize the image to (same as `size`). + * @property {boolean} [do_flip_channel_order=false] Whether to flip the color channels from RGB to BGR. + * Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. + * @property {boolean} [do_center_crop] Whether to center crop the image to the specified `crop_size`. + * Can be overridden by `do_center_crop` in the `preprocess` method. + * @property {boolean} [do_thumbnail] Whether to resize the image using thumbnail method. + * @property {boolean} [keep_aspect_ratio] If `true`, the image is resized to the largest possible size such that the aspect ratio is preserved. + * Can be overridden by `keep_aspect_ratio` in `preprocess`. + * @property {number} [ensure_multiple_of] If `do_resize` is `true`, the image is resized to a size that is a multiple of this value. + * Can be overridden by `ensure_multiple_of` in `preprocess`. + * + * @property {number[]} [mean] The mean values for image normalization (same as `image_mean`). + * @property {number[]} [std] The standard deviation values for image normalization (same as `image_std`). + */ +export class U2NetImageProcessor extends ImageProcessor { + /** + * Pad the image by a certain amount. + * @param {Float32Array} pixelData The pixel data to pad. + * @param {number[]} imgDims The dimensions of the image (height, width, channels). + * @param {{width:number; height:number}|number|'square'} padSize The dimensions of the padded image. + * @param {Object} options The options for padding. + * @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add. + * @param {boolean} [options.center=true] Whether to center the image. + * @param {number|number[]} [options.constant_values=0] The constant value to use for padding. + * @returns {[Float32Array, number[]]} The padded pixel data and image dimensions. + */ + pad_image(pixelData, imgDims, padSize, { + mode = 'constant', + center = true, + constant_values = 0, + } = {}) { + return super.pad_image(pixelData, imgDims, padSize, { + mode, + center, + constant_values, + }); + } +} + diff --git a/src/models/u2net/processing_u2net.js b/src/models/u2net/processing_u2net.js new file mode 100644 index 000000000..36a2a010d --- /dev/null +++ b/src/models/u2net/processing_u2net.js @@ -0,0 +1,10 @@ +import { Processor } from "../../base/processing_utils.js"; +import { U2NetImageProcessor } from "./image_processing_u2net.js"; + +export class U2NetProcessor extends Processor { + static image_processor_class = U2NetImageProcessor + + async _call(...args) { + return await this.image_processor(...args); + } +} diff --git a/src/utils/image.js b/src/utils/image.js index 40f51625e..de7541aa3 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -345,8 +345,8 @@ export class RawImage { /** * Resize the image to the given dimensions. This method uses the canvas API to perform the resizing. - * @param {number} width The width of the new image. `null` or `-1` will preserve the aspect ratio. - * @param {number} height The height of the new image. `null` or `-1` will preserve the aspect ratio. + * @param {number | null} width The width of the new image. `null` or `-1` will preserve the aspect ratio. + * @param {number | null} height The height of the new image. `null` or `-1` will preserve the aspect ratio. * @param {Object} options Additional options for resizing. * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. * @returns {Promise} `this` to support chaining. @@ -699,7 +699,7 @@ export class RawImage { /** * Split this image into individual bands. This method returns an array of individual image bands from an image. * For example, splitting an "RGB" image creates three new images each containing a copy of one of the original bands (red, green, blue). - * + * * Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split). * @returns {RawImage[]} An array containing bands. */ diff --git a/src/utils/tensor.js b/src/utils/tensor.js index caecac814..33902ca54 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -237,6 +237,7 @@ export class Tensor { /** * Return a new Tensor with a callback function applied to each element. * @param {Function} callback - The function to apply to each element. It should take three arguments: + * @param {function(number, number, DataArray): number} callback - The function to apply to each element. It should take three arguments: * the current element, its index, and the tensor's data array. * @returns {Tensor} A new Tensor with the callback function applied to each element. */ @@ -247,6 +248,7 @@ export class Tensor { /** * Apply a callback function to each element of the tensor in place. * @param {Function} callback - The function to apply to each element. It should take three arguments: + * @param {function(number, number, DataArray): number} callback - The function to apply to each element. It should take three arguments: * the current element, its index, and the tensor's data array. * @returns {Tensor} Returns `this`. */