Skip to content

Commit 5b51bb8

Browse files
authored
Support sym round and ceil
Differential Revision: D65382714 Pull Request resolved: #6699
1 parent 427b36d commit 5b51bb8

File tree

5 files changed

+107
-4
lines changed

5 files changed

+107
-4
lines changed

exir/pass_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,11 @@ def call_function(
318318
if target == operator.getitem:
319319
value, key = args
320320
return self.callback.call_getitem(value, key, meta)
321-
elif getattr(target, "__module__", None) in {"_operator", "math"}:
321+
elif getattr(target, "__module__", None) in {
322+
"_operator",
323+
"builtins",
324+
"math",
325+
}:
322326
assert callable(target)
323327
return self.callback.call_sym(target, args, meta)
324328
elif target in _TORCH_SYM_OPS:

exir/passes/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
339339
self.call(get_submodule(node.args[0]))
340340
self.call(get_submodule(node.args[1]))
341341
continue
342-
elif getattr(target, "__module__", None) == "_operator":
342+
elif getattr(target, "__module__", None) in ("builtins", "_operator"):
343343
continue
344344
elif target in to_out_var_skiplist:
345345
continue

exir/passes/executorch_prim_ops_registry.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import builtins
78
import math
89
import operator
9-
from typing import Dict, Set, Union
10+
from typing import Any, Dict, Set, Union
1011

1112
# necessary to ensure the ops are registered
1213
import torch
@@ -94,12 +95,24 @@ def neg(a: _SymScalar) -> _SymScalar:
9495
return -a # pyre-ignore
9596

9697

98+
@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar")
99+
def ceil(a: _SymScalar) -> _SymScalar:
100+
return math.ceil(a) # pyre-ignore
101+
102+
103+
@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar")
104+
def builtin_round(a: _SymScalar) -> _SymScalar:
105+
return round(a) # pyre-ignore
106+
107+
97108
@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
98109
def trunc(a: _SymScalar) -> _SymScalar:
99110
return math.trunc(a) # pyre-ignore
100111

101112

102-
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
113+
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = {
114+
builtins.round: ops.backend.executorch_prim.round.Scalar,
115+
math.ceil: ops.backend.executorch_prim.ceil.Scalar,
103116
math.trunc: ops.backend.executorch_prim.trunc.Scalar,
104117
operator.sub: ops.backend.executorch_prim.sub.Scalar,
105118
operator.mul: ops.backend.executorch_prim.mul.Scalar,

kernels/prim_ops/register_prim_ops.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,51 @@ static Kernel prim_ops[] = {
303303
}
304304
}),
305305

306+
// ceil.Scalar(Scalar a) -> Scalar
307+
Kernel(
308+
"executorch_prim::ceil.Scalar",
309+
[](KernelRuntimeContext& context, EValue** stack) {
310+
(void)context;
311+
EValue& a = *stack[0];
312+
EValue& out = *stack[1];
313+
if (a.isDouble()) {
314+
out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
315+
} else {
316+
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
317+
}
318+
}),
319+
320+
// round.Scalar(Scalar a) -> Scalar
321+
Kernel(
322+
"executorch_prim::round.Scalar",
323+
[](KernelRuntimeContext& context, EValue** stack) {
324+
(void)context;
325+
EValue& a = *stack[0];
326+
EValue& out = *stack[1];
327+
if (a.isDouble()) {
328+
// Round half to even to match Python round(). Need an explicit
329+
// implementation as not all platforms support fenv rounding modes.
330+
// See
331+
// https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html
332+
const auto val = a.toDouble();
333+
const auto r = round(val);
334+
const auto d = r - val;
335+
auto res = 0.0;
336+
337+
if (std::abs(d) != 0.5) {
338+
res = r;
339+
} else if (fmod(r, 2.0) == 0.0) {
340+
res = r;
341+
} else {
342+
res = val - d;
343+
}
344+
345+
out = EValue(static_cast<int64_t>(res));
346+
} else {
347+
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
348+
}
349+
}),
350+
306351
// trunc.Scalar(Scalar a) -> Scalar
307352
Kernel(
308353
"executorch_prim::trunc.Scalar",

kernels/prim_ops/test/prim_ops_test.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,47 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
503503
getOpsFn("executorch_prim::et_view.default")(context, bad_stack), "");
504504
}
505505

506+
TEST_F(RegisterPrimOpsTest, TestCeil) {
507+
std::array<double, 10> inputs = {
508+
0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};
509+
std::array<int64_t, 10> expected = {0, 1, 1, 1, 1, 2, 0, -1, -1, 10};
510+
511+
for (auto i = 0; i < inputs.size(); i++) {
512+
EValue values[2];
513+
values[0] = EValue(inputs[i]);
514+
values[1] = EValue(0.0);
515+
516+
EValue* stack[2];
517+
for (size_t j = 0; j < 2; j++) {
518+
stack[j] = &values[j];
519+
}
520+
521+
getOpsFn("executorch_prim::ceil.Scalar")(context, stack);
522+
EXPECT_EQ(stack[1]->toInt(), expected[i]);
523+
}
524+
}
525+
526+
TEST_F(RegisterPrimOpsTest, TestRound) {
527+
// Note that Python uses round-to-even for halfway values.
528+
std::array<double, 10> inputs = {
529+
0.0, 0.25, 0.5, 0.75, 1.0, 1.5, -0.5, -1.0, -1.5, 9.999999};
530+
std::array<int64_t, 10> expected = {0, 0, 0, 1, 1, 2, 0, -1, -2, 10};
531+
532+
for (auto i = 0; i < inputs.size(); i++) {
533+
EValue values[2];
534+
values[0] = EValue(inputs[i]);
535+
values[1] = EValue(0.0);
536+
537+
EValue* stack[2];
538+
for (size_t j = 0; j < 2; j++) {
539+
stack[j] = &values[j];
540+
}
541+
542+
getOpsFn("executorch_prim::round.Scalar")(context, stack);
543+
EXPECT_EQ(stack[1]->toInt(), expected[i]);
544+
}
545+
}
546+
506547
TEST_F(RegisterPrimOpsTest, TestTrunc) {
507548
std::array<double, 10> inputs = {
508549
0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};

0 commit comments

Comments
 (0)