-
Notifications
You must be signed in to change notification settings - Fork 257
Header bug fix #1079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Header bug fix #1079
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,47 +7,53 @@ | |||||
#pragma once | ||||||
#include <array> | ||||||
|
||||||
#include <stdint.h> | ||||||
#include <cassert> | ||||||
|
||||||
namespace torchao::ops { | ||||||
|
||||||
enum PackedWeightsFormat : unsigned short { | ||||||
enum class PackedWeightsFormat : uint32_t { | ||||||
unknown = 0, | ||||||
linear_8bit_act_xbit_weight_universal = 1 | ||||||
}; | ||||||
|
||||||
class PackedWeightsHeader { | ||||||
public: | ||||||
using params_type = std::array<unsigned short, 7>; | ||||||
using params_type = std::array<int, 14>; | ||||||
const static int magic = 6712; | ||||||
PackedWeightsFormat format; | ||||||
|
||||||
// 14 bytes of format specific params | ||||||
params_type params; | ||||||
|
||||||
PackedWeightsHeader( | ||||||
PackedWeightsFormat format = PackedWeightsFormat::unknown, | ||||||
params_type params = {0, 0, 0, 0, 0, 0, 0}) | ||||||
params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) | ||||||
: format{format}, params{params} {} | ||||||
|
||||||
inline static constexpr int size() { | ||||||
static_assert(sizeof(format) + sizeof(params) == 16); | ||||||
return 16; | ||||||
static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); | ||||||
return 64; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use sizeof(XYZ)?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought this made it clearer that the header is 64 bytes at a glance. I also checked and sizeof(PackedWeightsHeader) = 60, but the serialized size is still 64 bytes (I think because the static magic number doesn't count toward the size). |
||||||
} | ||||||
|
||||||
inline void write(void* packed_weights) const { | ||||||
auto header = (unsigned short*)(packed_weights); | ||||||
header[0] = (unsigned short)format; | ||||||
auto header = reinterpret_cast<int*>(packed_weights); | ||||||
header[0] = magic; | ||||||
header[1] = static_cast<int>(format); | ||||||
for (int i = 0; i < params.size(); i++) { | ||||||
header[i + 1] = params[i]; | ||||||
header[i + 2] = params[i]; | ||||||
} | ||||||
} | ||||||
|
||||||
static PackedWeightsHeader read(const void* packed_weights) { | ||||||
auto header = (unsigned short*)(packed_weights); | ||||||
auto header = reinterpret_cast<const int*>(packed_weights); | ||||||
assert(header[0] == PackedWeightsHeader::magic); | ||||||
params_type params; | ||||||
for (int i = 0; i < params.size(); i++) { | ||||||
params[i] = header[i + 1]; | ||||||
params[i] = header[i + 2]; | ||||||
} | ||||||
return PackedWeightsHeader((PackedWeightsFormat)header[0], params); | ||||||
return PackedWeightsHeader( | ||||||
static_cast<PackedWeightsFormat>(header[1]), params); | ||||||
} | ||||||
|
||||||
bool operator==(const PackedWeightsHeader& other) const { | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is wrong with this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before it existed to make sure version could be packed into a char (since both version/weight_nbit were packed into an unsigned short). But the header now has more space, and version is now an int, so this bound check is no longer needed.