diff --git a/README.md b/README.md index 33e3191dd..4858f07ef 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| +| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)
[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) | | [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index f481fbf96..3e1a3ea32 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -22,6 +22,7 @@ | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| +| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)
[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) | | [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) | | [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) | diff --git a/src/models.js b/src/models.js index 976d1c000..83c822d57 100644 --- a/src/models.js +++ b/src/models.js @@ -5223,6 +5223,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel { return new SequenceClassifierOutput(await super._call(model_inputs)); } } +export class SwinForSemanticSegmentation extends SwinPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -6825,6 +6826,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode return new SequenceClassifierOutput(await super._call(model_inputs)); } } + +export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -6848,6 +6851,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode return new SequenceClassifierOutput(await super._call(model_inputs)); } } +export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -6871,6 +6875,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode return new SequenceClassifierOutput(await super._call(model_inputs)); } } +export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -6894,6 +6899,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode return new SequenceClassifierOutput(await super._call(model_inputs)); } } +export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -7158,20 +7164,29 @@ export class PretrainedMixin { if (!this.MODEL_CLASS_MAPPINGS) { throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name); } - + const model_type = options.config.model_type; for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) { - const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type); + let modelInfo = MODEL_CLASS_MAPPING.get(model_type); if (!modelInfo) { - continue; // Item not found in this mapping + // As a fallback, we check if model_type is specified as the exact class + for (const cls of MODEL_CLASS_MAPPING.values()) { + if (cls[0] === model_type) { + modelInfo = cls; + break; + } + } + if (!modelInfo) continue; // Item not found in this mapping } return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options); } if (this.BASE_IF_FAIL) { - console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`); + if (!(CUSTOM_ARCHITECTURES.has(model_type))) { + console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`); + } return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options); } else { - throw Error(`Unsupported model type: ${options.config.model_type}`) + throw Error(`Unsupported model type: ${model_type}`) } } } @@ -7524,6 +7539,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([ ['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]], ['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]], + + ['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]], + ['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]], + ['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]], + ['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]], + ['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]], ]); const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([ @@ -7668,6 +7689,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) { MODEL_NAME_TO_CLASS_MAPPING.set(name, model); } +const CUSTOM_ARCHITECTURES = new Map([ + ['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES], + ['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES], + ['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES], + ['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES], +]); +for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) { + mapping.set(name, ['PreTrainedModel', PreTrainedModel]) + MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly); + MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name); + MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel); +} + /** * Helper class which is used to instantiate pretrained models with the `from_pretrained` function. diff --git a/src/pipelines.js b/src/pipelines.js index 649b00a49..bf17f7017 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2095,7 +2095,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image /** * @typedef {Object} ImageSegmentationPipelineOutput - * @property {string} label The label of the segment. + * @property {string|null} label The label of the segment. * @property {number|null} score The score of the segment. * @property {RawImage} mask The mask of the segment. * @@ -2165,14 +2165,30 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi const preparedImages = await prepareImages(images); const imageSizes = preparedImages.map(x => [x.height, x.width]); - const { pixel_values, pixel_mask } = await this.processor(preparedImages); - const output = await this.model({ pixel_values, pixel_mask }); + const inputs = await this.processor(preparedImages); + + const { inputNames, outputNames } = this.model.sessions['model']; + if (!inputNames.includes('pixel_values')) { + if (inputNames.length !== 1) { + throw Error(`Expected a single input name, but got ${inputNames.length} inputs: ${inputNames}.`); + } + + const newName = inputNames[0]; + if (newName in inputs) { + throw Error(`Input name ${newName} already exists in the inputs.`); + } + // To ensure compatibility with certain background-removal models, + // we may need to perform a mapping of input to output names + inputs[newName] = inputs.pixel_values; + } + + const output = await this.model(inputs); let fn = null; if (subtask !== null) { fn = this.subtasks_mapping[subtask]; - } else { - for (let [task, func] of Object.entries(this.subtasks_mapping)) { + } else if (this.processor.image_processor) { + for (const [task, func] of Object.entries(this.subtasks_mapping)) { if (func in this.processor.image_processor) { fn = this.processor.image_processor[func].bind(this.processor.image_processor); subtask = task; @@ -2186,7 +2202,23 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi /** @type {ImageSegmentationPipelineOutput[]} */ const annotation = []; - if (subtask === 'panoptic' || subtask === 'instance') { + if (!subtask) { + // Perform standard image segmentation + const result = output[outputNames[0]]; + for (let i = 0; i < imageSizes.length; ++i) { + const size = imageSizes[i]; + const item = result[i]; + if (item.data.some(x => x < 0 || x > 1)) { + item.sigmoid_(); + } + const mask = await RawImage.fromTensor(item.mul_(255).to('uint8')).resize(size[1], size[0]); + annotation.push({ + label: null, + score: null, + mask + }); + } + } else if (subtask === 'panoptic' || subtask === 'instance') { const processed = fn( output, threshold, @@ -2242,6 +2274,63 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi } } + +/** + * @typedef {Object} BackgroundRemovalPipelineOptions Parameters specific to image segmentation pipelines. + * + * @callback BackgroundRemovalPipelineCallback Segment the input images. + * @param {ImagePipelineInputs} images The input images. + * @param {BackgroundRemovalPipelineOptions} [options] The options to use for image segmentation. + * @returns {Promise} The images with the background removed. + * + * @typedef {ImagePipelineConstructorArgs & BackgroundRemovalPipelineCallback & Disposable} BackgroundRemovalPipelineType + */ + +/** + * Background removal pipeline using certain `AutoModelForXXXSegmentation`. + * This pipeline removes the backgrounds of images. + * + * **Example:** Perform background removal with `Xenova/modnet`. + * ```javascript + * const segmenter = await pipeline('background-removal', 'Xenova/modnet'); + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/portrait-of-woman_small.jpg'; + * const output = await segmenter(url); + * // [ + * // RawImage { data: Uint8ClampedArray(648000) [ ... ], width: 360, height: 450, channels: 4 } + * // ] + * ``` + */ +export class BackgroundRemovalPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageSegmentationPipelineType} */ (ImageSegmentationPipeline)) { + /** + * Create a new BackgroundRemovalPipeline. + * @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline. + */ + constructor(options) { + super(options); + } + + /** @type {BackgroundRemovalPipelineCallback} */ + async _call(images, options = {}) { + const isBatched = Array.isArray(images); + + if (isBatched && images.length !== 1) { + throw Error("Background removal pipeline currently only supports a batch size of 1."); + } + + const preparedImages = await prepareImages(images); + + // @ts-expect-error TS2339 + const masks = await super._call(images, options); + const result = preparedImages.map((img, i) => { + const cloned = img.clone(); + cloned.putAlpha(masks[i].mask); + return cloned; + }); + + return result; + } +} + /** * @typedef {Object} ZeroShotImageClassificationOutput * @property {string} label The label identified by the model. It is one of the suggested `candidate_label`. @@ -2554,7 +2643,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T const output = await this.model({ ...text_inputs, pixel_values }); let result; - if('post_process_grounded_object_detection' in this.processor) { + if ('post_process_grounded_object_detection' in this.processor) { // @ts-ignore const processed = this.processor.post_process_grounded_object_detection( output, @@ -3134,6 +3223,16 @@ const SUPPORTED_TASKS = Object.freeze({ }, "type": "multimodal", }, + "background-removal": { + // no tokenizer + "pipeline": BackgroundRemovalPipeline, + "model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation], + "processor": AutoProcessor, + "default": { + "model": "Xenova/modnet", + }, + "type": "image", + }, "zero-shot-image-classification": { "tokenizer": AutoTokenizer, diff --git a/tests/asset_cache.js b/tests/asset_cache.js index 9d1182014..1dfcfae41 100644 --- a/tests/asset_cache.js +++ b/tests/asset_cache.js @@ -24,6 +24,7 @@ const TEST_IMAGES = Object.freeze({ book_cover: BASE_URL + "book-cover.png", corgi: BASE_URL + "corgi.jpg", man_on_car: BASE_URL + "young-man-standing-and-leaning-on-car.jpg", + portrait_of_woman: BASE_URL + "portrait-of-woman_small.jpg", }); const TEST_AUDIOS = { diff --git a/tests/pipelines/test_pipelines_background_removal.js b/tests/pipelines/test_pipelines_background_removal.js new file mode 100644 index 000000000..a8fcfc9e7 --- /dev/null +++ b/tests/pipelines/test_pipelines_background_removal.js @@ -0,0 +1,70 @@ +import { pipeline, BackgroundRemovalPipeline, RawImage } from "../../src/transformers.js"; + +import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js"; +import { load_cached_image } from "../asset_cache.js"; + +const PIPELINE_ID = "background-removal"; + +export default () => { + describe("Background Removal", () => { + describe("Portrait Segmentation", () => { + const model_id = "Xenova/modnet"; + /** @type {BackgroundRemovalPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS); + }, MAX_MODEL_LOAD_TIME); + + it("should be an instance of BackgroundRemovalPipeline", () => { + expect(pipe).toBeInstanceOf(BackgroundRemovalPipeline); + }); + + it( + "single", + async () => { + const image = await load_cached_image("portrait_of_woman"); + + const output = await pipe(image); + expect(output).toHaveLength(1); + expect(output[0]).toBeInstanceOf(RawImage); + expect(output[0].width).toEqual(image.width); + expect(output[0].height).toEqual(image.height); + expect(output[0].channels).toEqual(4); // With alpha channel + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await pipe.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("Selfie Segmentation", () => { + const model_id = "onnx-community/mediapipe_selfie_segmentation"; + /** @type {BackgroundRemovalPipeline } */ + let pipe; + beforeAll(async () => { + pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS); + }, MAX_MODEL_LOAD_TIME); + + it( + "single", + async () => { + const image = await load_cached_image("portrait_of_woman"); + + const output = await pipe(image); + expect(output).toHaveLength(1); + expect(output[0]).toBeInstanceOf(RawImage); + expect(output[0].width).toEqual(image.width); + expect(output[0].height).toEqual(image.height); + expect(output[0].channels).toEqual(4); // With alpha channel + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await pipe.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); +};