Skip to content

Commit ce21031

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][qlinear] Add bias support to q4gsw and dq8ca_q4gsw quantized linear ops
Pull Request resolved: #18061 Wire bias through the q4gsw and dq8ca_q4gsw quantized linear operators. Add add_bias_to_out_tile() helper in the output tile computation header and call it from all three shader variants (tiled, coop, dq8ca_tiled). Remove the bias guard in the pattern matcher to allow biased linear layers. ghstack-source-id: 353546681 @exported-using-ghexport Differential Revision: [D95970172](https://our.internmc.facebook.com/intern/diff/D95970172/)
1 parent d996fb6 commit ce21031

File tree

7 files changed

+49
-18
lines changed

7 files changed

+49
-18
lines changed

backends/vulkan/custom_ops_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import executorch.backends.vulkan.patterns as vk_patterns
1010
import torch.library
11-
1211
from torch._subclasses.fake_tensor import FakeTensor
1312

1413
namespace = "et_vk"
@@ -259,7 +258,7 @@ def linear_q4gsw(
259258
weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7
260259
)
261260

262-
out = torch.nn.functional.linear(x, weights)
261+
out = torch.nn.functional.linear(x, weights, bias)
263262
return out
264263

265264

@@ -273,7 +272,7 @@ def linear_dq8ca_q4gsw(
273272
group_size: int,
274273
bias: Optional[torch.Tensor] = None,
275274
):
276-
return linear_q4gsw(x, weights, weight_scales, group_size)
275+
return linear_q4gsw(x, weights, weight_scales, group_size, bias)
277276

278277

279278
name = "linear_q4gsw"

backends/vulkan/patterns/quantized_linear.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,22 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import operator
8-
98
from typing import Optional
109

1110
import executorch.backends.vulkan.utils as utils
12-
1311
import torch
1412
import torch.nn.functional as F
15-
1613
from executorch.backends.transforms.utils import (
1714
create_constant_placeholder,
1815
get_param_tensor,
1916
)
20-
2117
from executorch.backends.vulkan.patterns.pattern_registry import (
2218
PatternMatch,
2319
register_pattern_detector,
2420
register_pattern_replacement,
2521
)
26-
2722
from executorch.exir import ExportedProgram
2823
from executorch.exir.dialects._ops import ops as exir_ops
29-
3024
from torch.export.graph_signature import InputKind
3125

3226

@@ -398,6 +392,12 @@ def make_linear_q4gsw_op(
398392
force_update=True,
399393
)
400394

395+
# Pad bias to multiple of 4 if present
396+
if match.bias_node is not None:
397+
bias_tensor = get_param_tensor(ep, match.bias_node)
398+
if bias_tensor is not None:
399+
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)
400+
401401
with graph_module.graph.inserting_before(match.output_node):
402402
linear_q4gsw_node = graph_module.graph.create_node(
403403
"call_function",
@@ -407,6 +407,7 @@ def make_linear_q4gsw_op(
407407
match.weight_node,
408408
match.weight_scales_node,
409409
group_size,
410+
match.bias_node,
410411
),
411412
)
412413

@@ -445,6 +446,12 @@ def make_linear_dq8ca_q4gsw_op(
445446
force_update=True,
446447
)
447448

449+
# Pad bias to multiple of 4 if present
450+
if match.bias_node is not None:
451+
bias_tensor = get_param_tensor(ep, match.bias_node)
452+
if bias_tensor is not None:
453+
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)
454+
448455
first_graph_node = list(graph_module.graph.nodes)[0]
449456
with graph_module.graph.inserting_before(first_graph_node):
450457
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
@@ -474,6 +481,7 @@ def make_linear_dq8ca_q4gsw_op(
474481
weight_sums_node,
475482
match.weight_scales_node,
476483
group_size,
484+
match.bias_node,
477485
),
478486
)
479487

@@ -538,6 +546,7 @@ def make_linear_q8ta_q8csw_custom_op(
538546
match.weight_node,
539547
weight_sums_node,
540548
match.weight_scales_node,
549+
match.bias_node,
541550
),
542551
)
543552

@@ -637,7 +646,6 @@ def replace_quantized_linear_patterns(
637646
assert weight_zeros_tensor is not None
638647

639648
# Route to appropriate custom op.
640-
# q8ta_linear supports bias, so check it first before the bias guard.
641649
if (
642650
match.is_input_static_per_tensor_quantized()
643651
and match.is_weight_perchannel_quantized()
@@ -646,10 +654,6 @@ def replace_quantized_linear_patterns(
646654
make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor)
647655
return
648656

649-
# Remaining ops do not support bias
650-
if match.bias_node is not None:
651-
return
652-
653657
if (
654658
match.is_weight_only_quantized()
655659
and match.is_weight_pergroup_quantized()

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,11 @@ void main() {
144144
group_size);
145145
}
146146

147+
if (apply_bias > 0) {
148+
FPPerOutChannelParams bias_tile;
149+
load_bias_tile(bias_tile, n4);
150+
add_bias_to_out_tile(out_tile, bias_tile);
151+
}
152+
147153
write_output_tile_with_checks(out_tile, n4, m, N4, M);
148154
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ void apply_weight_scales_and_biases(
7373
}
7474
}
7575

76+
void add_bias_to_out_tile(
77+
inout FPOutTile tile,
78+
const FPPerOutChannelParams bias) {
79+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
80+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
81+
tile.data[m][n4] = tile.data[m][n4] + bias.data[n4];
82+
}
83+
}
84+
}
85+
7686
void accumulate_out_tile_with_out_tile(
7787
inout FPOutTile accum,
7888
const FPOutTile other) {

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ void main() {
142142
// Only the first thread will write out result
143143
if (lid == 0) {
144144
out_tile = partial_sums[0];
145+
if (apply_bias > 0) {
146+
FPPerOutChannelParams bias_tile;
147+
load_bias_tile(bias_tile, n4);
148+
add_bias_to_out_tile(out_tile, bias_tile);
149+
}
145150
write_output_tile_with_checks(out_tile, n4, 0, N4, 1);
146151
}
147152
}

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,11 @@ void main() {
110110
}
111111
}
112112

113+
if (apply_bias > 0) {
114+
FPPerOutChannelParams bias_tile;
115+
load_bias_tile(bias_tile, n4);
116+
add_bias_to_out_tile(out_tile, bias_tile);
117+
}
118+
113119
write_output_tile_with_checks(out_tile, n4, m, N4, M);
114120
}

backends/vulkan/test/custom_ops/q4gsw_linear.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ TestCase create_test_case_from_config(
148148
input_dtype,
149149
storage_type,
150150
utils::kWidthPacked,
151-
DataGenType::ZEROS);
151+
config.has_bias ? DataGenType::RANDOM : DataGenType::ZEROS);
152152
bias.set_constant(true);
153153
if (!config.has_bias) {
154154
bias.set_none(true);
@@ -237,9 +237,10 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
237237
{32, 64, 32, 16},
238238
{32, 128, 64, 32},
239239
{32, 256, 128, 64},
240-
// No bias tests
241-
{32, 128, 64, 32, false},
242-
{32, 256, 128, 64, false},
240+
// With bias
241+
{4, 64, 32, 16, true},
242+
{4, 128, 64, 32, true},
243+
{32, 128, 64, 32, true},
243244
// Performance test cases
244245
{1, 2048, 2048, 128},
245246
{128, 2048, 2048, 128},

0 commit comments

Comments
 (0)