Skip to content

[core] Change u2 values packing in byte #31181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 103 additions & 98 deletions src/core/dev_api/openvino/core/type/element_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ constexpr bool is_byte_type(Type_t et) {
return !is_bit_type(et) && !is_split_bit_type(et) && !is_nibble_type(et) && et != string;
}

/**
* @brief Checks if type is packet in byte from LSB -> MSB order.
*
* @param et Element type to check.
* @return True if type is packet LSB first, false otherwise.
*/
constexpr bool is_lsb_packed(Type_t et) {
return et != u1;
}

/**
* @brief Gets bit width of ov::element::Type_t.
*
Expand Down Expand Up @@ -139,16 +149,17 @@ class BitProxy {};
* @tparam ET OpenVINO element type.
*/
template <class T, Type_t ET>
class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(ET)>::type> {
class BitProxy<T, ET, std::enable_if_t<is_bit_type(ET) || is_nibble_type(ET)>> {
private:
template <Type_t, class>
friend class Iterator; //!< Iterator class is friend to access private members to manipulate pointer.

using Bits = typename std::conditional<std::is_const<T>::value, const uint8_t, uint8_t>::type;
using Bits = std::conditional_t<std::is_const_v<T>, const uint8_t, uint8_t>;

static constexpr size_t m_bits = bit_width<ET>(); //!< Number of bit for single value.
static constexpr size_t m_num_values = 8 / m_bits; //!< Number values in byte.
static constexpr size_t m_shift_init = is_nibble_type(ET) ? 0 : 8 - m_bits; //!< Initial value for bit shift.
static constexpr size_t m_bits = bit_width<ET>(); //!< Number of bit for single value.
static constexpr size_t m_num_values = 8 / m_bits; //!< Number values in byte.
static constexpr size_t m_shift_init = is_lsb_packed(ET) ? 0 : 8 - m_bits; //!< Initial value for bit shift.
static constexpr size_t m_shift_last = is_lsb_packed(ET) ? 8 - m_bits : 0; //!< Last value for bit shift.

Bits* m_ptr; //!< Pointer to T as Bits used to get value from bits.
size_t m_bit_shift; //!< Current bit shift to get value.
Expand All @@ -167,7 +178,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
}

public:
using value_type = typename std::decay<T>::type; //!< Fundamental type of bound to BitProxy.
using value_type = std::decay_t<T>; //!< Fundamental type of bound to BitProxy.

/**
* @brief Compare proxy value with other provided value.
Expand Down Expand Up @@ -196,7 +207,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
*
* @return Value of BitProxy.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT != i4 && ETT != f4e2m1>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT != i4 && ETT != f4e2m1>* = nullptr>
operator value_type() const {
return static_cast<value_type>(get_bit_value());
}
Expand All @@ -206,7 +217,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
*
* @return Value of BitProxy.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT == f4e2m1>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == f4e2m1>* = nullptr>
operator value_type() const {
return value_type::from_bits(get_bit_value());
}
Expand All @@ -219,12 +230,12 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
*
* @return Converted NF4 value to float.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT == nf4>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == nf4>* = nullptr>
operator float() const {
return ConvertNF4::dequantize(get_bit_value());
}

template <Type_t ETT = ET, typename std::enable_if<ETT == f4e2m1>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == f4e2m1>* = nullptr>
operator float() const {
return static_cast<float>(float4_e2m1::from_bits(get_bit_value()));
}
Expand All @@ -234,7 +245,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
*
* @return Value of BitProxy.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT == i4>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == i4>* = nullptr>
operator value_type() const {
constexpr auto value_mask = util::make_n_bit_mask(m_bits);
constexpr auto value_msb_mask = (1U << (m_bits - 1U));
Expand All @@ -252,7 +263,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
* @brief Sets current ProxyBit to value.
* @param v Value to be set.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT != f4e2m1>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT != f4e2m1>* = nullptr>
BitProxy<T, ET>& operator=(const value_type v) {
constexpr auto value_mask = util::make_n_bit_mask(m_bits);
set_bit_value(static_cast<uint8_t>(v) & value_mask);
Expand All @@ -263,7 +274,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
* @brief Sets current ProxyBit to value (f4e2m1 specialization).
* @param v Value to be set.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT == f4e2m1>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == f4e2m1>* = nullptr>
BitProxy<T, ET>& operator=(const value_type v) {
set_bit_value(v.to_bits());
return *this;
Expand All @@ -273,7 +284,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
* @brief Sets current NF4 value from float using quantization.
* @param v Value to be set.
*/
template <Type_t ETT = ET, typename std::enable_if<ETT == nf4>::type* = nullptr>
template <Type_t ETT = ET, std::enable_if_t<ETT == nf4>* = nullptr>
BitProxy<T, ET>& operator=(const float v) {
set_bit_value(ConvertNF4::quantize(v));
return *this;
Expand All @@ -289,7 +300,7 @@ class BitProxy<T, ET, typename std::enable_if<is_bit_type(ET) || is_nibble_type(
* @tparam ET OpenVINO element type.
*/
template <class T, Type_t ET>
class BitProxy<T, ET, typename std::enable_if<is_split_bit_type(ET)>::type> {
class BitProxy<T, ET, std::enable_if_t<is_split_bit_type(ET)>> {
private:
template <Type_t, class>
friend class Iterator; //!< Iterator class is friend to access private members to manipulate pointer.
Expand All @@ -314,7 +325,7 @@ class BitProxy<T, ET, typename std::enable_if<is_split_bit_type(ET)>::type> {
constexpr BitProxy(T* ptr) noexcept : m_ptr{ptr}, m_bit_shift{m_shift_init} {}

public:
using value_type = typename std::decay<T>::type; //!< Fundamental type of sub-byte.
using value_type = std::decay_t<T>; //!< Fundamental type of sub-byte.

/**
* @brief Compare proxy value is equal than rhs.
Expand Down Expand Up @@ -411,41 +422,38 @@ class Iterator {
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = T;
using reference = typename std::conditional<std::is_const<T>::value, const proxy_type&, proxy_type&>::type;
using pointer = typename std::conditional<std::is_const<T>::value, const proxy_type*, proxy_type*>::type;
using reference = std::conditional_t<std::is_const<T>::value, const proxy_type&, proxy_type&>;
using pointer = std::conditional_t<std::is_const<T>::value, const proxy_type*, proxy_type*>;

static_assert(std::is_same<typename std::decay<T>::type, ov::fundamental_type_for<ET>>::value,
static_assert(std::is_same_v<std::decay_t<T>, ov::fundamental_type_for<ET>>,
"Iterator value_type must be same as fundamental type of ET");

constexpr Iterator(T* ptr) noexcept : m_et_ptr{ptr} {}

template <class U>
constexpr Iterator(U* ptr) noexcept : m_et_ptr{reinterpret_cast<T*>(ptr)} {
static_assert(std::is_same<typename std::decay<U>::type, int8_t>::value,
static_assert(std::is_same_v<typename std::decay_t<U>, int8_t>,
"User type must be int8_t as base type for LP ET");
}

// Iteration operators
template <Type_t ETT = ET>
typename std::enable_if<is_bit_type(ETT), Iterator<ET, T>>::type& operator++() {
m_et_ptr.m_bit_shift -= m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % (m_et_ptr.m_num_values * m_et_ptr.m_bits);
m_et_ptr.m_ptr += static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init);
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_nibble_type(ETT), Iterator<ET, T>>::type& operator++() {
m_et_ptr.m_bit_shift ^= m_et_ptr.m_bits;
m_et_ptr.m_ptr += static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init);
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_split_bit_type(ETT), Iterator<ET, T>>::type& operator++() {
--m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % m_et_ptr.m_num_values;
m_et_ptr.m_ptr += (m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init) ? 3 : 0;
Iterator<ET, T>& operator++() {
if constexpr (is_nibble_type(ET)) {
m_et_ptr.m_bit_shift ^= m_et_ptr.m_bits;
m_et_ptr.m_ptr += static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init);
} else if constexpr (is_split_bit_type(ET)) {
--m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % m_et_ptr.m_num_values;
m_et_ptr.m_ptr += (m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init) ? 3 : 0;
} else {
if constexpr (is_lsb_packed(ET)) {
m_et_ptr.m_bit_shift += m_et_ptr.m_bits;
} else {
m_et_ptr.m_bit_shift -= m_et_ptr.m_bits;
}
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % (m_et_ptr.m_num_values * m_et_ptr.m_bits);
m_et_ptr.m_ptr += static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == m_et_ptr.m_shift_init);
}
return *this;
}

Expand All @@ -455,25 +463,25 @@ class Iterator {
return old;
}

template <Type_t ETT = ET>
typename std::enable_if<is_bit_type(ETT), Iterator<ET, T>>::type& operator+=(const difference_type& n) {
const auto advance = n + (m_et_ptr.m_shift_init - m_et_ptr.m_bit_shift) / m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = m_et_ptr.m_shift_init - (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr += advance / m_et_ptr.m_num_values;
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_nibble_type(ETT), Iterator<ET, T>>::type& operator+=(const difference_type& n) {
m_et_ptr.m_ptr += n / m_et_ptr.m_num_values;
return (n % m_et_ptr.m_num_values) ? ++*this : *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_split_bit_type(ETT), Iterator<ET, T>>::type& operator+=(const difference_type& n) {
const auto advance = n + m_et_ptr.m_shift_init - m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_shift_init - (advance % m_et_ptr.m_num_values);
m_et_ptr.m_ptr += 3 * (advance / m_et_ptr.m_num_values);
Iterator<ET, T>& operator+=(const difference_type& n) {
if constexpr (is_nibble_type(ET)) {
m_et_ptr.m_ptr += n / m_et_ptr.m_num_values;
if (n % m_et_ptr.m_num_values) {
++*this;
}
} else if constexpr (is_split_bit_type(ET)) {
const auto advance = n + m_et_ptr.m_shift_init - m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_shift_init - (advance % m_et_ptr.m_num_values);
m_et_ptr.m_ptr += 3 * (advance / m_et_ptr.m_num_values);
} else if constexpr (is_lsb_packed(ET)) {
const auto advance = n + m_et_ptr.m_bit_shift / m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr += advance / m_et_ptr.m_num_values;
} else {
const auto advance = n + (m_et_ptr.m_shift_init - m_et_ptr.m_bit_shift) / m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = m_et_ptr.m_shift_init - (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr += advance / m_et_ptr.m_num_values;
}
return *this;
}

Expand All @@ -483,26 +491,23 @@ class Iterator {
return tmp;
}

template <Type_t ETT = ET>
typename std::enable_if<is_bit_type(ETT), Iterator<ET, T>>::type& operator--() {
m_et_ptr.m_bit_shift += m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % (m_et_ptr.m_num_values * m_et_ptr.m_bits);
m_et_ptr.m_ptr -= static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == 0);
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_nibble_type(ETT), Iterator<ET, T>>::type& operator--() {
m_et_ptr.m_bit_shift ^= m_et_ptr.m_bits;
m_et_ptr.m_ptr -= static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == 4);
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_split_bit_type(ETT), Iterator<ET, T>>::type& operator--() {
++m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % m_et_ptr.m_num_values;
m_et_ptr.m_ptr -= m_et_ptr.m_bit_shift == 0 ? 3 : 0;
Iterator<ET, T>& operator--() {
if constexpr (is_nibble_type(ET)) {
m_et_ptr.m_bit_shift ^= m_et_ptr.m_bits;
m_et_ptr.m_ptr -= static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == 4);
} else if constexpr (is_split_bit_type(ET)) {
++m_et_ptr.m_bit_shift;
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % m_et_ptr.m_num_values;
m_et_ptr.m_ptr -= m_et_ptr.m_bit_shift == 0 ? 3 : 0;
} else {
if constexpr (is_lsb_packed(ET)) {
m_et_ptr.m_bit_shift -= m_et_ptr.m_bits;
} else {
m_et_ptr.m_bit_shift += m_et_ptr.m_bits;
}
m_et_ptr.m_bit_shift = m_et_ptr.m_bit_shift % (m_et_ptr.m_num_values * m_et_ptr.m_bits);
m_et_ptr.m_ptr -= static_cast<std::ptrdiff_t>(m_et_ptr.m_bit_shift == m_et_ptr.m_shift_last);
}
return *this;
}

Expand All @@ -512,25 +517,25 @@ class Iterator {
return old;
}

template <Type_t ETT = ET>
typename std::enable_if<is_bit_type(ETT), Iterator<ET, T>>::type& operator-=(const difference_type& n) {
const auto advance = m_et_ptr.m_bit_shift / m_et_ptr.m_bits + n;
m_et_ptr.m_bit_shift = (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr -= advance / m_et_ptr.m_num_values;
return *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_nibble_type(ETT), Iterator<ET, T>>::type& operator-=(const difference_type& n) {
m_et_ptr.m_ptr -= n / m_et_ptr.m_num_values;
return (n % m_et_ptr.m_num_values) ? --*this : *this;
}

template <Type_t ETT = ET>
typename std::enable_if<is_split_bit_type(ETT), Iterator<ET, T>>::type& operator-=(const difference_type& n) {
const auto advance = m_et_ptr.m_bit_shift + n;
m_et_ptr.m_bit_shift = advance % m_et_ptr.m_num_values;
m_et_ptr.m_ptr -= 3 * (advance / m_et_ptr.m_num_values);
Iterator<ET, T>& operator-=(const difference_type& n) {
if constexpr (is_nibble_type(ET)) {
m_et_ptr.m_ptr -= n / m_et_ptr.m_num_values;
if (n % m_et_ptr.m_num_values) {
--*this;
}
} else if constexpr (is_split_bit_type(ET)) {
const auto advance = m_et_ptr.m_bit_shift + n;
m_et_ptr.m_bit_shift = advance % m_et_ptr.m_num_values;
m_et_ptr.m_ptr -= 3 * (advance / m_et_ptr.m_num_values);
} else if constexpr (is_lsb_packed(ET)) {
const auto advance = n + (m_et_ptr.m_shift_last - m_et_ptr.m_bit_shift) / m_et_ptr.m_bits;
m_et_ptr.m_bit_shift = m_et_ptr.m_shift_last - (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr -= advance / m_et_ptr.m_num_values;
} else {
const auto advance = m_et_ptr.m_bit_shift / m_et_ptr.m_bits + n;
m_et_ptr.m_bit_shift = (advance % m_et_ptr.m_num_values) * m_et_ptr.m_bits;
m_et_ptr.m_ptr -= advance / m_et_ptr.m_num_values;
}
return *this;
}

Expand Down
4 changes: 3 additions & 1 deletion src/core/src/op/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ void Constant::set_unused_bits(void* buffer) const {
if (element::is_bit_type(m_element_type)) {
constexpr size_t storage_unit_byte_size = 1;
const auto not_aligned_elements = num_elements % (8 / m_element_type.bitwidth());
const uint8_t not_used_bits_mask = 0xff >> (m_element_type.bitwidth() * not_aligned_elements);
const uint8_t not_used_bits_mask = element::is_lsb_packed(m_element_type)
? 0xff << (m_element_type.bitwidth() * not_aligned_elements)
: 0xff >> (m_element_type.bitwidth() * not_aligned_elements);
reinterpret_cast<uint8_t*>(buffer)[byte_size - storage_unit_byte_size] &= ~not_used_bits_mask;
} else if (element::is_nibble_type(m_element_type) && (num_elements % 2)) {
constexpr size_t storage_unit_byte_size = 1;
Expand Down
Loading
Loading