Skip to content

Commit 251aa74

Browse files
yipjustinfacebook-github-bot
authored andcommitted
reconcile Dim4D and NchwDim into DimIndex (#3489)
Summary: Pull Request resolved: #3489 1. Adapt SS-JIA's idea to represent `Dim4D` as a "negative index", and rename it as `DimIndex` 2. Merge `NchwDim`'s functionality with `Dim4D`. 3. Clean up `dim_at` call to assume only `DimIndex` as input. 4. Further clean up some usage of `uint` and convert them into `int`. ghstack-source-id: 225521662 Reviewed By: SS-JIA Differential Revision: D56778340 fbshipit-source-id: 68e1fd7d74b59fb89299123263009be096dd5c18
1 parent bd6cbc4 commit 251aa74

File tree

10 files changed

+134
-135
lines changed

10 files changed

+134
-135
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
2929

3030
layout(set = 0, binding = 4) uniform PRECISION restrict Block {
3131
// output dims
32-
uvec4 out_ndims;
32+
ivec4 out_ndims;
3333
// x = output channels aligned to 4, y = input channels aligned to 4
34-
uvec2 ch_info;
34+
ivec2 ch_info;
3535
};
3636

3737
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ void add_cat_default_node(
3131
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
3232
vTensorPtr t_out = graph.get_tensor(out);
3333

34-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);
34+
DimIndex dim_index = normalize_to_dim_index(*t_out, dim);
3535

3636
// TODO: Find ways to factor out the similar code for width, height, and batch
37-
if (nchw_dim == DimWidth) {
37+
if (dim_index == kWidth4D) {
3838
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
3939
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
4040

@@ -46,7 +46,7 @@ void add_cat_default_node(
4646
dst_offset.data[0] += range.data[0];
4747
}
4848

49-
} else if (nchw_dim == DimHeight) {
49+
} else if (dim_index == kHeight4D) {
5050
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
5151
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
5252

@@ -57,7 +57,7 @@ void add_cat_default_node(
5757
graph, input_ref, range, src_offset, dst_offset, out);
5858
dst_offset.data[1] += range.data[1];
5959
}
60-
} else if (nchw_dim == DimBatch) {
60+
} else if (dim_index == kBatch4D) {
6161
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
6262
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
6363

@@ -68,19 +68,19 @@ void add_cat_default_node(
6868
graph, input_ref, range, src_offset, dst_offset, out);
6969
dst_offset.data[2] += range.data[2];
7070
}
71-
} else if (nchw_dim == DimChannel) {
71+
} else if (dim_index == kChannel4D) {
7272
int32_t src_offset = 0;
7373
int32_t dst_offset = 0;
7474

7575
for (ValueRef input_ref : *input_list) {
7676
vTensorPtr t_in = graph.get_tensor(input_ref);
77-
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
77+
int32_t range = dim_at(t_in->sizes(), kChannel4D);
7878
add_copy_channel_offset_node(
7979
graph, input_ref, range, src_offset, dst_offset, out);
8080
dst_offset += range;
8181
}
8282
} else {
83-
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
83+
VK_THROW("Unexpected value of dim_index=", dim_index);
8484
}
8585
}
8686

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,23 @@ void add_copy_channel_offset_node(
9292
VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3");
9393

9494
VK_CHECK_COND(
95-
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
95+
dim_at<kChannel4D>(in_sizes) >= src_channel_offset + channel_range,
9696
"Src channel (",
9797
src_channel_offset,
9898
") and range (",
9999
channel_range,
100100
") should be less than or equal to input tensor's channel size (",
101-
dim_at<Dim4D::Channel>(in_sizes),
101+
dim_at<kChannel4D>(in_sizes),
102102
")");
103103

104104
VK_CHECK_COND(
105-
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
105+
dim_at<kChannel4D>(out_sizes) >= dst_channel_offset + channel_range,
106106
"Dst channel (",
107107
dst_channel_offset,
108108
") and range (",
109109
channel_range,
110110
") should be less than or equal to input tensor's channel size (",
111-
dim_at<Dim4D::Channel>(out_sizes),
111+
dim_at<kChannel4D>(out_sizes),
112112
")");
113113

114114
VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
@@ -121,11 +121,10 @@ void add_copy_channel_offset_node(
121121
kernel_name.reserve(kShaderNameReserve);
122122
add_dtype_suffix(kernel_name, *t_out);
123123

124-
int32_t out_channels = dim_at<Dim4D::Channel>(out_sizes);
124+
int32_t out_channels = dim_at<kChannel4D>(out_sizes);
125125

126126
// Copy one batch at a time.
127-
for (int batch_idx = 0; batch_idx < dim_at<Dim4D::Batch>(in_sizes);
128-
batch_idx++) {
127+
for (int batch_idx = 0; batch_idx < dim_at<kBatch4D>(in_sizes); batch_idx++) {
129128
// Mapping the tensor NCHW coordinates into texture XYZ coordinates
130129
int32_t dst_first_z = dst_channel_offset / 4;
131130
int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4;
@@ -139,8 +138,8 @@ void add_copy_channel_offset_node(
139138
0, 0, dst_first_z + batch_idx * api::utils::div_up(out_channels, 4)};
140139

141140
uvec3 global_size{
142-
dim_at<Dim4D::Width>(in_sizes),
143-
dim_at<Dim4D::Height>(in_sizes),
141+
api::utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
142+
api::utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
144143
api::utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
145144

146145
uvec3 local_size = adaptive_work_group_size(global_size);

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
namespace vkcompute {
1919

20+
using api::utils::ivec2;
2021
using api::utils::ivec3;
21-
using api::utils::uvec2;
22+
using api::utils::ivec4;
2223
using api::utils::uvec4;
2324

2425
namespace {
@@ -53,7 +54,7 @@ void add_permute_node(
5354

5455
check_args(*t_in, permute_dims, *t_out);
5556

56-
uvec4 out_dims{0u, 1u, 2u, 3u};
57+
ivec4 out_dims{0, 1, 2, 3};
5758

5859
int64_t out_dim = t_out->dim();
5960
std::vector<bool> seen(out_dim);
@@ -63,22 +64,22 @@ void add_permute_node(
6364
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
6465
seen[permute_dim] = true;
6566

66-
out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
67+
out_dims.data[(4u - out_dim) + i] = permute_dim + (4 - out_dim);
6768
}
6869

6970
std::string kernel_name = "permute";
7071
kernel_name.reserve(kShaderNameReserve);
7172
add_dtype_suffix(kernel_name, *t_out);
7273

73-
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
74-
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
74+
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
75+
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
7576

76-
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
77-
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
77+
int32_t out_c_aligned = api::utils::align_up(out_channels, 4);
78+
int32_t in_c_aligned = api::utils::align_up(in_channels, 4);
7879

7980
const struct Block final {
80-
uvec4 out_ndims;
81-
uvec2 ch_info;
81+
ivec4 out_ndims;
82+
ivec2 ch_info;
8283
} params{
8384
out_dims,
8485
{out_c_aligned, in_c_aligned},

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@ void check_args(
3232
"Input tensor dim size must be not greater than the repeat argument's size");
3333

3434
VK_CHECK_COND(
35-
dim_at<Dim4D::Width>(in.sizes()) * dim_at<Dim4D::Width>(repeats) ==
36-
dim_at<Dim4D::Width>(out.sizes()),
35+
dim_at<kWidth4D>(in.sizes()) * dim_at<kWidth4D>(repeats) ==
36+
dim_at<kWidth4D>(out.sizes()),
3737
"Output's width doesn't match input's width * repeat count");
3838

3939
VK_CHECK_COND(
40-
dim_at<Dim4D::Height>(in.sizes()) * dim_at<Dim4D::Height>(repeats) ==
41-
dim_at<Dim4D::Height>(out.sizes()),
40+
dim_at<kHeight4D>(in.sizes()) * dim_at<kHeight4D>(repeats) ==
41+
dim_at<kHeight4D>(out.sizes()),
4242
"Output's height doesn't match input's height * repeat count");
4343

4444
VK_CHECK_COND(
45-
dim_at<Dim4D::Channel>(in.sizes()) * dim_at<Dim4D::Channel>(repeats) ==
46-
dim_at<Dim4D::Channel>(out.sizes()),
45+
dim_at<kChannel4D>(in.sizes()) * dim_at<kChannel4D>(repeats) ==
46+
dim_at<kChannel4D>(out.sizes()),
4747
"Output's channel doesn't match input's channel * repeat count");
4848

4949
VK_CHECK_COND(
50-
dim_at<Dim4D::Batch>(in.sizes()) * dim_at<Dim4D::Batch>(repeats) ==
51-
dim_at<Dim4D::Batch>(out.sizes()),
50+
dim_at<kBatch4D>(in.sizes()) * dim_at<kBatch4D>(repeats) ==
51+
dim_at<kBatch4D>(out.sizes()),
5252
"Output's batch doesn't match input's batch * repeat count");
5353
}
5454

@@ -70,13 +70,13 @@ void add_repeat_channel_node(
7070
const std::vector<int64_t>& in_sizes = t_in->sizes();
7171

7272
int32_t in_width =
73-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(in_sizes));
73+
api::utils::safe_downcast<int32_t>(dim_at<kWidth4D>(in_sizes));
7474
int32_t in_height =
75-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(in_sizes));
75+
api::utils::safe_downcast<int32_t>(dim_at<kHeight4D>(in_sizes));
7676
int32_t in_channel =
77-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(in_sizes));
77+
api::utils::safe_downcast<int32_t>(dim_at<kChannel4D>(in_sizes));
7878
int32_t in_batch =
79-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Batch>(in_sizes));
79+
api::utils::safe_downcast<int32_t>(dim_at<kBatch4D>(in_sizes));
8080

8181
int32_t out_channel = repeat_channel * in_channel;
8282

@@ -142,11 +142,11 @@ void add_repeat_node(
142142
// dimension, we copy over the input texure to the output. In subsequent
143143
// dimensions, we read and write from the same tensor.
144144

145-
if (int64_t channel_repeat = dim_at<Dim4D::Channel>(repeats);
145+
if (int64_t channel_repeat = dim_at<kChannel4D>(repeats);
146146
channel_repeat == 1) {
147147
// If no repeat, short-cut to a direct copy
148-
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
149-
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
148+
api::utils::ivec3 src_offset{0, 0, 0};
149+
api::utils::ivec3 dst_offset{0, 0, 0};
150150

151151
add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);
152152

@@ -156,12 +156,11 @@ void add_repeat_node(
156156

157157
// TODO: refactor width, height, and batch into a common helper function.
158158
// Width
159-
if (int64_t width_repeat = dim_at<Dim4D::Width>(repeats); width_repeat > 1) {
160-
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
159+
if (int64_t width_repeat = dim_at<kWidth4D>(repeats); width_repeat > 1) {
160+
api::utils::ivec3 src_offset{0, 0, 0};
161161

162162
for (int i = 1; i < width_repeat; ++i) {
163-
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
164-
{i * dim_at<Dim4D::Width>(in_sizes), 0, 0}, false);
163+
api::utils::ivec3 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0};
165164

166165
add_copy_offset_node(
167166
graph, out, running_range, src_offset, dst_offset, out);
@@ -171,13 +170,11 @@ void add_repeat_node(
171170
}
172171

173172
// Height
174-
if (int64_t height_repeat = dim_at<Dim4D::Height>(repeats);
175-
height_repeat > 1) {
176-
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
173+
if (int64_t height_repeat = dim_at<kHeight4D>(repeats); height_repeat > 1) {
174+
api::utils::ivec3 src_offset{0, 0, 0};
177175

178176
for (int i = 1; i < height_repeat; ++i) {
179-
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
180-
{0, i * dim_at<Dim4D::Height>(in_sizes), 0}, false);
177+
api::utils::ivec3 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0};
181178

182179
add_copy_offset_node(
183180
graph, out, running_range, src_offset, dst_offset, out);
@@ -187,12 +184,11 @@ void add_repeat_node(
187184
}
188185

189186
// Batch
190-
if (int64_t batch_repeat = dim_at<Dim4D::Batch>(repeats); batch_repeat > 1) {
191-
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
187+
if (int64_t batch_repeat = dim_at<kBatch4D>(repeats); batch_repeat > 1) {
188+
api::utils::ivec3 src_offset{0, 0, 0};
192189

193190
for (int i = 1; i < batch_repeat; ++i) {
194-
api::utils::ivec3 dst_offset =
195-
api::utils::make_ivec3({0, 0, i * running_range.data[2]}, false);
191+
api::utils::ivec3 dst_offset = {0, 0, i * running_range.data[2]};
196192

197193
add_copy_offset_node(
198194
graph, out, running_range, src_offset, dst_offset, out);

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ void add_slice_tensor_out_node(
4343

4444
dim = normalize(dim, t_in->dim());
4545

46-
// Create a dim value as in the underlying dim is 4-dimension.
47-
int64_t nchw_dim = dim + (4 - t_in->dim());
46+
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
4847

4948
std::optional<int64_t> opt_start =
5049
graph.extract_optional_scalar<int64_t>(opt_start_ref);
@@ -61,7 +60,7 @@ void add_slice_tensor_out_node(
6160
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
6261
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
6362

64-
if (nchw_dim == 1) {
63+
if (dim_index == kChannel4D) {
6564
// slice by channel
6665
std::string kernel_name = "slice_channel";
6766
kernel_name.reserve(kShaderNameReserve);
@@ -93,17 +92,17 @@ void add_slice_tensor_out_node(
9392
// GPU's coordinate is in x, y, z
9493
int64_t gpu_dim = -1;
9594
int64_t stride = 1;
96-
if (nchw_dim == 3) {
95+
if (dim_index == kWidth4D) {
9796
gpu_dim = 0; // width: x dimension in gpu
9897
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
99-
} else if (nchw_dim == 2) {
98+
} else if (dim_index == kHeight4D) {
10099
gpu_dim = 1; // height: y dimension
101100
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
102-
} else if (nchw_dim == 0) {
101+
} else if (dim_index == kBatch4D) {
103102
gpu_dim = 2; // batch: z dimension
104103

105104
// Due to channel packing, each batch value is span over stride planes
106-
int64_t n_channels = dim_at<Dim4D::Channel>(in_sizes);
105+
int64_t n_channels = dim_at(in_sizes, kChannel4D);
107106
stride = api::utils::div_up<int64_t>(n_channels, 4ll);
108107
} else {
109108
VK_THROW("Unexpected ncwh_dim!");

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void add_split_with_sizes_default_node(
2929

3030
ValueListPtr out_list = graph.get_value_list(out_list_ref);
3131

32-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
32+
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
3333

3434
VK_CHECK_COND(out_list->size() == split_sizes.size());
3535

@@ -39,10 +39,10 @@ void add_split_with_sizes_default_node(
3939

4040
vTensorPtr t_out = graph.get_tensor(out_ref);
4141
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
42-
VK_CHECK_COND(dim_at(*t_out, nchw_dim) == split_size);
42+
VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size);
4343
}
4444

45-
if (nchw_dim == DimWidth) {
45+
if (dim_index == kWidth4D) {
4646
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
4747
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
4848

@@ -55,7 +55,7 @@ void add_split_with_sizes_default_node(
5555

5656
src_offset.data[0] += range.data[0];
5757
}
58-
} else if (nchw_dim == DimHeight) {
58+
} else if (dim_index == kHeight4D) {
5959
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
6060
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
6161

@@ -66,7 +66,7 @@ void add_split_with_sizes_default_node(
6666

6767
src_offset.data[1] += range.data[1];
6868
}
69-
} else if (nchw_dim == DimBatch) {
69+
} else if (dim_index == kBatch4D) {
7070
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
7171
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
7272

@@ -77,13 +77,13 @@ void add_split_with_sizes_default_node(
7777

7878
src_offset.data[2] += range.data[2];
7979
}
80-
} else if (nchw_dim == DimChannel) {
80+
} else if (dim_index == kChannel4D) {
8181
int32_t src_offset = 0;
8282
int32_t dst_offset = 0;
8383

8484
for (ValueRef out_ref : *out_list) {
8585
vTensorPtr t_out = graph.get_tensor(out_ref);
86-
int32_t range = dim_at<Dim4D::Channel>(t_out->sizes());
86+
int32_t range = dim_at<kChannel4D>(t_out->sizes());
8787
add_copy_channel_offset_node(
8888
graph, in, range, src_offset, dst_offset, out_ref);
8989
src_offset += range;
@@ -122,8 +122,8 @@ void add_split_tensor_node(
122122
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
123123

124124
vTensorPtr t_in = graph.get_tensor(in);
125-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
126-
int64_t size = dim_at(*t_in, nchw_dim);
125+
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
126+
int64_t size = dim_at(*t_in, dim_index);
127127
std::vector<int64_t> split_sizes(size / split_size, split_size);
128128

129129
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);

0 commit comments

Comments
 (0)