Skip to content

Support for annotating Optional inputs #280

@justinchuby

Description

@justinchuby

We would like to make this work:

@torch_op("aten::clamp")
def aten_clamp(
    self: TReal, min_: Optional[TReal] = None, max_: Optional[TReal] = None
) -> TReal:
    # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor

    if op.OptionalHasElement(min_):
        min_ = op.OptionalGetElement(min_)
        min_clamp = op.CastLike(min_, self)
    else:
        # TODO: get the correct min number for the dtype
        min_clamp = op.Constant(value_float=float("-inf"))

    if op.OptionalHasElement(max_):
        max_ = op.OptionalGetElement(max_)
        max_clamp = op.CastLike(max_, self)
    else:
        max_clamp = op.Constant(value_float=float("inf"))

    # Enforce the lower and upper bounds
    clamped = op.Max(op.Min(self, max_clamp), min_clamp)
    return clamped

onnxruntime complains that

E onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) , Error Node () has input size 0 not in range [min=1, max=1].

Reference

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions