Skip to content

Add background removal pipeline #1216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)<br>[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/5_supported-tasks.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Background Removal](https://huggingface.co/tasks/image-segmentation#background-removal) | `background-removal` | Isolating the main subject of an image by removing or making the background transparent. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.BackgroundRemovalPipeline)<br>[(models)](https://huggingface.co/models?other=background-removal&library=transformers.js) |
| [Depth Estimation](https://huggingface.co/tasks/depth-estimation) | `depth-estimation` | Predicting the depth of objects present in an image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.DepthEstimationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=depth-estimation&library=transformers.js) |
| [Image Classification](https://huggingface.co/tasks/image-classification) | `image-classification` | Assigning a label or class to an entire image. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js) |
| [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | `image-segmentation` | Divides an image into segments where each pixel is mapped to an object. This task has multiple variants such as instance segmentation, panoptic segmentation and semantic segmentation. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ImageSegmentationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js) |
Expand Down
44 changes: 39 additions & 5 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
export class SwinForSemanticSegmentation extends SwinPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down Expand Up @@ -6825,6 +6826,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}

export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand All @@ -6848,6 +6851,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand All @@ -6871,6 +6875,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand All @@ -6894,6 +6899,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down Expand Up @@ -7158,20 +7164,29 @@ export class PretrainedMixin {
if (!this.MODEL_CLASS_MAPPINGS) {
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
}

const model_type = options.config.model_type;
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
let modelInfo = MODEL_CLASS_MAPPING.get(model_type);
if (!modelInfo) {
continue; // Item not found in this mapping
// As a fallback, we check if model_type is specified as the exact class
for (const cls of MODEL_CLASS_MAPPING.values()) {
if (cls[0] === model_type) {
modelInfo = cls;
break;
}
}
if (!modelInfo) continue; // Item not found in this mapping
}
return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
}

if (this.BASE_IF_FAIL) {
console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`);
if (!(CUSTOM_ARCHITECTURES.has(model_type))) {
console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`);
}
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
} else {
throw Error(`Unsupported model type: ${options.config.model_type}`)
throw Error(`Unsupported model type: ${model_type}`)
}
}
}
Expand Down Expand Up @@ -7524,6 +7539,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],

['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]],
['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]],
['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]],
['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]],
['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]],
]);

const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
Expand Down Expand Up @@ -7668,6 +7689,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
}

const CUSTOM_ARCHITECTURES = new Map([
['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
]);
for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) {
mapping.set(name, ['PreTrainedModel', PreTrainedModel])
MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly);
MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name);
MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel);
}


/**
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.
Expand Down
113 changes: 106 additions & 7 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image

/**
* @typedef {Object} ImageSegmentationPipelineOutput
* @property {string} label The label of the segment.
* @property {string|null} label The label of the segment.
* @property {number|null} score The score of the segment.
* @property {RawImage} mask The mask of the segment.
*
Expand Down Expand Up @@ -2165,14 +2165,30 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
const preparedImages = await prepareImages(images);
const imageSizes = preparedImages.map(x => [x.height, x.width]);

const { pixel_values, pixel_mask } = await this.processor(preparedImages);
const output = await this.model({ pixel_values, pixel_mask });
const inputs = await this.processor(preparedImages);

const { inputNames, outputNames } = this.model.sessions['model'];
if (!inputNames.includes('pixel_values')) {
if (inputNames.length !== 1) {
throw Error(`Expected a single input name, but got ${inputNames.length} inputs: ${inputNames}.`);
}

const newName = inputNames[0];
if (newName in inputs) {
throw Error(`Input name ${newName} already exists in the inputs.`);
}
// To ensure compatibility with certain background-removal models,
// we may need to perform a mapping of input to output names
inputs[newName] = inputs.pixel_values;
}

const output = await this.model(inputs);

let fn = null;
if (subtask !== null) {
fn = this.subtasks_mapping[subtask];
} else {
for (let [task, func] of Object.entries(this.subtasks_mapping)) {
} else if (this.processor.image_processor) {
for (const [task, func] of Object.entries(this.subtasks_mapping)) {
if (func in this.processor.image_processor) {
fn = this.processor.image_processor[func].bind(this.processor.image_processor);
subtask = task;
Expand All @@ -2186,7 +2202,23 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi

/** @type {ImageSegmentationPipelineOutput[]} */
const annotation = [];
if (subtask === 'panoptic' || subtask === 'instance') {
if (!subtask) {
// Perform standard image segmentation
const result = output[outputNames[0]];
for (let i = 0; i < imageSizes.length; ++i) {
const size = imageSizes[i];
const item = result[i];
if (item.data.some(x => x < 0 || x > 1)) {
item.sigmoid_();
}
const mask = await RawImage.fromTensor(item.mul_(255).to('uint8')).resize(size[1], size[0]);
annotation.push({
label: null,
score: null,
mask
});
}
} else if (subtask === 'panoptic' || subtask === 'instance') {
const processed = fn(
output,
threshold,
Expand Down Expand Up @@ -2242,6 +2274,63 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
}
}


/**
* @typedef {Object} BackgroundRemovalPipelineOptions Parameters specific to image segmentation pipelines.
*
* @callback BackgroundRemovalPipelineCallback Segment the input images.
* @param {ImagePipelineInputs} images The input images.
* @param {BackgroundRemovalPipelineOptions} [options] The options to use for image segmentation.
* @returns {Promise<RawImage[]>} The images with the background removed.
*
* @typedef {ImagePipelineConstructorArgs & BackgroundRemovalPipelineCallback & Disposable} BackgroundRemovalPipelineType
*/

/**
* Background removal pipeline using certain `AutoModelForXXXSegmentation`.
* This pipeline removes the backgrounds of images.
*
* **Example:** Perform background removal with `Xenova/modnet`.
* ```javascript
* const segmenter = await pipeline('background-removal', 'Xenova/modnet');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/portrait-of-woman_small.jpg';
* const output = await segmenter(url);
* // [
* // RawImage { data: Uint8ClampedArray(648000) [ ... ], width: 360, height: 450, channels: 4 }
* // ]
* ```
*/
export class BackgroundRemovalPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageSegmentationPipelineType} */ (ImageSegmentationPipeline)) {
/**
* Create a new BackgroundRemovalPipeline.
* @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
*/
constructor(options) {
super(options);
}

/** @type {BackgroundRemovalPipelineCallback} */
async _call(images, options = {}) {
const isBatched = Array.isArray(images);

if (isBatched && images.length !== 1) {
throw Error("Background removal pipeline currently only supports a batch size of 1.");
}

const preparedImages = await prepareImages(images);

// @ts-expect-error TS2339
const masks = await super._call(images, options);
const result = preparedImages.map((img, i) => {
const cloned = img.clone();
cloned.putAlpha(masks[i].mask);
return cloned;
});

return result;
}
}

