Skip to content

Commit 4630347

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Add per-operator dtype constraints to op_registry
Previously, dtype validation for Vulkan operators was scattered across individual node-checking functions (e.g., `check_to_copy_node` had inline float16/float32 checks). This made it difficult to see at a glance which dtypes each operator supports. This diff introduces a centralized dtype constraint system: - `utils.py`: Adds dtype set constants (`FP_T`, `INT_T`, `FP_INT32_T`, `FP_INT32_BOOL_T`, etc.) and a `DtypeSetList` wrapper class with broadcasting semantics. The `check_node_dtypes()` function validates tensor inputs/outputs against allowed dtype sets and returns descriptive error messages. - `op_registry.py`: Extends `OpFeatures` with `inputs_dtypes` and `outputs_dtypes` parameters. Each operator registration now explicitly declares its supported dtypes. Simplified node-checking functions like `check_to_copy_node` since dtype logic is now handled declaratively. - `vulkan_partitioner.py`: Calls `features.check_dtypes()` early in validation, failing fast with a clear skip reason if dtypes are invalid. This approach improves maintainability by making dtype support explicit and centralizing the validation logic. Authored with assistance from Claude. Differential Revision: [D92740295](https://our.internmc.facebook.com/intern/diff/D92740295/) ghstack-source-id: 339885885 Pull Request resolved: #17336
1 parent bc3b703 commit 4630347

File tree

3 files changed

+197
-39
lines changed

3 files changed

+197
-39
lines changed

0 commit comments

Comments
 (0)