diff --git a/README.md b/README.md
index 33e3191dd..4858f07ef 100644
--- a/README.md
+++ b/README.md
@@ -235,6 +235,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
+| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)
[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet
index f481fbf96..3e1a3ea32 100644
--- a/docs/snippets/5_supported-tasks.snippet
+++ b/docs/snippets/5_supported-tasks.snippet
@@ -22,6 +22,7 @@
| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
+| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)
[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
diff --git a/src/models.js b/src/models.js
index 976d1c000..83c822d57 100644
--- a/src/models.js
+++ b/src/models.js
@@ -5223,6 +5223,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
+export class SwinForSemanticSegmentation extends SwinPreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -6825,6 +6826,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
+
+export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -6848,6 +6851,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
+export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -6871,6 +6875,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
+export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -6894,6 +6899,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
+export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -7158,20 +7164,29 @@ export class PretrainedMixin {
if (!this.MODEL_CLASS_MAPPINGS) {
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
}
-
+ const model_type = options.config.model_type;
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
- const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
+ let modelInfo = MODEL_CLASS_MAPPING.get(model_type);
if (!modelInfo) {
- continue; // Item not found in this mapping
+ // As a fallback, we check if model_type is specified as the exact class
+ for (const cls of MODEL_CLASS_MAPPING.values()) {
+ if (cls[0] === model_type) {
+ modelInfo = cls;
+ break;
+ }
+ }
+ if (!modelInfo) continue; // Item not found in this mapping
}
return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
}
if (this.BASE_IF_FAIL) {
- console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`);
+ if (!(CUSTOM_ARCHITECTURES.has(model_type))) {
+ console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`);
+ }
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
} else {
- throw Error(`Unsupported model type: ${options.config.model_type}`)
+ throw Error(`Unsupported model type: ${model_type}`)
}
}
}
@@ -7524,6 +7539,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
+
+ ['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]],
+ ['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]],
+ ['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]],
+ ['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]],
+ ['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]],
]);
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
@@ -7668,6 +7689,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
}
+const CUSTOM_ARCHITECTURES = new Map([
+ ['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
+ ['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
+ ['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
+ ['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
+]);
+for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) {
+ mapping.set(name, ['PreTrainedModel', PreTrainedModel])
+ MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly);
+ MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name);
+ MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel);
+}
+
/**
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.
diff --git a/src/pipelines.js b/src/pipelines.js
index 649b00a49..bf17f7017 100644
--- a/src/pipelines.js
+++ b/src/pipelines.js
@@ -2095,7 +2095,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image
/**
* @typedef {Object} ImageSegmentationPipelineOutput
- * @property {string} label The label of the segment.
+ * @property {string|null} label The label of the segment.
* @property {number|null} score The score of the segment.
* @property {RawImage} mask The mask of the segment.
*
@@ -2165,14 +2165,30 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
const preparedImages = await prepareImages(images);
const imageSizes = preparedImages.map(x => [x.height, x.width]);
- const { pixel_values, pixel_mask } = await this.processor(preparedImages);
- const output = await this.model({ pixel_values, pixel_mask });
+ const inputs = await this.processor(preparedImages);
+
+ const { inputNames, outputNames } = this.model.sessions['model'];
+ if (!inputNames.includes('pixel_values')) {
+ if (inputNames.length !== 1) {
+ throw Error(`Expected a single input name, but got ${inputNames.length} inputs: ${inputNames}.`);
+ }
+
+ const newName = inputNames[0];
+ if (newName in inputs) {
+ throw Error(`Input name ${newName} already exists in the inputs.`);
+ }
+ // To ensure compatibility with certain background-removal models,
+ // we may need to perform a mapping of input to output names
+ inputs[newName] = inputs.pixel_values;
+ }
+
+ const output = await this.model(inputs);
let fn = null;
if (subtask !== null) {
fn = this.subtasks_mapping[subtask];
- } else {
- for (let [task, func] of Object.entries(this.subtasks_mapping)) {
+ } else if (this.processor.image_processor) {
+ for (const [task, func] of Object.entries(this.subtasks_mapping)) {
if (func in this.processor.image_processor) {
fn = this.processor.image_processor[func].bind(this.processor.image_processor);
subtask = task;
@@ -2186,7 +2202,23 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
/** @type {ImageSegmentationPipelineOutput[]} */
const annotation = [];
- if (subtask === 'panoptic' || subtask === 'instance') {
+ if (!subtask) {
+ // Perform standard image segmentation
+ const result = output[outputNames[0]];
+ for (let i = 0; i < imageSizes.length; ++i) {
+ const size = imageSizes[i];
+ const item = result[i];
+ if (item.data.some(x => x < 0 || x > 1)) {
+ item.sigmoid_();
+ }
+ const mask = await RawImage.fromTensor(item.mul_(255).to('uint8')).resize(size[1], size[0]);
+ annotation.push({
+ label: null,
+ score: null,
+ mask
+ });
+ }
+ } else if (subtask === 'panoptic' || subtask === 'instance') {
const processed = fn(
output,
threshold,
@@ -2242,6 +2274,63 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
}
}
+
+/**
+ * @typedef {Object} BackgroundRemovalPipelineOptions Parameters specific to image segmentation pipelines.
+ *
+ * @callback BackgroundRemovalPipelineCallback Segment the input images.
+ * @param {ImagePipelineInputs} images The input images.
+ * @param {BackgroundRemovalPipelineOptions} [options] The options to use for image segmentation.
+ * @returns {Promise} The images with the background removed.
+ *
+ * @typedef {ImagePipelineConstructorArgs & BackgroundRemovalPipelineCallback & Disposable} BackgroundRemovalPipelineType
+ */
+
+/**
+ * Background removal pipeline using certain `AutoModelForXXXSegmentation`.
+ * This pipeline removes the backgrounds of images.
+ *
+ * **Example:** Perform background removal with `Xenova/modnet`.
+ * ```javascript
+ * const segmenter = await pipeline('background-removal', 'Xenova/modnet');
+ * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/portrait-of-woman_small.jpg';
+ * const output = await segmenter(url);
+ * // [
+ * // RawImage { data: Uint8ClampedArray(648000) [ ... ], width: 360, height: 450, channels: 4 }
+ * // ]
+ * ```
+ */
+export class BackgroundRemovalPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageSegmentationPipelineType} */ (ImageSegmentationPipeline)) {
+ /**
+ * Create a new BackgroundRemovalPipeline.
+ * @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
+ */
+ constructor(options) {
+ super(options);
+ }
+
+ /** @type {BackgroundRemovalPipelineCallback} */
+ async _call(images, options = {}) {
+ const isBatched = Array.isArray(images);
+
+ if (isBatched && images.length !== 1) {
+ throw Error("Background removal pipeline currently only supports a batch size of 1.");
+ }
+
+ const preparedImages = await prepareImages(images);
+
+ // @ts-expect-error TS2339
+ const masks = await super._call(images, options);
+ const result = preparedImages.map((img, i) => {
+ const cloned = img.clone();
+ cloned.putAlpha(masks[i].mask);
+ return cloned;
+ });
+
+ return result;
+ }
+}
+
/**
* @typedef {Object} ZeroShotImageClassificationOutput
* @property {string} label The label identified by the model. It is one of the suggested `candidate_label`.
@@ -2554,7 +2643,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
const output = await this.model({ ...text_inputs, pixel_values });
let result;
- if('post_process_grounded_object_detection' in this.processor) {
+ if ('post_process_grounded_object_detection' in this.processor) {
// @ts-ignore
const processed = this.processor.post_process_grounded_object_detection(
output,
@@ -3134,6 +3223,16 @@ const SUPPORTED_TASKS = Object.freeze({
},
"type": "multimodal",
},
+ "background-removal": {
+ // no tokenizer
+ "pipeline": BackgroundRemovalPipeline,
+ "model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
+ "processor": AutoProcessor,
+ "default": {
+ "model": "Xenova/modnet",
+ },
+ "type": "image",
+ },
"zero-shot-image-classification": {
"tokenizer": AutoTokenizer,
diff --git a/tests/asset_cache.js b/tests/asset_cache.js
index 9d1182014..1dfcfae41 100644
--- a/tests/asset_cache.js
+++ b/tests/asset_cache.js
@@ -24,6 +24,7 @@ const TEST_IMAGES = Object.freeze({
book_cover: BASE_URL + "book-cover.png",
corgi: BASE_URL + "corgi.jpg",
man_on_car: BASE_URL + "young-man-standing-and-leaning-on-car.jpg",
+ portrait_of_woman: BASE_URL + "portrait-of-woman_small.jpg",
});
const TEST_AUDIOS = {
diff --git a/tests/pipelines/test_pipelines_background_removal.js b/tests/pipelines/test_pipelines_background_removal.js
new file mode 100644
index 000000000..a8fcfc9e7
--- /dev/null
+++ b/tests/pipelines/test_pipelines_background_removal.js
@@ -0,0 +1,70 @@
+import { pipeline, BackgroundRemovalPipeline, RawImage } from "../../src/transformers.js";
+
+import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
+import { load_cached_image } from "../asset_cache.js";
+
+const PIPELINE_ID = "background-removal";
+
+export default () => {
+ describe("Background Removal", () => {
+ describe("Portrait Segmentation", () => {
+ const model_id = "Xenova/modnet";
+ /** @type {BackgroundRemovalPipeline} */
+ let pipe;
+ beforeAll(async () => {
+ pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
+ }, MAX_MODEL_LOAD_TIME);
+
+ it("should be an instance of BackgroundRemovalPipeline", () => {
+ expect(pipe).toBeInstanceOf(BackgroundRemovalPipeline);
+ });
+
+ it(
+ "single",
+ async () => {
+ const image = await load_cached_image("portrait_of_woman");
+
+ const output = await pipe(image);
+ expect(output).toHaveLength(1);
+ expect(output[0]).toBeInstanceOf(RawImage);
+ expect(output[0].width).toEqual(image.width);
+ expect(output[0].height).toEqual(image.height);
+ expect(output[0].channels).toEqual(4); // With alpha channel
+ },
+ MAX_TEST_EXECUTION_TIME,
+ );
+
+ afterAll(async () => {
+ await pipe.dispose();
+ }, MAX_MODEL_DISPOSE_TIME);
+ });
+
+ describe("Selfie Segmentation", () => {
+ const model_id = "onnx-community/mediapipe_selfie_segmentation";
+ /** @type {BackgroundRemovalPipeline } */
+ let pipe;
+ beforeAll(async () => {
+ pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
+ }, MAX_MODEL_LOAD_TIME);
+
+ it(
+ "single",
+ async () => {
+ const image = await load_cached_image("portrait_of_woman");
+
+ const output = await pipe(image);
+ expect(output).toHaveLength(1);
+ expect(output[0]).toBeInstanceOf(RawImage);
+ expect(output[0].width).toEqual(image.width);
+ expect(output[0].height).toEqual(image.height);
+ expect(output[0].channels).toEqual(4); // With alpha channel
+ },
+ MAX_TEST_EXECUTION_TIME,
+ );
+
+ afterAll(async () => {
+ await pipe.dispose();
+ }, MAX_MODEL_DISPOSE_TIME);
+ });
+ });
+};