Skip to content

Header bug fix #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int nr,
int kr,
int version = 1) {
TORCHAO_CHECK(
version >= 0 && version < 256, "version must be between 0 and 255");
Comment on lines -20 to -21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is wrong with this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before it existed to make sure version could be packed into a char (since both version/weight_nbit were packed into an unsigned short). But the header now has more space, and version is now an int, so this bound check is no longer needed.

TORCHAO_CHECK(
weight_nbit >= 1 && weight_nbit < 256,
"weight_nbit must be between 1 and 255");
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
{((static_cast<unsigned short>(version) << 8) |
static_cast<unsigned short>(weight_nbit)),
((static_cast<unsigned short>(has_weight_zeros) << 8) |
static_cast<unsigned short>(has_bias)),
static_cast<unsigned short>(nr),
static_cast<unsigned short>(kr),
{version,
weight_nbit,
has_weight_zeros,
has_bias,
nr,
kr,
0,
0,
0,
0,
0,
0,
0,
Expand Down
28 changes: 17 additions & 11 deletions torchao/experimental/ops/packed_weights_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,53 @@
#pragma once
#include <array>

#include <stdint.h>
#include <cassert>

namespace torchao::ops {

enum PackedWeightsFormat : unsigned short {
enum class PackedWeightsFormat : uint32_t {
unknown = 0,
linear_8bit_act_xbit_weight_universal = 1
};

class PackedWeightsHeader {
public:
using params_type = std::array<unsigned short, 7>;
using params_type = std::array<int, 14>;
const static int magic = 6712;
PackedWeightsFormat format;

// 14 bytes of format specific params
params_type params;

PackedWeightsHeader(
PackedWeightsFormat format = PackedWeightsFormat::unknown,
params_type params = {0, 0, 0, 0, 0, 0, 0})
params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
: format{format}, params{params} {}

inline static constexpr int size() {
static_assert(sizeof(format) + sizeof(params) == 16);
return 16;
static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64);
return 64;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use sizeof(XYZ)?

Suggested change
return 64;
return sizeof(PackedWeightsHeader);

Copy link
Contributor Author

@metascroy metascroy Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this made it clearer that the header is 64 bytes at a glance. I also checked and sizeof(PackedWeightsHeader) = 60, but the serialized size is still 64 bytes (I think because the static magic number doesn't count toward the size).

}

inline void write(void* packed_weights) const {
auto header = (unsigned short*)(packed_weights);
header[0] = (unsigned short)format;
auto header = reinterpret_cast<int*>(packed_weights);
header[0] = magic;
header[1] = static_cast<int>(format);
for (int i = 0; i < params.size(); i++) {
header[i + 1] = params[i];
header[i + 2] = params[i];
}
}

static PackedWeightsHeader read(const void* packed_weights) {
auto header = (unsigned short*)(packed_weights);
auto header = reinterpret_cast<const int*>(packed_weights);
assert(header[0] == PackedWeightsHeader::magic);
params_type params;
for (int i = 0; i < params.size(); i++) {
params[i] = header[i + 1];
params[i] = header[i + 2];
}
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
return PackedWeightsHeader(
static_cast<PackedWeightsFormat>(header[1]), params);
}

bool operator==(const PackedWeightsHeader& other) const {
Expand Down
Loading