|
7 | 7 | #pragma once
|
8 | 8 | #include <array>
|
9 | 9 |
|
| 10 | +#include <stdint.h> |
10 | 11 | #include <cassert>
|
| 12 | + |
11 | 13 | namespace torchao::ops {
|
12 | 14 |
|
13 |
| -enum PackedWeightsFormat : unsigned short { |
| 15 | +enum class PackedWeightsFormat : uint32_t { |
14 | 16 | unknown = 0,
|
15 | 17 | linear_8bit_act_xbit_weight_universal = 1
|
16 | 18 | };
|
17 | 19 |
|
18 | 20 | class PackedWeightsHeader {
|
19 | 21 | public:
|
20 |
| - using params_type = std::array<unsigned short, 7>; |
| 22 | + using params_type = std::array<int, 14>; |
| 23 | + const static int magic = 6712; |
21 | 24 | PackedWeightsFormat format;
|
22 | 25 |
|
23 | 26 | // 14 bytes of format specific params
|
24 | 27 | params_type params;
|
25 | 28 |
|
26 | 29 | PackedWeightsHeader(
|
27 | 30 | PackedWeightsFormat format = PackedWeightsFormat::unknown,
|
28 |
| - params_type params = {0, 0, 0, 0, 0, 0, 0}) |
| 31 | + params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) |
29 | 32 | : format{format}, params{params} {}
|
30 | 33 |
|
31 | 34 | inline static constexpr int size() {
|
32 |
| - static_assert(sizeof(format) + sizeof(params) == 16); |
33 |
| - return 16; |
| 35 | + static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); |
| 36 | + return 64; |
34 | 37 | }
|
35 | 38 |
|
36 | 39 | inline void write(void* packed_weights) const {
|
37 |
| - auto header = (unsigned short*)(packed_weights); |
38 |
| - header[0] = (unsigned short)format; |
| 40 | + auto header = (int*)(packed_weights); |
| 41 | + header[0] = magic; |
| 42 | + header[1] = (int)format; |
39 | 43 | for (int i = 0; i < params.size(); i++) {
|
40 |
| - header[i + 1] = params[i]; |
| 44 | + header[i + 2] = params[i]; |
41 | 45 | }
|
42 | 46 | }
|
43 | 47 |
|
44 | 48 | static PackedWeightsHeader read(const void* packed_weights) {
|
45 |
| - auto header = (unsigned short*)(packed_weights); |
| 49 | + auto header = reinterpret_cast<const int*>(packed_weights); |
| 50 | + assert(header[0] == PackedWeightsHeader::magic); |
46 | 51 | params_type params;
|
47 | 52 | for (int i = 0; i < params.size(); i++) {
|
48 |
| - params[i] = header[i + 1]; |
| 53 | + params[i] = header[i + 2]; |
49 | 54 | }
|
50 |
| - return PackedWeightsHeader((PackedWeightsFormat)header[0], params); |
| 55 | + return PackedWeightsHeader((PackedWeightsFormat)header[1], params); |
51 | 56 | }
|
52 | 57 |
|
53 | 58 | bool operator==(const PackedWeightsHeader& other) const {
|
|
0 commit comments