Skip to content

Android - How to integrate Image segmentation (Selfie) with WebRTC #653

@maitrungduc1410

Description

@maitrungduc1410

Hi guys,

I'm trying to build a demo with Image segmentation with WebRTC on Web + Android.

Web part is done, pretty straightforward using Canvas

when it comes to Android, things seem to become much more complicated

Currrently, I'm trying to use videoSource?.setVideoProcessor(backgroundProcessor) with custom VirtualBackgroundProcessor class to handle it, but at it doesn't work, local stream still showing camera (I can feel it's a bit laggy, probably it's going through MediaPipe, but just something is wrong)

Seeking for your help, thank you 🙏

Here is my code: maitrungduc1410/WebRTC-Demo@960061e#diff-98114fe57333170b2414a799579f50fdadafd9c7fe003c7b118967b2581ac55a):

package com.example.myapplication.webrtc

import android.content.Context
import android.graphics.*
import android.util.Log
import com.example.myapplication.R
import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.framework.image.ByteBufferExtractor
import com.google.mediapipe.tasks.core.BaseOptions
import com.google.mediapipe.tasks.vision.core.RunningMode
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter
import org.webrtc.*
import java.nio.ByteBuffer

/**
 * VirtualBackgroundProcessor applies virtual background effect using MediaPipe.
 * 
 * Key insight from MediaPipe selfie_segmenter model:
 * - Mask value 0 = person (foreground)
 * - Mask value 255 = background
 */
