Skip to content

PyTorch FX exporter dispatching and type promotion scenarios #563

Closed
@justinchuby

Description

@justinchuby
  1. An ATen op can be dispatched to multiple onnx-script implementations:

aten::aten_scaled_dot_product_attention ->

def aten_scaled_dot_product_attention(
    query: TFloat,
    key: TFloat,
    value: TFloat,
    attn_mask: Optional[TFloat] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
): ...

def aten_scaled_dot_product_attention_bool_mask(
    query: TFloat,
    key: TFloat,
    value: TFloat,
    attn_mask: Optional[BOOL] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
): ...

The exporter will need to dispatch onto one of the implementation based on the dtype of attn_mask.

  1. (Easy) Multiple ATen op overloads may map to a single onnx-script implementation:

Both add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor and add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor map to a single

def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: ...

Type promotion

Pytorch

  • torch_implicit_integral_promotion: Sometimes we don't need to Cast the output back: Cast(int to float) -> Cos -> output [ONNX] Add supported ops into test_fx_op_consistency - 1st batch pytorch/pytorch#100265 (comment)
  • torch_implicit_type_promotion: ONNX does not have automatic type promotion.
    • We can either run an fx pass to create explicit type promotion nodes before converting to ONNX, or
    • Create rules to promote types in the exporter during conversion.
    • The exporter can consult the type_constraints in a function's OpSchema to determine whether it needs to cast inputs to the same type.

ONNX

  • ONNX tends to support less types than PyTorch. When an ONNX function is selected as the dispatch target but it does not support the input type, the exporter needs to insert cast nodes before and after the function. E.g. Cast -> <the torchlib function> -> Cast to mirror the aten op's behavior.
  • The exporter may not be able to dispatch to a function with exact type match but can still make a successful dispatch under case (1). For example, we may have two functions supporting (a) INT64, (b) FLOAT32, and the input is BFLOAT16. The dispatcher should be able to choose (b) as the target and use the process in (1) to insert proper cast nodes.

cc @titaiwangms @thiagocrepaldi @BowenBao

Metadata

Metadata

Labels

module: torchlibRelated to the torch/aten function lib in developmenttopic: discussionFor discussion

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions