-
Notifications
You must be signed in to change notification settings - Fork 64
Add op (linspace) | feat(torchlib) #838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
reason="op 'Range' doesn't support float16.", | ||
).skip( | ||
matcher=lambda sample: len(sample.args) > 1 and sample.args[1] == 1, | ||
reason="aten::linspace with steps=1 is not supported by its definition.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we see this case in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the defition of aten::linspace here, steps should not equal to 1 because they need to be divisor. But PyTorch has done something in backend to make it work so the test has this case.
Unfortunately, I tried couple of solutions to simulate this in ONNX graph but failed. So skip this test here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, probably we need to improve the experience of skip() method.
Per our assumptions, skip() will only skip those inputs which matches the matcher. But from the results, it looks like all of inputs were skipped while there still some tests passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the results, it looks like all of inputs were skipped while there still some tests passed.
That's interesting. Could you provide more info? Which tests should not have been skipped?
@@ -3507,10 +3507,39 @@ def aten_linear_backward( | |||
raise NotImplementedError() | |||
|
|||
|
|||
def aten_linspace(start: float, end: float, steps: int) -> TensorType: | |||
@torch_op("aten::linspace", trace_only=True) | |||
def aten_linspace(start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, steps: int, dtype: int = -1) -> TensorType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am now pondering if we should split the dtype variation out to avoid traceonly and leverage the dispatcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, this should be a general solution for other ops as well: let dispatcher handle the logic of choosing correct overload.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. We can do a follow up pass to clean this up. @titaiwangms fyi - what are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow:
The scenario is:
- Let dispatcher do the work: Two scripted overloads that one is with dtype, and one is not.
- Use trace_only: Only one overload, so there is nothing to dispatch.
In this case, dispatcher doesn't have to do anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow:
The scenario is:
- Let dispatcher do the work: Two scripted overloads that one is with dtype, and one is not.
- Use trace_only: Only one overload, so there is nothing to dispatch.
In this case, dispatcher doesn't have to do anything.
The scenario should be:
1, torch_lib provides 2 scripted overloads, one is with dtype and the other is not., No trace_only here.
2, Dispatcher needs to decide which one to call for export, depending on what dispatcher got from FX graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG! I remember we already have those ops! The dispatcher can do that.
Codecov Report
@@ Coverage Diff @@
## main #838 +/- ##
==========================================
+ Coverage 76.31% 76.33% +0.01%
==========================================
Files 113 113
Lines 13338 13356 +18
Branches 1320 1322 +2
==========================================
+ Hits 10179 10195 +16
- Misses 2833 2834 +1
- Partials 326 327 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
def aten_linspace(start: float, end: float, steps: int) -> TensorType: | ||
@torch_op("aten::linspace", trace_only=True) | ||
def aten_linspace( | ||
start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, steps: int, dtype: int = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start
, end
and return type can be 3 distinct TypeVar
s having the same constraints as the underlying types of TRealUnlessFloat16OrInt8
. @justinchuby what is the current guideline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean they can take different types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With help of a prototype type constraint deducer, I observe the following. Note I created 3 variants to fold away the if condition, since deducer does not work with trace only functions
"""
NOTE: I'm not sure about this one, the deducer takes INT64 as type for `steps`.
T1: {'tensor(int8)', 'tensor(double)', 'tensor(float)', 'tensor(float16)', 'tensor(uint16)', 'tensor(uint64)', 'tensor(int16)', 'tensor(string)', 'tensor(uint32)', 'tensor(bfloat16)', 'tensor(bool)', 'tensor(uint8)', 'tensor(int32)', 'tensor(int64)'}
T2: {'tensor(int64)'}
start: T1
end: T2
return_val: T2
"""
@torch_op("aten::linspace")
def aten_linspace_variance_1(
start: TensorType, end: TensorType, steps: int
) -> TensorType:
zero = op.CastLike(0.0, steps)
one = op.CastLike(1.0, steps)
range_tensor = op.Range(zero, steps, one)
start = op.CastLike(start, end)
step = op.Div(
op.Sub(end, start),
op.Sub(steps, one),
)
return op.Add(op.Mul(range_tensor, step), start)
"""
T1: {'tensor(uint16)', 'tensor(int16)', 'tensor(float16)', 'tensor(float)', 'tensor(uint8)', 'tensor(bool)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(uint32)', 'tensor(int64)', 'tensor(int32)', 'tensor(uint64)', 'tensor(int8)', 'tensor(double)'}
T2: {'tensor(uint16)', 'tensor(int16)', 'tensor(float16)', 'tensor(float)', 'tensor(uint8)', 'tensor(bool)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(uint32)', 'tensor(int64)', 'tensor(int32)', 'tensor(uint64)', 'tensor(int8)', 'tensor(double)'}
T3: {'tensor(int16)', 'tensor(float)', 'tensor(int64)', 'tensor(int32)', 'tensor(double)'}
start: T1
end: T2
return_val: T3
"""
@torch_op("aten::linspace")
def aten_linspace_variance_2(
start: TensorType, end: TensorType, steps: int, dtype: int
) -> TensorType:
zero = op.Cast(0, to=dtype)
one = op.Cast(1, to=dtype)
start = op.Cast(start, to=dtype)
end = op.Cast(end, to=dtype)
steps = op.Cast(steps, to=dtype)
range_tensor = op.Range(zero, steps, one)
start = op.CastLike(start, end)
step = op.Div(
op.Sub(end, start),
op.Sub(steps, one),
)
return op.Add(op.Mul(range_tensor, step), start)
from onnxscript import FLOAT
"""
T1: {'tensor(int16)', 'tensor(uint64)', 'tensor(string)', 'tensor(int32)', 'tensor(uint16)', 'tensor(uint32)', 'tensor(double)', 'tensor(bfloat16)', 'tensor(float16)', 'tensor(int8)', 'tensor(uint8)', 'tensor(bool)', 'tensor(int64)', 'tensor(float)'}
T2: {'tensor(int16)', 'tensor(uint64)', 'tensor(string)', 'tensor(int32)', 'tensor(uint16)', 'tensor(uint32)', 'tensor(double)', 'tensor(bfloat16)', 'tensor(float16)', 'tensor(int8)', 'tensor(uint8)', 'tensor(bool)', 'tensor(int64)', 'tensor(float)'}
T3: {'tensor(float)'}
start: T1
end: T2
return_val: T3
"""
@torch_op("aten::linspace")
def aten_linspace_variance_3(
start: TensorType, end: TensorType, steps: int,
) -> TensorType:
zero = op.Cast(0.0, to=FLOAT.dtype)
one = op.Cast(1.0, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
steps = op.Cast(steps, to=FLOAT.dtype)
range_tensor = op.Range(zero, steps, one)
start = op.CastLike(start, end)
step = op.Div(
op.Sub(end, start),
op.Sub(steps, one),
)
return op.Add(op.Mul(range_tensor, step), start)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So we can support different input types, in which case we just create three type vars with the same constraint. However, we can also decide that it’s not practical and they just need to be the same typevar as that’s what torch gives us
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it’s not practical and they just need to be the same typevar as that’s what torch gives us
I think different typed start
and end
is legal in torch. Return type can be completely different with them too due to dtype
. So we should use 3 type vars to generate accurate opschema.
Also assuming deducer is correct, the 1st variant restricts end
and return value to INT64
only, hence we might want to insert steps = op.CastLike(steps, end)
to the beginning.
"""
T1: {'tensor(int32)', 'tensor(uint64)', 'tensor(uint16)', 'tensor(int8)', 'tensor(bfloat16)', 'tensor(string)', 'tensor(int64)', 'tensor(float16)', 'tensor(uint8)', 'tensor(bool)', 'tensor(double)', 'tensor(uint32)', 'tensor(int16)', 'tensor(float)'}
T2: {'tensor(int32)', 'tensor(int64)', 'tensor(double)', 'tensor(int16)', 'tensor(float)'}
start: T1
end: T2
return_val: T2
"""
@torch_op("aten::linspace")
def aten_linspace_variance_1(
start: TensorType, end: TensorType, steps: int
) -> TensorType:
steps = op.CastLike(steps, end)
zero = op.CastLike(0.0, steps)
one = op.CastLike(1.0, steps)
range_tensor = op.Range(zero, steps, one)
start = op.CastLike(start, end)
step = op.Div(
op.Sub(end, start),
op.Sub(steps, one),
)
return op.Add(op.Mul(range_tensor, step), start)
Yet the behavior doesn't match with torch doc anyways.
dtype ([torch.dtype](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype), optional) – the data type to perform the computation in. Default: if None, uses the global default dtype (see torch.get_default_dtype()) when both start and end are real, and corresponding complex dtype when either is complex.
So my comment for this is non-blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I debugged tests of this op, I did see some test data where start is FLOAT and end is INT, so they can have different types in real scenario.
If we have a new solution for such type differences, need to go through all of ops using 'Range' op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #841
Add linspace op into torch_lib functions.
Add tests as well.