/**
* @typedef {Object} ZeroShotImageClassificationOutput
* @property {string} label The label identified by the model. It is one of the suggested `candidate_label`.
Expand Down Expand Up @@ -2554,7 +2643,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
const output = await this.model({ ...text_inputs, pixel_values });

let result;
if('post_process_grounded_object_detection' in this.processor) {
if ('post_process_grounded_object_detection' in this.processor) {
// @ts-ignore
const processed = this.processor.post_process_grounded_object_detection(
output,
Expand Down Expand Up @@ -3134,6 +3223,16 @@ const SUPPORTED_TASKS = Object.freeze({
},
"type": "multimodal",
},
"background-removal": {
// no tokenizer
"pipeline": BackgroundRemovalPipeline,
"model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
"processor": AutoProcessor,
"default": {
"model": "Xenova/modnet",
},
"type": "image",
},

"zero-shot-image-classification": {
"tokenizer": AutoTokenizer,
Expand Down
1 change: 1 addition & 0 deletions tests/asset_cache.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const TEST_IMAGES = Object.freeze({
book_cover: BASE_URL + "book-cover.png",
corgi: BASE_URL + "corgi.jpg",
man_on_car: BASE_URL + "young-man-standing-and-leaning-on-car.jpg",
portrait_of_woman: BASE_URL + "portrait-of-woman_small.jpg",
});

const TEST_AUDIOS = {
Expand Down
70 changes: 70 additions & 0 deletions tests/pipelines/test_pipelines_background_removal.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { pipeline, BackgroundRemovalPipeline, RawImage } from "../../src/transformers.js";

import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
import { load_cached_image } from "../asset_cache.js";

const PIPELINE_ID = "background-removal";

export default () => {
describe("Background Removal", () => {
describe("Portrait Segmentation", () => {
const model_id = "Xenova/modnet";
/** @type {BackgroundRemovalPipeline} */
let pipe;
beforeAll(async () => {
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
}, MAX_MODEL_LOAD_TIME);

it("should be an instance of BackgroundRemovalPipeline", () => {
expect(pipe).toBeInstanceOf(BackgroundRemovalPipeline);
});

it(
"single",
async () => {
const image = await load_cached_image("portrait_of_woman");

const output = await pipe(image);
expect(output).toHaveLength(1);
expect(output[0]).toBeInstanceOf(RawImage);
expect(output[0].width).toEqual(image.width);
expect(output[0].height).toEqual(image.height);
expect(output[0].channels).toEqual(4); // With alpha channel
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await pipe.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("Selfie Segmentation", () => {
const model_id = "onnx-community/mediapipe_selfie_segmentation";
/** @type {BackgroundRemovalPipeline } */
let pipe;
beforeAll(async () => {
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
}, MAX_MODEL_LOAD_TIME);

it(
"single",
async () => {
const image = await load_cached_image("portrait_of_woman");

const output = await pipe(image);
expect(output).toHaveLength(1);
expect(output[0]).toBeInstanceOf(RawImage);
expect(output[0].width).toEqual(image.width);
expect(output[0].height).toEqual(image.height);
expect(output[0].channels).toEqual(4); // With alpha channel
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await pipe.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});
};
Loading