Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

from torch._subclasses.fake_tensor import FakeTensor

namespace = "et_vk"
Expand Down Expand Up @@ -259,7 +258,7 @@ def linear_q4gsw(
weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7
)

out = torch.nn.functional.linear(x, weights)
out = torch.nn.functional.linear(x, weights, bias)
return out


Expand All @@ -273,7 +272,7 @@ def linear_dq8ca_q4gsw(
group_size: int,
bias: Optional[torch.Tensor] = None,
):
return linear_q4gsw(x, weights, weight_scales, group_size)
return linear_q4gsw(x, weights, weight_scales, group_size, bias)


name = "linear_q4gsw"
Expand Down
54 changes: 46 additions & 8 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def update_features_impl(op: OpKey):
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -279,6 +281,26 @@ def register_bitwise_and():
)


@update_features(exir_ops.edge.aten.bitwise_not.default)
def register_bitwise_not():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


@update_features(exir_ops.edge.aten.logical_and.default)
def register_logical_and():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# BinaryScalarOp.cpp
# =============================================================================
Expand All @@ -301,16 +323,22 @@ def register_pow_tensor_scalar():

@update_features(exir_ops.edge.aten._to_copy.default)
def register_to_copy():
def check_to_copy_node(node: torch.fx.Node) -> bool:
# Only single-arg _to_copy is supported
return len(node.args) == 1
def pick_to_copy_storage(
node: torch.fx.Node,
) -> Tuple[utils.TensorRepSet, utils.TensorRepSet]:
in_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr]
out_dtype = node.meta["val"].dtype
fp_types = {torch.float16, torch.float32}
if in_dtype in fp_types and out_dtype in fp_types:
return utils.ANY_STORAGE, utils.ANY_STORAGE
return utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
inputs_dtypes=utils.FP_INT_BOOL_T,
outputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
pick_io_storage_fn=pick_to_copy_storage,
)


Expand Down Expand Up @@ -705,7 +733,7 @@ def register_reduce_cpp_ops():
)
def register_argreduce_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.INT_T,
supports_resize=True,
Expand Down Expand Up @@ -1336,6 +1364,7 @@ def register_scalar_tensor():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
)


Expand Down Expand Up @@ -1390,11 +1419,20 @@ def register_repeat():

@update_features(exir_ops.edge.aten.embedding.default)
def register_embedding():
def check_embedding_weight_size(node: torch.fx.Node) -> bool:
weight = node.args[0]
if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight):
numel = weight.meta["val"].numel()
if numel > utils.DEFAULT_BUFFER_LIMIT:
return False
return True

return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=[utils.FP_T, utils.INT_T],
supports_prepacking=True,
supports_resize=True,
are_node_inputs_supported_fn=check_embedding_weight_size,
)


Expand Down
26 changes: 15 additions & 11 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@
# LICENSE file in the root directory of this source tree.

import operator

from typing import Optional

import executorch.backends.vulkan.utils as utils

import torch
import torch.nn.functional as F

from executorch.backends.transforms.utils import (
create_constant_placeholder,
get_param_tensor,
)

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from torch.export.graph_signature import InputKind


Expand Down Expand Up @@ -398,6 +392,12 @@ def make_linear_q4gsw_op(
force_update=True,
)

# Pad bias to multiple of 4 if present
if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)

with graph_module.graph.inserting_before(match.output_node):
linear_q4gsw_node = graph_module.graph.create_node(
"call_function",
Expand All @@ -407,6 +407,7 @@ def make_linear_q4gsw_op(
match.weight_node,
match.weight_scales_node,
group_size,
match.bias_node,
),
)

Expand Down Expand Up @@ -445,6 +446,12 @@ def make_linear_dq8ca_q4gsw_op(
force_update=True,
)

# Pad bias to multiple of 4 if present
if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)

first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
Expand Down Expand Up @@ -474,6 +481,7 @@ def make_linear_dq8ca_q4gsw_op(
weight_sums_node,
match.weight_scales_node,
group_size,
match.bias_node,
),
)

Expand Down Expand Up @@ -538,6 +546,7 @@ def make_linear_q8ta_q8csw_custom_op(
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.bias_node,
),
)

Expand Down Expand Up @@ -637,7 +646,6 @@ def replace_quantized_linear_patterns(
assert weight_zeros_tensor is not None

# Route to appropriate custom op.
# q8ta_linear supports bias, so check it first before the bias guard.
if (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
Expand All @@ -646,10 +654,6 @@ def replace_quantized_linear_patterns(
make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor)
return

# Remaining ops do not support bias
if match.bias_node is not None:
return

if (
match.is_weight_only_quantized()
and match.is_weight_pergroup_quantized()
Expand Down
9 changes: 3 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture2d")}
${layout_declare_ubo(B, "ivec4", "sizes")}

#include "indexing_utils.h"
Expand All @@ -30,9 +30,6 @@ const lowp int packed_dim = unhash_packed_dim(out_layout);
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);

${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout);

void main() {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
Expand All @@ -48,8 +45,8 @@ void main() {
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];

// Read weight tensor for embedding, it is height-packed.
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0);
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4];
const ivec2 weight_pos = ivec2(out_tidx.x, in_texel_elem / 4);
out_texel[i] = texelFetch(t_weight, weight_pos, 0)[in_texel_elem % 4];
}

write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,11 @@ void main() {
group_size);
}

if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}

write_output_tile_with_checks(out_tile, n4, m, N4, M);
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ void apply_weight_scales_and_biases(
}
}

void add_bias_to_out_tile(
inout FPOutTile tile,
const FPPerOutChannelParams bias) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
tile.data[m][n4] = tile.data[m][n4] + bias.data[n4];
}
}
}

void accumulate_out_tile_with_out_tile(
inout FPOutTile accum,
const FPOutTile other) {
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void main() {
// Only the first thread will write out result
if (lid == 0) {
out_tile = partial_sums[0];
if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}
write_output_tile_with_checks(out_tile, n4, 0, N4, 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,11 @@ void main() {
}
}

if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}

write_output_tile_with_checks(out_tile, n4, m, N4, M);
}
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ unary_op:
OPERATOR: leaky_relu(X, A)
- NAME: round
OPERATOR: round(X)
- NAME: bitwise_not_uint8
OPERATOR: 1 - X
DTYPE: uint8
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.gt.Tensor, gt);
VK_REGISTER_OP(aten.ge.Tensor, ge);
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
VK_REGISTER_OP(aten.logical_and.default, bitwise_and);
}

} // namespace vkcompute
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/graph/ops/impl/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ void add_embedding_legacy_node(
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out),
graph.hashed_layout_of(in),
graph.hashed_layout_of(weight)},
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
// Resize Args
{},
// Resizing Logic
Expand Down
42 changes: 2 additions & 40 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,13 @@
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>

namespace vkcompute {

using utils::GPUMemoryLayout;
using utils::StorageType;

void resize_split_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
const ValueRef input = args.at(0).refs.at(0);
const ValueRef split_sizes_ref = args.at(1).refs.at(0);
const ValueRef dim_ref = args.at(2).refs.at(0);
const ValueRef out_list_ref = args.at(3).refs.at(0);

const ValueListPtr out_list = graph->get_value_list(out_list_ref);
const std::vector<int64_t> split_sizes =
*(graph->get_int_list(split_sizes_ref));
const int64_t dim = graph->extract_scalar<int64_t>(dim_ref);

const int64_t input_ndim = graph->dim_of(input);
const DimIndex dim_index = dim < 0 ? static_cast<DimIndex>(dim)
: static_cast<DimIndex>(dim - input_ndim);

std::vector<int64_t> input_sizes = graph->sizes_of(input);

for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
const int64_t split_size = split_sizes.at(split_idx);
const ValueRef out_ref = out_list->at(split_idx);

std::vector<int64_t> out_sizes = input_sizes;
out_sizes.at(dim_index) = split_size;

graph->virtual_resize(out_ref, out_sizes);
}
}

void add_split_node(
ComputeGraph& graph,
const ValueRef input,
Expand Down Expand Up @@ -125,7 +86,8 @@ void split_with_sizes_copy_default(
ValueRef out_list_ref = args[3];

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
std::vector<int64_t> split_sizes =
graph.extract_int_or_symint_list(split_sizes_ref);

add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref);
}
Expand Down
Loading
Loading