Commit 4630347
[ET-VK] Add per-operator dtype constraints to op_registry
Previously, dtype validation for Vulkan operators was scattered across
individual node-checking functions (e.g., `check_to_copy_node` had inline
float16/float32 checks). This made it difficult to see at a glance which
dtypes each operator supports.
This diff introduces a centralized dtype constraint system:
- `utils.py`: Adds dtype set constants (`FP_T`, `INT_T`, `FP_INT32_T`,
`FP_INT32_BOOL_T`, etc.) and a `DtypeSetList` wrapper class with
broadcasting semantics. The `check_node_dtypes()` function validates
tensor inputs/outputs against allowed dtype sets and returns descriptive
error messages.
- `op_registry.py`: Extends `OpFeatures` with `inputs_dtypes` and
`outputs_dtypes` parameters. Each operator registration now explicitly
declares its supported dtypes. Simplified node-checking functions like
`check_to_copy_node` since dtype logic is now handled declaratively.
- `vulkan_partitioner.py`: Calls `features.check_dtypes()` early in
validation, failing fast with a clear skip reason if dtypes are invalid.
This approach improves maintainability by making dtype support explicit and
centralizing the validation logic.
Authored with assistance from Claude.
Differential Revision: [D92740295](https://our.internmc.facebook.com/intern/diff/D92740295/)
ghstack-source-id: 339885885
Pull Request resolved: #173361 parent bc3b703 commit 4630347
File tree
3 files changed
+197
-39
lines changed- backends/vulkan
- partitioner
3 files changed
+197
-39
lines changed
0 commit comments