Skip to content

Commit 179f62b

Browse files
ahmtoxfacebook-github-bot
authored andcommitted
Revert vulkan changes from D76646172 fixup patch
Summary: # Context Need these changes that were reverted in the weekend. Original stack of commits were unable to be merged into main due to an existing lintrunner issue blocking the merge. All the changes already went through [review](#11479) and approved. Differential Revision: D76737404
1 parent 56392aa commit 179f62b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3169
-152
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,52 +56,97 @@
5656
TYPE_MAPPINGS: Dict[str, Any] = {
5757
"IMAGE_T": {
5858
3: {
59+
"double": "image3D",
5960
"float": "image3D",
6061
"half": "image3D",
61-
"int": "iimage3D",
62-
"uint": "uimage3D",
62+
# integer dtypes
6363
"int8": "iimage3D",
6464
"uint8": "uimage3D",
65+
"int16": "iimage3D",
66+
"uint16": "uimage3D",
67+
"int32": "iimage3D",
68+
"uint32": "uimage3D",
69+
"int64": "iimage3D",
70+
"uint64": "uimage3D",
71+
# common dtype aliases
6572
"bool": "uimage3D",
73+
"int": "iimage3D",
74+
"uint": "uimage3D",
6675
},
6776
2: {
77+
"double": "image2D",
6878
"float": "image2D",
6979
"half": "image2D",
70-
"int": "iimage2D",
71-
"uint": "uimage2D",
80+
# integer dtypes
7281
"int8": "iimage2D",
7382
"uint8": "uimage2D",
83+
"int16": "iimage2D",
84+
"uint16": "uimage2D",
85+
"int32": "iimage2D",
86+
"uint32": "uimage2D",
87+
"int64": "iimage2D",
88+
"uint64": "uimage2D",
89+
# common dtype aliases
7490
"bool": "uimage2D",
91+
"int": "iimage2D",
92+
"uint": "uimage2D",
7593
},
7694
},
7795
"SAMPLER_T": {
7896
3: {
97+
"double": "sampler3D",
7998
"float": "sampler3D",
8099
"half": "sampler3D",
81-
"int": "isampler3D",
82-
"uint": "usampler3D",
100+
# integer dtypes
83101
"int8": "isampler3D",
84102
"uint8": "usampler3D",
103+
"int16": "isampler3D",
104+
"uint16": "usampler3D",
105+
"int32": "isampler3D",
106+
"uint32": "usampler3D",
107+
"int64": "isampler3D",
108+
"uint64": "usampler3D",
109+
# common dtype aliases
85110
"bool": "usampler3D",
111+
"int": "isampler3D",
112+
"uint": "usampler3D",
86113
},
87114
2: {
115+
"double": "sampler2D",
88116
"float": "sampler2D",
89117
"half": "sampler2D",
90-
"int": "isampler2D",
91-
"uint": "usampler2D",
118+
# integer dtypes
92119
"int8": "isampler2D",
93120
"uint8": "usampler2D",
121+
"int16": "isampler2D",
122+
"uint16": "usampler2D",
123+
"int32": "isampler2D",
124+
"uint32": "usampler2D",
125+
"int64": "isampler2D",
126+
"uint64": "usampler2D",
127+
# common dtype aliases
94128
"bool": "usampler2D",
129+
"int": "isampler2D",
130+
"uint": "usampler2D",
95131
},
96132
},
97133
"IMAGE_FORMAT": {
134+
"double": "rgba32f",
98135
"float": "rgba32f",
99136
"half": "rgba16f",
100-
"int": "rgba32i",
101-
"uint": "rgba32ui",
137+
# integer dtypes
102138
"int8": "rgba8i",
103139
"uint8": "rgba8ui",
140+
"int16": "rgba16i",
141+
"uint16": "rgba16ui",
142+
"int32": "rgba32i",
143+
"uint32": "rgba32ui",
144+
"int64": "rgba32i",
145+
"uint64": "rgba32ui",
146+
# common dtype aliases
104147
"bool": "rgba8ui",
148+
"int": "rgba32i",
149+
"uint": "rgba32ui",
105150
},
106151
}
107152

@@ -118,33 +163,47 @@ def define_variable(name: str) -> str:
118163
def buffer_scalar_type(dtype: str) -> str:
119164
if dtype == "half":
120165
return "float16_t"
121-
elif dtype[-1] == "8":
122-
return dtype + "_t"
166+
elif dtype == "float":
167+
return "float"
168+
elif dtype == "double":
169+
return "float64_t"
170+
# integer dtype alias conversion
123171
elif dtype == "bool":
124172
return "uint8_t"
173+
# we don't want to append _t for int32 or uint32 as int is already 32bit
174+
elif dtype == "int32" or dtype == "uint32":
175+
return "int" if dtype == "int32" else "uint"
176+
elif dtype[-1].isdigit():
177+
return dtype + "_t"
125178
return dtype
126179

127180

128181
def buffer_gvec_type(dtype: str, n: int) -> str:
129182
if n == 1:
130183
return buffer_scalar_type(dtype)
131184

