-
Notifications
You must be signed in to change notification settings - Fork 749
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)
Description
🐞Describing the bug
I trained the default torchvision implementation of MobileNetV3Large and convert it to CoreML. In fp32 both models give identical results but when converted to fp16 the CoreML output differs much more than the Torch model. I created a self contained example to show the difference on grayscale images (from 0 to 255), but the errors are equally large on my real data.
To Reproduce
import asyncio
import coremltools as ct
import matplotlib.pyplot as plt
import numpy as np
import torch
from coremltools.converters.mil.mil import types
from coremltools.models.ml_program.experimental.debugging_utils import MLModelComparator
from PIL import Image
from tabulate import tabulate
from torchvision.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
from tqdm import tqdm
with torch.no_grad():
model_torch32 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
model_torch32.eval()
model_torch16 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
model_torch16.eval().to(torch.float16)
model_coreml32 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
outputs=[ct.TensorType(name="y", dtype=types.fp32)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT32,
)
model_coreml16 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
outputs=[ct.TensorType(name="y", dtype=types.fp16)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT16,
)
# Compare results
xs = list(range(0, 256))
torch32_preds = []
torch16_preds = []
coreml32_preds = []
coreml16_preds = []
for x in tqdm(xs):
x_torch = torch.full((1, 3, 224, 224), x, dtype=torch.float32) / 255.0
torch32_preds.append(model_torch32(x_torch)[0, 0].item())
torch16_preds.append(model_torch16(x_torch.to(torch.float16))[0, 0].item())
x_coreml = Image.fromarray(np.full((224, 224, 3), x, dtype=np.uint8))
coreml32_preds.append(model_coreml32.predict({"image": x_coreml})["y"][0, 0])
coreml16_preds.append(model_coreml16.predict({"image": x_coreml})["y"][0, 0])
# Print error table
print(
tabulate(
[
[
"Torch ",
np.mean(np.abs(np.array(torch32_preds) - np.array(torch16_preds))),
np.mean((np.array(torch32_preds) - np.array(torch16_preds)) ** 2),
],
[
"CoreML",
np.mean(np.abs(np.array(coreml32_preds) - np.array(coreml16_preds))),
np.mean((np.array(coreml32_preds) - np.array(coreml16_preds)) ** 2),
],
],
headers=["Model", "L1", "L2"],
)
)
# plot
plt.plot(xs, torch32_preds, label="Torch32")
plt.plot(xs, torch16_preds, label="Torch16")
plt.plot(xs, coreml32_preds, label="CoreML32", linestyle="dashed")
plt.plot(xs, coreml16_preds, label="CoreML16", linestyle="dashed")
plt.legend()
plt.title("MobileNetV3: Torch vs CoreML")
plt.xlabel("Input pixel value")
plt.ylabel("Output value")
plt.savefig("torch_vs_coreml.png")
plt.close()fp32 vs fp16 comparison for Torch and CoreML:
Model L1 Error L2 Error
------- --------- -----------
Torch 0.0164849 0.000468726
CoreML 0.0393638 0.00248957 <-- much larger
Comparison plot, CoreML in fp16 output is extremely noisy:

I also tried MLModelComparator to find the error but it fails on normal "linear" and "conv" layers (depending on atol):
def compare_outputs(operation, reference_output, target_output):
return np.allclose(reference_output, target_output, atol=1e-1)
comparator = MLModelComparator(reference_model=model_coreml32, target_model=model_coreml16, num_predict_intermediate_outputs=720)
failing_ops = asyncio.run(
comparator.find_failing_ops(inputs={"image": x_coreml}, compare_outputs=compare_outputs)
)
print(failing_ops)Analyzed operation: classifier_3_weight, type: const: 1%|▏ | 4/720 [00:00<02:18, 5.18it/s]
[type: "linear"
inputs {
key: "x"
value {
arguments {
name: "input_331"
}
}
}
inputs {
key: "weight"
value {
arguments {
name: "classifier_3_weight"
}
}
}
inputs {
key: "bias"
value {
arguments {
name: "classifier_3_bias"
}
}
}
outputs {
name: "y"
type {
tensorType {
dataType: FLOAT32
rank: 2
dimensions {
constant {
size: 1
}
}
dimensions {
constant {
size: 1000
}
}
}
}
}
attributes {
key: "name"
value {
type {
tensorType {
dataType: STRING
}
}
immediateValue {
tensor {
strings {
values: "linear_1"
}
}
}
}
}
]
System environment
coremltools version: 9.0
OS: macOS 26.1
PyTorch version: 2.7.0
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)