diff --git a/kernels/portable/cpu/op_round.cpp b/kernels/portable/cpu/op_round.cpp index c14d75dc31e..eb4559fbb96 100644 --- a/kernels/portable/cpu/op_round.cpp +++ b/kernels/portable/cpu/op_round.cpp @@ -43,14 +43,18 @@ Tensor& round_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK( ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbf16_type(out), + InvalidArgument, + out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); auto in_scalar_type = in.scalar_type(); - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "round.out", CTYPE, [&] { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "round.out", CTYPE, [&] { apply_unary_map_fn( [in_scalar_type](const CTYPE val_in) { if (isIntegralType(in_scalar_type, /*includeBool=*/false)) { diff --git a/kernels/test/op_round_test.cpp b/kernels/test/op_round_test.cpp index 0b39aab22c2..71fda4a50d2 100644 --- a/kernels/test/op_round_test.cpp +++ b/kernels/test/op_round_test.cpp @@ -84,6 +84,14 @@ TEST_F(OpRoundTest, DoubleTensors) { test_round_execution_floats(); } +TEST_F(OpRoundTest, HalfTensors) { + test_round_execution_floats(); +} + +TEST_F(OpRoundTest, BFloat16Tensors) { + test_round_execution_floats(); +} + TEST_F(OpRoundTest, ByteTensors) { TensorFactory tf;