Skip to content

Commit 04fa11d

Browse files
metascroyfacebook-github-bot
authored andcommitted
Header bug fix (pytorch#1079)
Summary: A last minute change created a compile error on the header. This fixes the issue. I also make the header 64 bytes and add a magic number at the start to make it safer in future. Differential Revision: D64370707
1 parent 48bc81c commit 04fa11d

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@ torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
1717
int nr,
1818
int kr,
1919
int version = 1) {
20-
TORCHAO_CHECK(
21-
version >= 0 && version < 256, "version must be between 0 and 255");
22-
TORCHAO_CHECK(
23-
weight_nbit >= 1 && weight_nbit < 256,
24-
"weight_nbit must be between 1 and 255");
2520
return torchao::ops::PackedWeightsHeader(
2621
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
27-
{((static_cast<unsigned short>(version) << 8) |
28-
static_cast<unsigned short>(weight_nbit)),
29-
((static_cast<unsigned short>(has_weight_zeros) << 8) |
30-
static_cast<unsigned short>(has_bias)),
31-
static_cast<unsigned short>(nr),
32-
static_cast<unsigned short>(kr),
22+
{version,
23+
weight_nbit,
24+
has_weight_zeros,
25+
has_bias,
26+
nr,
27+
kr,
28+
0,
29+
0,
30+
0,
31+
0,
3332
0,
3433
0,
3534
0,

torchao/experimental/ops/packed_weights_header.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,52 @@
77
#pragma once
88
#include <array>
99

10+
#include <stdint.h>
1011
#include <cassert>
12+
1113
namespace torchao::ops {
1214

13-
enum PackedWeightsFormat : unsigned short {
15+
enum class PackedWeightsFormat : uint32_t {
1416
unknown = 0,
1517
linear_8bit_act_xbit_weight_universal = 1
1618
};
1719

1820
class PackedWeightsHeader {
1921
public:
20-
using params_type = std::array<unsigned short, 7>;
22+
using params_type = std::array<int, 14>;
23+
const static int magic = 6712;
2124
PackedWeightsFormat format;
2225

2326
// 14 bytes of format specific params
2427
params_type params;
2528

2629
PackedWeightsHeader(
2730
PackedWeightsFormat format = PackedWeightsFormat::unknown,
28-
params_type params = {0, 0, 0, 0, 0, 0, 0})
31+
params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
2932
: format{format}, params{params} {}
3033

3134
inline static constexpr int size() {
32-
static_assert(sizeof(format) + sizeof(params) == 16);
33-
return 16;
35+
static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64);
36+
return 64;
3437
}
3538

3639
inline void write(void* packed_weights) const {
37-
auto header = (unsigned short*)(packed_weights);
38-
header[0] = (unsigned short)format;
40+
auto header = (int*)(packed_weights);
41+
header[0] = magic;
42+
header[1] = (int)format;
3943
for (int i = 0; i < params.size(); i++) {
40-
header[i + 1] = params[i];
44+
header[i + 2] = params[i];
4145
}
4246
}
4347

4448
static PackedWeightsHeader read(const void* packed_weights) {
45-
auto header = (unsigned short*)(packed_weights);
49+
auto header = reinterpret_cast<const int*>(packed_weights);
50+
assert(header[0] == PackedWeightsHeader::magic);
4651
params_type params;
4752
for (int i = 0; i < params.size(); i++) {
48-
params[i] = header[i + 1];
53+
params[i] = header[i + 2];
4954
}
50-
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
55+
return PackedWeightsHeader((PackedWeightsFormat)header[1], params);
5156
}
5257

5358
bool operator==(const PackedWeightsHeader& other) const {

0 commit comments

Comments
 (0)