Skip to content

Add op (linspace) | feat(torchlib) #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 8, 2023

Conversation

fatcat-z
Copy link
Contributor

@fatcat-z fatcat-z commented Jul 7, 2023

Add linspace op into torch_lib functions.
Add tests as well.

reason="op 'Range' doesn't support float16.",
).skip(
matcher=lambda sample: len(sample.args) > 1 and sample.args[1] == 1,
reason="aten::linspace with steps=1 is not supported by its definition.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we see this case in the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the defition of aten::linspace here, steps should not equal to 1 because they need to be divisor. But PyTorch has done something in backend to make it work so the test has this case.

Unfortunately, I tried couple of solutions to simulate this in ONNX graph but failed. So skip this test here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, probably we need to improve the experience of skip() method.

Per our assumptions, skip() will only skip those inputs which matches the matcher. But from the results, it looks like all of inputs were skipped while there still some tests passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the results, it looks like all of inputs were skipped while there still some tests passed.

That's interesting. Could you provide more info? Which tests should not have been skipped?

@@ -3507,10 +3507,39 @@ def aten_linear_backward(
raise NotImplementedError()


def aten_linspace(start: float, end: float, steps: int) -> TensorType:
@torch_op("aten::linspace", trace_only=True)
def aten_linspace(start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, steps: int, dtype: int = -1) -> TensorType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am now pondering if we should split the dtype variation out to avoid traceonly and leverage the dispatcher

Copy link
Contributor Author

@fatcat-z fatcat-z Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, this should be a general solution for other ops as well: let dispatcher handle the logic of choosing correct overload.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. We can do a follow up pass to clean this up. @titaiwangms fyi - what are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow:

The scenario is:

  1. Let dispatcher do the work: Two scripted overloads that one is with dtype, and one is not.
  2. Use trace_only: Only one overload, so there is nothing to dispatch.

In this case, dispatcher doesn't have to do anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow:

The scenario is:

  1. Let dispatcher do the work: Two scripted overloads that one is with dtype, and one is not.
  2. Use trace_only: Only one overload, so there is nothing to dispatch.

In this case, dispatcher doesn't have to do anything.

The scenario should be:
1, torch_lib provides 2 scripted overloads, one is with dtype and the other is not., No trace_only here.
2, Dispatcher needs to decide which one to call for export, depending on what dispatcher got from FX graph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG! I remember we already have those ops! The dispatcher can do that.

@fatcat-z fatcat-z mentioned this pull request Jul 7, 2023
@codecov
Copy link

codecov bot commented Jul 7, 2023

Codecov Report

Merging #838 (0eb740f) into main (4af3495) will increase coverage by 0.01%.
The diff coverage is 85.00%.

@@            Coverage Diff             @@
##             main     #838      +/-   ##
==========================================
+ Coverage   76.31%   76.33%   +0.01%     
==========================================
  Files         113      113              
  Lines       13338    13356      +18     
  Branches     1320     1322       +2     
==========================================
+ Hits        10179    10195      +16     
- Misses       2833     2834       +1     
- Partials      326      327       +1     
Impacted Files Coverage Δ
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.72% <ø> (ø)
onnxscript/function_libs/torch_lib/ops/core.py 76.72% <85.00%> (+0.08%) ⬆️

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

def aten_linspace(start: float, end: float, steps: int) -> TensorType:
@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, steps: int, dtype: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start, end and return type can be 3 distinct TypeVars having the same constraints as the underlying types of TRealUnlessFloat16OrInt8. @justinchuby what is the current guideline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean they can take different types?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With help of a prototype type constraint deducer, I observe the following. Note I created 3 variants to fold away the if condition, since deducer does not work with trace only functions

"""
NOTE: I'm not sure about this one, the deducer takes INT64 as type for `steps`.
T1: {'tensor(int8)', 'tensor(double)', 'tensor(float)', 'tensor(float16)', 'tensor(uint16)', 'tensor(uint64)', 'tensor(int16)', 'tensor(string)', 'tensor(uint32)', 'tensor(bfloat16)', 'tensor(bool)', 'tensor(uint8)', 'tensor(int32)', 'tensor(int64)'}
T2: {'tensor(int64)'}
start: T1
end: T2
return_val: T2
"""
@torch_op("aten::linspace")
def aten_linspace_variance_1(
    start: TensorType, end: TensorType, steps: int
) -> TensorType:
    zero = op.CastLike(0.0, steps)
    one = op.CastLike(1.0, steps)

    range_tensor = op.Range(zero, steps, one)

    start = op.CastLike(start, end)
    step = op.Div(
        op.Sub(end, start),
        op.Sub(steps, one),
    )

    return op.Add(op.Mul(range_tensor, step), start)


"""
T1: {'tensor(uint16)', 'tensor(int16)', 'tensor(float16)', 'tensor(float)', 'tensor(uint8)', 'tensor(bool)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(uint32)', 'tensor(int64)', 'tensor(int32)', 'tensor(uint64)', 'tensor(int8)', 'tensor(double)'}
T2: {'tensor(uint16)', 'tensor(int16)', 'tensor(float16)', 'tensor(float)', 'tensor(uint8)', 'tensor(bool)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(uint32)', 'tensor(int64)', 'tensor(int32)', 'tensor(uint64)', 'tensor(int8)', 'tensor(double)'}
T3: {'tensor(int16)', 'tensor(float)', 'tensor(int64)', 'tensor(int32)', 'tensor(double)'}
start: T1
end: T2
return_val: T3
"""
@torch_op("aten::linspace")
def aten_linspace_variance_2(
    start: TensorType, end: TensorType, steps: int, dtype: int
) -> TensorType:
    zero = op.Cast(0, to=dtype)
    one = op.Cast(1, to=dtype)
    start = op.Cast(start, to=dtype)
    end = op.Cast(end, to=dtype)
    steps = op.Cast(steps, to=dtype)

    range_tensor = op.Range(zero, steps, one)

    start = op.CastLike(start, end)
    step = op.Div(
        op.Sub(end, start),
        op.Sub(steps, one),
    )

    return op.Add(op.Mul(range_tensor, step), start)

from onnxscript import FLOAT


"""
T1: {'tensor(int16)', 'tensor(uint64)', 'tensor(string)', 'tensor(int32)', 'tensor(uint16)', 'tensor(uint32)', 'tensor(double)', 'tensor(bfloat16)', 'tensor(float16)', 'tensor(int8)', 'tensor(uint8)', 'tensor(bool)', 'tensor(int64)', 'tensor(float)'}
T2: {'tensor(int16)', 'tensor(uint64)', 'tensor(string)', 'tensor(int32)', 'tensor(uint16)', 'tensor(uint32)', 'tensor(double)', 'tensor(bfloat16)', 'tensor(float16)', 'tensor(int8)', 'tensor(uint8)', 'tensor(bool)', 'tensor(int64)', 'tensor(float)'}
T3: {'tensor(float)'}
start: T1
end: T2
return_val: T3
"""
@torch_op("aten::linspace")
def aten_linspace_variance_3(
    start: TensorType, end: TensorType, steps: int,
) -> TensorType:
    zero = op.Cast(0.0, to=FLOAT.dtype)
    one = op.Cast(1.0, to=FLOAT.dtype)
    start = op.Cast(start, to=FLOAT.dtype)
    end = op.Cast(end, to=FLOAT.dtype)
    steps = op.Cast(steps, to=FLOAT.dtype)

    range_tensor = op.Range(zero, steps, one)

    start = op.CastLike(start, end)
    step = op.Div(
        op.Sub(end, start),
        op.Sub(steps, one),
    )

    return op.Add(op.Mul(range_tensor, step), start)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we can support different input types, in which case we just create three type vars with the same constraint. However, we can also decide that it’s not practical and they just need to be the same typevar as that’s what torch gives us

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it’s not practical and they just need to be the same typevar as that’s what torch gives us

I think different typed start and end is legal in torch. Return type can be completely different with them too due to dtype. So we should use 3 type vars to generate accurate opschema.

Also assuming deducer is correct, the 1st variant restricts end and return value to INT64 only, hence we might want to insert steps = op.CastLike(steps, end) to the beginning.

"""
T1: {'tensor(int32)', 'tensor(uint64)', 'tensor(uint16)', 'tensor(int8)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(int64)', 'tensor(float16)', 'tensor(uint8)', 'tensor(bool)', 'tensor(double)', 'tensor(uint32)', 'tensor(int16)', 'tensor(float)'}
T2: {'tensor(int32)', 'tensor(int64)', 'tensor(double)', 'tensor(int16)', 'tensor(float)'}
start: T1
end: T2
return_val: T2
"""
@torch_op("aten::linspace")
def aten_linspace_variance_1(
    start: TensorType, end: TensorType, steps: int
) -> TensorType:
    steps = op.CastLike(steps, end)
    zero = op.CastLike(0.0, steps)
    one = op.CastLike(1.0, steps)

    range_tensor = op.Range(zero, steps, one)

    start = op.CastLike(start, end)
    step = op.Div(
        op.Sub(end, start),
        op.Sub(steps, one),
    )

    return op.Add(op.Mul(range_tensor, step), start)

Yet the behavior doesn't match with torch doc anyways.
dtype ([torch.dtype](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype), optional) – the data type to perform the computation in. Default: if None, uses the global default dtype (see torch.get_default_dtype()) when both start and end are real, and corresponding complex dtype when either is complex.

So my comment for this is non-blocking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I debugged tests of this op, I did see some test data where start is FLOAT and end is INT, so they can have different types in real scenario.

If we have a new solution for such type differences, need to go through all of ops using 'Range' op.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #841

@justinchuby justinchuby merged commit bcee0ec into microsoft:main Jul 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants