Closed
Description
- An ATen op can be dispatched to multiple onnx-script implementations:
aten::aten_scaled_dot_product_attention
->
def aten_scaled_dot_product_attention(
query: TFloat,
key: TFloat,
value: TFloat,
attn_mask: Optional[TFloat] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
): ...
def aten_scaled_dot_product_attention_bool_mask(
query: TFloat,
key: TFloat,
value: TFloat,
attn_mask: Optional[BOOL] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
): ...
The exporter will need to dispatch onto one of the implementation based on the dtype of attn_mask
.
- (Easy) Multiple ATen op overloads may map to a single onnx-script implementation:
Both add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
and add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
map to a single
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: ...
Type promotion
Pytorch
-
torch_implicit_integral_promotion
: Sometimes we don't need to Cast the output back:Cast(int to float) -> Cos -> output
[ONNX] Add supported ops into test_fx_op_consistency - 1st batch pytorch/pytorch#100265 (comment) -
torch_implicit_type_promotion
: ONNX does not have automatic type promotion.- We can either run an fx pass to create explicit type promotion nodes before converting to ONNX, or
- Create rules to promote types in the exporter during conversion.
- The exporter can consult the
type_constraints
in a function'sOpSchema
to determine whether it needs to cast inputs to the same type.
ONNX
- ONNX tends to support less types than PyTorch. When an ONNX function is selected as the dispatch target but it does not support the input type, the exporter needs to insert cast nodes before and after the function. E.g.
Cast -> <the torchlib function> -> Cast
to mirror the aten op's behavior. - The exporter may not be able to dispatch to a function with exact type match but can still make a successful dispatch under case (1). For example, we may have two functions supporting
(a) INT64
,(b) FLOAT32
, and the input isBFLOAT16
. The dispatcher should be able to choose(b)
as the target and use the process in (1) to insert proper cast nodes.