132-
if dtype == "float":
133-
return f"vec{n}"
134-
if dtype == "uint":
135-
return f"uvec{n}"
136-
elif dtype == "half":
137-
return f"f16vec{n}"
138-
elif dtype == "int":
139-
return f"ivec{n}"
140-
elif dtype == "int8":
141-
return f"i8vec{n}"
142-
elif dtype == "uint8":
143-
return f"u8vec{n}"
144-
elif dtype == "bool":
145-
return f"u8vec{n}"
146-
147-
raise AssertionError(f"Invalid dtype: {dtype}")
185+
dtype_map = {
186+
"half": f"f16vec{n}",
187+
"float": f"vec{n}",
188+
"double": f"vec{n}", # No 64bit image format support in GLSL
189+
"int8": f"i8vec{n}",
190+
"uint8": f"u8vec{n}",
191+
"int16": f"i16vec{n}",
192+
"uint16": f"u16vec{n}",
193+
"int32": f"ivec{n}",
194+
"int": f"ivec{n}",
195+
"uint32": f"uvec{n}",
196+
"uint": f"uvec{n}",
197+
"int64": f"ivec{n}", # No 64bit image format support in GLSL
198+
"uint64": f"uvec{n}", # No 64bit image format support in GLSL
199+
"bool": f"u8vec{n}",
200+
}
201+
202+
vector_type = dtype_map.get(dtype)
203+
if vector_type is None:
204+
raise AssertionError(f"Invalid dtype: {dtype}")
205+
206+
return vector_type
148207

149208

150209
def texel_type(dtype: str) -> str:
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
365424
if dtype == "half":
366425
nbit = "16bit"
367426
glsl_type = "float16"
368-
elif dtype == "int16" or dtype == "uint16":
369-
nbit = "16bit"
370-
glsl_type = "int16"
371-
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
427+
elif dtype == "double":
428+
# We only need to allow float64_t type usage
429+
glsl_type = "float64"
430+
elif dtype in ["int8", "uint8", "bool"]:
372431
nbit = "8bit"
373432
glsl_type = "int8"
433+
elif dtype in ["int16", "uint16"]:
434+
nbit = "16bit"
435+
glsl_type = "int16"
436+
elif dtype in ["int64", "uint64"]:
437+
# We only need to allow int64_t and uint64_t type usage
438+
glsl_type = "int64"
374439

375-
if nbit is not None and glsl_type is not None:
440+
if nbit is not None:
376441
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
442+
if glsl_type is not None:
377443
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
378444

379445
return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
629695

630696
elif "VALUE" in value:
631697
suffix = value.get("SUFFIX", value["VALUE"])
698+
if value["VALUE"] in ["int", "uint"]:
699+
raise ValueError(
700+
f"Use int32 or uint32 instead of {value['VALUE']}"
701+
)
632702
param_values.append((param_name, suffix, value["VALUE"]))
633703

634704
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
arange:
88
parameter_names_with_default_values:
99
NDIM: 3
10-
DTYPE: int
10+
DTYPE: int32
1111
STORAGE: texture3d
1212
PACKING: C_packed
1313
generate_variant_forall:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: int32
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: int32
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int
20+
- VALUE: int32
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18+
- VALUE: int32
1819
shader_variants:
1920
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ buffer_to_nchw:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: double
1717
- VALUE: int8
1818
- VALUE: uint8
19+
- VALUE: int32
1920
shader_variants:
2021
- NAME: buffer_to_nchw
2122
- NAME: buffer_to_nchw_no_pc

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: embedding

backends/vulkan/runtime/graph/ops/glsl/flip.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ flip:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9-
- VALUE: int
9+
- VALUE: double
1010
- VALUE: int8
1111
- VALUE: uint8
12+
- VALUE: int32
1213
shader_variants:
1314
- NAME: flip

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ image_to_nchw:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: double
1818
- VALUE: int8
1919
- VALUE: uint8
20+
- VALUE: int32
2021
shader_variants:
2122
- NAME: image_to_nchw_texture3d
2223
- NAME: image_to_nchw_texture2d

backends/vulkan/runtime/graph/ops/glsl/index_select.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ index_select:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: index_select

backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ index_select_channel:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: index_select_channel

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ nchw_to_buffer:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: double
1717
- VALUE: int8
1818
- VALUE: uint8
19+
- VALUE: int32
1920
shader_variants:
2021
- NAME: nchw_to_buffer
2122
- NAME: nchw_to_buffer_no_pc

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,9 @@ void main() {
8787
return;
8888
}
8989

90-
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
90+
$if DTYPE == "double" and DTYPE == "int64":
91+
VEC4_T texel = read_texel(tidx);
92+
write_texel(t_out, lpos_to_pos(lpos, axis_map), texel);
93+
$else:
94+
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
9195
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ nchw_to_image:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: double
1818
- VALUE: int8
1919
- VALUE: uint8
20+
- VALUE: int32
2021
shader_variants:
2122
- NAME: nchw_to_image_texture3d
2223
- NAME: nchw_to_image_texture2d

backends/vulkan/runtime/graph/ops/glsl/no_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ no_op:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: int32
1616
- VALUE: int8
1717
- VALUE: uint8
1818
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/permute.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ permute:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: permute

0 commit comments

Comments
 (0)