class VirtualBackgroundProcessor(
    context: Context,
    private val rootEglBase: EglBase
) : VideoProcessor {
    
    companion object {
        private const val TAG = "VirtualBgProcessor"
        private const val SEGMENTATION_WIDTH = 256  // Downscale for fast inference
        private const val INFERENCE_INTERVAL = 3  // Process every Nth frame for segmentation
    }
    
    private var imageSegmenter: ImageSegmenter? = null
    private var sink: VideoSink? = null
    private val backgroundBitmap: Bitmap = BitmapFactory.decodeResource(
        context.resources, 
        R.drawable.virtual_background
    )

    private val maskLock = Any()
    private var safeMaskBuffer: ByteBuffer? = null
    private var maskWidth = 0
    private var maskHeight = 0
    
    // Frame counter for skipping inference
    private var frameCount = 0
    
    // Pre-scaled background for faster compositing
    private var scaledBackgroundCache: Bitmap? = null
    private var cachedWidth = 0
    private var cachedHeight = 0

    init {
        try {
            val options = ImageSegmenter.ImageSegmenterOptions.builder()
                .setBaseOptions(
                    BaseOptions.builder()
                        .setModelAssetPath("selfie_segmenter.tflite")
                        .build()
                )
                .setRunningMode(RunningMode.LIVE_STREAM)
                .setOutputCategoryMask(true)
                .setOutputConfidenceMasks(false)
                .setResultListener { result, _ ->
                    try {
                        val maskImage = result.categoryMask().get()
                        val extracted = ByteBufferExtractor.extract(maskImage)

                        synchronized(maskLock) {
                            if (safeMaskBuffer == null || safeMaskBuffer!!.capacity() < extracted.capacity()) {
                                safeMaskBuffer = ByteBuffer.allocateDirect(extracted.capacity())
                            }
                            safeMaskBuffer!!.clear()
                            extracted.rewind()
                            safeMaskBuffer!!.put(extracted)
                            safeMaskBuffer!!.rewind()
                            
                            maskWidth = maskImage.width
                            maskHeight = maskImage.height
                        }
                    } catch (e: Exception) {
                        Log.e(TAG, "Error in segmentation result", e)
                    }
                }
                .setErrorListener { error ->
                    Log.e(TAG, "MediaPipe error: ${error.message}", error)
                }
                .build()
            imageSegmenter = ImageSegmenter.createFromOptions(context, options)
            Log.d(TAG, "MediaPipe initialized successfully")
        } catch (e: Exception) {
            Log.e(TAG, "Failed to initialize MediaPipe", e)
        }
    }

    override fun onFrameCaptured(frame: VideoFrame) {
        try {
            frameCount++
            
            // Only run segmentation on every Nth frame to save CPU
            if (frameCount % INFERENCE_INTERVAL == 0) {
                runSegmentation(frame)
            }

            // Apply background if mask ready
            synchronized(maskLock) {
                if (safeMaskBuffer != null && safeMaskBuffer!!.capacity() > 0) {
                    val processedFrame = applyVirtualBackgroundFast(frame)
                    sink?.onFrame(processedFrame)
                    processedFrame.release()
                } else {
                    // No mask yet (first frames), pass through
                    sink?.onFrame(frame)
                }
            }
        } catch (e: Exception) {
            Log.e(TAG, "Error processing frame", e)
            sink?.onFrame(frame)
        }
    }
    
    private fun runSegmentation(frame: VideoFrame) {
        try {
            val aspectRatio = frame.buffer.height.toFloat() / frame.buffer.width
            val inferenceHeight = (SEGMENTATION_WIDTH * aspectRatio).toInt()
            
            val scaledBuffer = frame.buffer.cropAndScale(
                0, 0,
                frame.buffer.width,
                frame.buffer.height,
                SEGMENTATION_WIDTH,
                inferenceHeight
            )
            val i420Buffer = scaledBuffer.toI420()

            val inferenceBitmap = i420BufferToBitmap(i420Buffer!!)
            i420Buffer.release()
            scaledBuffer.release()

            // Trigger MediaPipe async segmentation
            val mpImage = BitmapImageBuilder(inferenceBitmap).build()
            imageSegmenter?.segmentAsync(mpImage, frame.timestampNs / 1000000)
            inferenceBitmap.recycle()
        } catch (e: Exception) {
            Log.e(TAG, "Error in segmentation", e)
        }
    }

    private fun applyVirtualBackgroundFast(originalFrame: VideoFrame): VideoFrame {
        // Convert to full-res bitmap for quality
        val fullI420 = originalFrame.buffer.toI420()
        val width = fullI420?.width ?: 0
        val height = fullI420?.height ?: 0
        
        val personBitmap = i420BufferToBitmap(fullI420!!)
        fullI420.release()

        // Output bitmap
        val outputBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
        val canvas = Canvas(outputBitmap)

        // Use cached scaled background if size matches
        if (scaledBackgroundCache == null || cachedWidth != width || cachedHeight != height) {
            scaledBackgroundCache?.recycle()
            scaledBackgroundCache = Bitmap.createScaledBitmap(backgroundBitmap, width, height, true)
            cachedWidth = width
            cachedHeight = height
        }
        
        // Draw pre-scaled background (faster than scaling on-demand)
        canvas.drawBitmap(scaledBackgroundCache!!, 0f, 0f, null)

        // Get and scale mask
        val maskBitmap = synchronized(maskLock) {
            val rawMask = Bitmap.createBitmap(maskWidth, maskHeight, Bitmap.Config.ALPHA_8)
            safeMaskBuffer?.rewind()
            rawMask.copyPixelsFromBuffer(safeMaskBuffer!!)
            
            val scaledMask = Bitmap.createScaledBitmap(rawMask, width, height, true)
            rawMask.recycle()
            scaledMask
        }

        // Create person layer
        // CRITICAL: selfie_segmenter outputs 0=person, 255=background
        // We need to invert for proper masking
        val personLayer = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
        val layerCanvas = Canvas(personLayer)
        layerCanvas.drawBitmap(personBitmap, 0f, 0f, null)

        // Invert mask: 0 (person) -> 255 (opaque), 255 (bg) -> 0 (transparent)
        val invertedMask = invertMaskBitmap(maskBitmap)
        maskBitmap.recycle()

        // Apply mask to keep only person
        val paint = Paint(Paint.ANTI_ALIAS_FLAG)
        paint.xfermode = PorterDuffXfermode(PorterDuff.Mode.DST_IN)
        layerCanvas.drawBitmap(invertedMask, 0f, 0f, paint)
        invertedMask.recycle()

        // Composite person on background
        canvas.drawBitmap(personLayer, 0f, 0f, null)

        personLayer.recycle()
        personBitmap.recycle()

        // Convert back to VideoFrame
        val processedBuffer = bitmapToI420Buffer(outputBitmap)
        outputBitmap.recycle()
        
        return VideoFrame(processedBuffer, originalFrame.rotation, originalFrame.timestampNs)
    }

    private fun invertMaskBitmap(mask: Bitmap): Bitmap {
        val width = mask.width
        val height = mask.height
        val inverted = Bitmap.createBitmap(width, height, Bitmap.Config.ALPHA_8)
        
        val pixels = ByteBuffer.allocate(width * height)
        mask.copyPixelsToBuffer(pixels)
        pixels.rewind()
        
        val invertedPixels = ByteBuffer.allocate(width * height)
        for (i in 0 until width * height) {
            val value = pixels.get(i).toInt() and 0xFF
            // Invert: 0 (person) -> 255, 255 (bg) -> 0
            invertedPixels.put((255 - value).toByte())
        }
        invertedPixels.rewind()
        
        inverted.copyPixelsFromBuffer(invertedPixels)
        return inverted
    }

    private fun bitmapToI420Buffer(bitmap: Bitmap): VideoFrame.I420Buffer {
        val width = bitmap.width
        val height = bitmap.height
        
        val argbBuffer = ByteBuffer.allocateDirect(width * height * 4)
        bitmap.copyPixelsToBuffer(argbBuffer)
        argbBuffer.rewind()

        val i420Buffer = JavaI420Buffer.allocate(width, height)
        
        YuvHelper.ABGRToI420(
            argbBuffer, width * 4,
            i420Buffer.dataY, i420Buffer.strideY,
            i420Buffer.dataU, i420Buffer.strideU,
            i420Buffer.dataV, i420Buffer.strideV,
            width, height
        )
        
        return i420Buffer
    }

    private fun i420BufferToBitmap(i420Buffer: VideoFrame.I420Buffer): Bitmap {
        val width = i420Buffer.width
        val height = i420Buffer.height
        val pixels = IntArray(width * height)

        val yBuf = i420Buffer.dataY
        val uBuf = i420Buffer.dataU
        val vBuf = i420Buffer.dataV

        for (i in 0 until height) {
            for (j in 0 until width) {
                val yIndex = i * i420Buffer.strideY + j
                val uvIndex = (i / 2) * i420Buffer.strideU + (j / 2)
                
                val y = (yBuf.get(yIndex).toInt() and 0xFF)
                val u = (uBuf.get(uvIndex).toInt() and 0xFF) - 128
                val v = (vBuf.get(uvIndex).toInt() and 0xFF) - 128

                val r = (y + 1.370705 * v).toInt().coerceIn(0, 255)
                val g = (y - 0.337633 * u - 0.698001 * v).toInt().coerceIn(0, 255)
                val b = (y + 1.732446 * u).toInt().coerceIn(0, 255)
                
                pixels[i * width + j] = Color.rgb(r, g, b)
            }
        }

        return Bitmap.createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
    }

    override fun setSink(sink: VideoSink?) {
        this.sink = sink
    }

    override fun onCapturerStarted(success: Boolean) {
        Log.d(TAG, "Capturer started: $success")
    }

    override fun onCapturerStopped() {
        Log.d(TAG, "Capturer stopped")
    }
    
    fun cleanup() {
        imageSegmenter?.close()
        imageSegmenter = null
        if (!backgroundBitmap.isRecycled) {
            backgroundBitmap.recycle()
        }
        scaledBackgroundCache?.recycle()
        scaledBackgroundCache = null
        Log.d(TAG, "Cleanup complete")
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions