Skip to content

Commit beea58f

Browse files
authored
[ET-VK] Moving PushConstantData class implementation a separate h and cpp file.
Differential Revision: D70102031 Pull Request resolved: #8648
1 parent 8786f86 commit beea58f

File tree

4 files changed

+95
-64
lines changed

4 files changed

+95
-64
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
10+
11+
namespace vkcompute {
12+
13+
uint32_t PushConstantDataInfo::write(
14+
void* dst,
15+
const uint32_t dst_offset,
16+
const uint32_t max_dst_size) const {
17+
if (tensorUniformData != nullptr) {
18+
return tensorUniformData->write_attribute(
19+
dst, dst_offset, max_dst_size, payload_.attr);
20+
}
21+
22+
VK_CHECK_COND(
23+
(dst_offset + payload_.dataSize) <= max_dst_size,
24+
"Attempting to write push constant data outside data boundary.");
25+
memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize);
26+
return payload_.dataSize;
27+
}
28+
29+
} // namespace vkcompute
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
namespace vkcompute {
14+
15+
class ComputeGraph;
16+
17+
constexpr uint32_t kMaxPushConstantSize = 128;
18+
/*
19+
* Represents a push constant data entry
20+
* Which is either shared pointer to a tensor's uniform data with an attribute
21+
* Or data with a maximum size of 16 bytes
22+
*/
23+
class PushConstantDataInfo {
24+
std::shared_ptr<api::vTensor::UniformData> tensorUniformData;
25+
union Payload {
26+
struct {
27+
api::vTensor::Attribute attr;
28+
};
29+
struct {
30+
uint8_t data[16];
31+
uint32_t dataSize;
32+
};
33+
};
34+
35+
Payload payload_;
36+
37+
public:
38+
explicit PushConstantDataInfo(
39+
const std::shared_ptr<api::vTensor::UniformData>& tensorUniformData,
40+
api::vTensor::Attribute attr)
41+
: tensorUniformData(tensorUniformData) {
42+
payload_.attr = attr;
43+
}
44+
45+
explicit PushConstantDataInfo(
46+
const void* data,
47+
uint32_t dataLen,
48+
uint32_t pushConstantLen = 0)
49+
: tensorUniformData(nullptr) {
50+
VK_CHECK_COND(
51+
dataLen <= 16, "Single push constant data size must be <= 16 bytes");
52+
payload_.dataSize = pushConstantLen ? pushConstantLen : dataLen;
53+
memcpy(payload_.data, data, dataLen);
54+
}
55+
56+
/*
57+
* Function writes push constant data to the destination buffer
58+
*/
59+
uint32_t write(
60+
void* dst,
61+
const uint32_t dst_offset,
62+
const uint32_t max_dst_size) const;
63+
};
64+
65+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

-16
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,6 @@
1414

1515
namespace vkcompute {
1616

17-
uint32_t PushConstantDataInfo::write(
18-
void* dst,
19-
const uint32_t dst_offset,
20-
const uint32_t max_dst_size) const {
21-
if (tensorUniformData != nullptr) {
22-
return tensorUniformData->write_attribute(
23-
dst, dst_offset, max_dst_size, payload_.attr);
24-
}
25-
26-
VK_CHECK_COND(
27-
(dst_offset + payload_.dataSize) <= max_dst_size,
28-
"Attempting to write push constant data outside data boundary.");
29-
memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize);
30-
return payload_.dataSize;
31-
}
32-
3317
DispatchNode::DispatchNode(
3418
ComputeGraph& graph,
3519
const vkapi::ShaderInfo& shader,

backends/vulkan/runtime/graph/ops/DispatchNode.h

+1-48
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/api/api.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
1314
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
@@ -18,54 +19,6 @@ namespace vkcompute {
1819

1920
class ComputeGraph;
2021

21-
constexpr uint32_t kMaxPushConstantSize = 128;
22-
/*
23-
* Represents a push constant data entry
24-
* Which is either shared pointer to a tensor's uniform data with an attribute
25-
* Or data with a maximum size of 16 bytes
26-
*/
27-
class PushConstantDataInfo {
28-
std::shared_ptr<api::vTensor::UniformData> tensorUniformData;
29-
union Payload {
30-
struct {
31-
api::vTensor::Attribute attr;
32-
};
33-
struct {
34-
uint8_t data[16];
35-
uint32_t dataSize;
36-
};
37-
};
38-
39-
Payload payload_;
40-
41-
public:
42-
explicit PushConstantDataInfo(
43-
const std::shared_ptr<api::vTensor::UniformData>& tensorUniformData,
44-
api::vTensor::Attribute attr)
45-
: tensorUniformData(tensorUniformData) {
46-
payload_.attr = attr;
47-
}
48-
49-
explicit PushConstantDataInfo(
50-
const void* data,
51-
uint32_t dataLen,
52-
uint32_t pushConstantLen = 0)
53-
: tensorUniformData(nullptr) {
54-
VK_CHECK_COND(
55-
dataLen <= 16, "Single push constant data size must be <= 16 bytes");
56-
payload_.dataSize = pushConstantLen ? pushConstantLen : dataLen;
57-
memcpy(payload_.data, data, dataLen);
58-
}
59-
60-
/*
61-
* Function writes push constant data to the destination buffer
62-
*/
63-
uint32_t write(
64-
void* dst,
65-
const uint32_t dst_offset,
66-
const uint32_t max_dst_size) const;
67-
};
68-
6922
/*
7023
* Represents a single shader execution op in a ML model.
7124
*/

0 commit comments

Comments
 (0)