Skip to content

Large output errors when converting MobileNetV3 in FP16 #2625

@johan-sightic

Description

@johan-sightic

🐞Describing the bug

I trained the default torchvision implementation of MobileNetV3Large and convert it to CoreML. In fp32 both models give identical results but when converted to fp16 the CoreML output differs much more than the Torch model. I created a self contained example to show the difference on grayscale images (from 0 to 255), but the errors are equally large on my real data.

To Reproduce

import asyncio

import coremltools as ct
import matplotlib.pyplot as plt
import numpy as np
import torch
from coremltools.converters.mil.mil import types
from coremltools.models.ml_program.experimental.debugging_utils import MLModelComparator
from PIL import Image
from tabulate import tabulate
from torchvision.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
from tqdm import tqdm

with torch.no_grad():
    model_torch32 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
    model_torch32.eval()
    model_torch16 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
    model_torch16.eval().to(torch.float16)

    model_coreml32 = ct.converters.convert(
        model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
        inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
        outputs=[ct.TensorType(name="y", dtype=types.fp32)],
        minimum_deployment_target=ct.target.iOS18,
        compute_units=ct.ComputeUnit.CPU_ONLY,
        compute_precision=ct.precision.FLOAT32,
    )
    model_coreml16 = ct.converters.convert(
        model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
        inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
        outputs=[ct.TensorType(name="y", dtype=types.fp16)],
        minimum_deployment_target=ct.target.iOS18,
        compute_units=ct.ComputeUnit.CPU_ONLY,
        compute_precision=ct.precision.FLOAT16,
    )

    # Compare results
    xs = list(range(0, 256))
    torch32_preds = []
    torch16_preds = []
    coreml32_preds = []
    coreml16_preds = []
    for x in tqdm(xs):
        x_torch = torch.full((1, 3, 224, 224), x, dtype=torch.float32) / 255.0
        torch32_preds.append(model_torch32(x_torch)[0, 0].item())
        torch16_preds.append(model_torch16(x_torch.to(torch.float16))[0, 0].item())

        x_coreml = Image.fromarray(np.full((224, 224, 3), x, dtype=np.uint8))
        coreml32_preds.append(model_coreml32.predict({"image": x_coreml})["y"][0, 0])
        coreml16_preds.append(model_coreml16.predict({"image": x_coreml})["y"][0, 0])

# Print error table
print(
    tabulate(
        [
            [
                "Torch ",
                np.mean(np.abs(np.array(torch32_preds) - np.array(torch16_preds))),
                np.mean((np.array(torch32_preds) - np.array(torch16_preds)) ** 2),
            ],
            [
                "CoreML",
                np.mean(np.abs(np.array(coreml32_preds) - np.array(coreml16_preds))),
                np.mean((np.array(coreml32_preds) - np.array(coreml16_preds)) ** 2),
            ],
        ],
        headers=["Model", "L1", "L2"],
    )
)


# plot
plt.plot(xs, torch32_preds, label="Torch32")
plt.plot(xs, torch16_preds, label="Torch16")
plt.plot(xs, coreml32_preds, label="CoreML32", linestyle="dashed")
plt.plot(xs, coreml16_preds, label="CoreML16", linestyle="dashed")
plt.legend()
plt.title("MobileNetV3: Torch vs CoreML")
plt.xlabel("Input pixel value")
plt.ylabel("Output value")
plt.savefig("torch_vs_coreml.png")
plt.close()

fp32 vs fp16 comparison for Torch and CoreML:

Model     L1 Error     L2 Error
-------  ---------  -----------
Torch    0.0164849  0.000468726
CoreML   0.0393638  0.00248957  <-- much larger

Comparison plot, CoreML in fp16 output is extremely noisy:
Image

I also tried MLModelComparator to find the error but it fails on normal "linear" and "conv" layers (depending on atol):

def compare_outputs(operation, reference_output, target_output):
    return np.allclose(reference_output, target_output, atol=1e-1)


comparator = MLModelComparator(reference_model=model_coreml32, target_model=model_coreml16, num_predict_intermediate_outputs=720)
failing_ops = asyncio.run(
    comparator.find_failing_ops(inputs={"image": x_coreml}, compare_outputs=compare_outputs)
)
print(failing_ops)
Analyzed operation: classifier_3_weight, type: const:   1%|▏                                     | 4/720 [00:00<02:18,  5.18it/s]
[type: "linear"
inputs {
  key: "x"
  value {
    arguments {
      name: "input_331"
    }
  }
}
inputs {
  key: "weight"
  value {
    arguments {
      name: "classifier_3_weight"
    }
  }
}
inputs {
  key: "bias"
  value {
    arguments {
      name: "classifier_3_bias"
    }
  }
}
outputs {
  name: "y"
  type {
    tensorType {
      dataType: FLOAT32
      rank: 2
      dimensions {
        constant {
          size: 1
        }
      }
      dimensions {
        constant {
          size: 1000
        }
      }
    }
  }
}
attributes {
  key: "name"
  value {
    type {
      tensorType {
        dataType: STRING
      }
    }
    immediateValue {
      tensor {
        strings {
          values: "linear_1"
        }
      }
    }
  }
}
]

System environment

coremltools version: 9.0
OS: macOS 26.1
PyTorch version: 2.7.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions