Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ for specific CPU features which avoids the runtime overhead and works in a `no_s
| ------------ | ------------------ | ----- |
| `x86`/`x86_64` | `f16c` | This supports conversion to/from `f16` only (including vector SIMD) and does not support any `bf16` or arithmetic operations. |
| `aarch64` | `fp16` | This supports all operations on `f16` only. |
| `loongarch64` | `lsx` | This supports conversion to/from `f16` only (including vector SIMD) and does not support any `bf16` or arithmetic operations. |

### More Documentation

Expand Down
74 changes: 73 additions & 1 deletion src/binary16/arch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ mod x86;
#[cfg(target_arch = "aarch64")]
mod aarch64;

#[cfg(target_arch = "loongarch64")]
mod loongarch64;

macro_rules! convert_fn {
(if x86_feature("f16c") { $f16c:expr }
else if aarch64_feature("fp16") { $aarch64:expr }
else if loongarch64_feature("lsx") { $loongarch64:expr }
else { $fallback:expr }) => {
cfg_if::cfg_if! {
// Use intrinsics directly when a compile target or using no_std
Expand All @@ -25,7 +29,12 @@ macro_rules! convert_fn {
target_feature = "fp16"
))] {
$aarch64

}
else if #[cfg(all(
target_arch = "loongarch64",
target_feature = "lsx"
))] {
$loongarch64
}

// Use CPU feature detection if using std
Expand All @@ -51,6 +60,17 @@ macro_rules! convert_fn {
$fallback
}
}
else if #[cfg(all(
feature = "std",
target_arch = "loongarch64",
))] {
use std::arch::is_loongarch_feature_detected;
if is_loongarch_feature_detected!("lsx") {
$loongarch64
} else {
$fallback
}
}

// Fallback to software
else {
Expand All @@ -67,6 +87,8 @@ pub(crate) fn f32_to_f16(f: f32) -> u16 {
unsafe { x86::f32_to_f16_x86_f16c(f) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f32_to_f16_fp16(f) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f32_to_f16_lsx(f) }
} else {
f32_to_f16_fallback(f)
}
Expand All @@ -80,6 +102,8 @@ pub(crate) fn f64_to_f16(f: f64) -> u16 {
unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f64_to_f16_fp16(f) }
} else if loongarch64_feature("lsx") {
f64_to_f16_fallback(f)
} else {
f64_to_f16_fallback(f)
}
Expand All @@ -93,6 +117,8 @@ pub(crate) fn f16_to_f32(i: u16) -> f32 {
unsafe { x86::f16_to_f32_x86_f16c(i) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f16_to_f32_fp16(i) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f16_to_f32_lsx(i) }
} else {
f16_to_f32_fallback(i)
}
Expand All @@ -106,6 +132,8 @@ pub(crate) fn f16_to_f64(i: u16) -> f64 {
unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f16_to_f64_fp16(i) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f16_to_f32_lsx(i) as f64 }
} else {
f16_to_f64_fallback(i)
}
Expand All @@ -119,6 +147,8 @@ pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] {
unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f32x4_to_f16x4_fp16(f) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f32x4_to_f16x4_lsx(f) }
} else {
f32x4_to_f16x4_fallback(f)
}
Expand All @@ -132,6 +162,8 @@ pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] {
unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f16x4_to_f32x4_fp16(i) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f16x4_to_f32x4_lsx(i) }
} else {
f16x4_to_f32x4_fallback(i)
}
Expand All @@ -145,6 +177,8 @@ pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] {
unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f64x4_to_f16x4_fp16(f) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f64x4_to_f16x4_lsx(f) }
} else {
f64x4_to_f16x4_fallback(f)
}
Expand All @@ -158,6 +192,8 @@ pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] {
unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
} else if aarch64_feature("fp16") {
unsafe { aarch64::f16x4_to_f64x4_fp16(i) }
} else if loongarch64_feature("lsx") {
unsafe { loongarch64::f16x4_to_f64x4_lsx(i) }
} else {
f16x4_to_f64x4_fallback(i)
}
Expand All @@ -176,6 +212,13 @@ pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] {
aarch64::f32x4_to_f16x4_fp16);
result
}
} else if loongarch64_feature("lsx") {
{
let mut result = [0u16; 8];
convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
loongarch64::f32x4_to_f16x4_lsx);
result
}
} else {
f32x8_to_f16x8_fallback(f)
}
Expand All @@ -194,6 +237,13 @@ pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] {
aarch64::f16x4_to_f32x4_fp16);
result
}
} else if loongarch64_feature("lsx") {
{
let mut result = [0f32; 8];
convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
loongarch64::f16x4_to_f32x4_lsx);
result
}
} else {
f16x8_to_f32x8_fallback(i)
}
Expand All @@ -212,6 +262,13 @@ pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] {
aarch64::f64x4_to_f16x4_fp16);
result
}
} else if loongarch64_feature("lsx") {
{
let mut result = [0u16; 8];
convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
loongarch64::f64x4_to_f16x4_lsx);
result
}
} else {
f64x8_to_f16x8_fallback(f)
}
Expand All @@ -230,6 +287,13 @@ pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] {
aarch64::f16x4_to_f64x4_fp16);
result
}
} else if loongarch64_feature("lsx") {
{
let mut result = [0f64; 8];
convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
loongarch64::f16x4_to_f64x4_lsx);
result
}
} else {
f16x8_to_f64x8_fallback(i)
}
Expand All @@ -244,6 +308,8 @@ pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) {
x86::f32x4_to_f16x4_x86_f16c)
} else if aarch64_feature("fp16") {
convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16)
} else if loongarch64_feature("lsx") {
convert_chunked_slice_4(src, dst, loongarch64::f32x4_to_f16x4_lsx)
} else {
slice_fallback(src, dst, f32_to_f16_fallback)
}
Expand All @@ -258,6 +324,8 @@ pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) {
x86::f16x4_to_f32x4_x86_f16c)
} else if aarch64_feature("fp16") {
convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16)
} else if loongarch64_feature("lsx") {
convert_chunked_slice_4(src, dst, loongarch64::f16x4_to_f32x4_lsx)
} else {
slice_fallback(src, dst, f16_to_f32_fallback)
}
Expand All @@ -272,6 +340,8 @@ pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) {
x86::f64x4_to_f16x4_x86_f16c)
} else if aarch64_feature("fp16") {
convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16)
} else if loongarch64_feature("lsx") {
convert_chunked_slice_4(src, dst, loongarch64::f64x4_to_f16x4_lsx)
} else {
slice_fallback(src, dst, f64_to_f16_fallback)
}
Expand All @@ -286,6 +356,8 @@ pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) {
x86::f16x4_to_f64x4_x86_f16c)
} else if aarch64_feature("fp16") {
convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16)
} else if loongarch64_feature("lsx") {
convert_chunked_slice_4(src, dst, loongarch64::f16x4_to_f64x4_lsx)
} else {
slice_fallback(src, dst, f16_to_f64_fallback)
}
Expand Down
63 changes: 63 additions & 0 deletions src/binary16/arch/loongarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use core::{mem::MaybeUninit, ptr};

#[cfg(target_arch = "loongarch64")]
use core::arch::loongarch64::{lsx_vfcvt_h_s, lsx_vfcvtl_s_h, m128, m128i};

/////////////// loongarch64 lsx/lasx ////////////////

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f16_to_f32_lsx(i: u16) -> f32 {
let mut vec = MaybeUninit::<m128i>::zeroed();
vec.as_mut_ptr().cast::<u16>().write(i);
let retval = lsx_vfcvtl_s_h(vec.assume_init());
*(&retval as *const m128).cast()
}

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f32_to_f16_lsx(f: f32) -> u16 {
let mut vec = MaybeUninit::<m128>::zeroed();
vec.as_mut_ptr().cast::<f32>().write(f);
let retval = lsx_vfcvt_h_s(vec.assume_init(), vec.assume_init());
*(&retval as *const m128i).cast()
}

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f16x4_to_f32x4_lsx(v: &[u16; 4]) -> [f32; 4] {
let mut vec = MaybeUninit::<m128i>::zeroed();
ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
let retval = lsx_vfcvtl_s_h(vec.assume_init());
*(&retval as *const m128).cast()
}

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f32x4_to_f16x4_lsx(v: &[f32; 4]) -> [u16; 4] {
let mut vec = MaybeUninit::<m128>::uninit();
ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
let retval = lsx_vfcvt_h_s(vec.assume_init(), vec.assume_init());
*(&retval as *const m128i).cast()
}

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f16x4_to_f64x4_lsx(v: &[u16; 4]) -> [f64; 4] {
let array = f16x4_to_f32x4_lsx(v);
// Let compiler vectorize this regular cast for now.
[
array[0] as f64,
array[1] as f64,
array[2] as f64,
array[3] as f64,
]
}

#[target_feature(enable = "lsx")]
#[inline]
pub(super) unsafe fn f64x4_to_f16x4_lsx(v: &[f64; 4]) -> [u16; 4] {
// Let compiler vectorize this regular cast for now.
let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
f32x4_to_f16x4_lsx(&v)
}
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
//! | ------------ | ------------------ | ----- |
//! | `x86`/`x86_64` | `f16c` | This supports conversion to/from [`struct@f16`] only (including vector SIMD) and does not support any [`struct@bf16`] or arithmetic operations. |
//! | `aarch64` | `fp16` | This supports all operations on [`struct@f16`] only. |
//! | `loongarch64` | `lsx` | This supports conversion to/from [`struct@f16`] only (including vector SIMD) and does not support any [`struct@bf16`] or arithmetic operations. |
//!
//! # Cargo Features
//!
Expand Down Expand Up @@ -214,6 +215,14 @@
future_incompatible
)]
#![cfg_attr(not(target_arch = "spirv"), warn(missing_debug_implementations))]
#![cfg_attr(
target_arch = "loongarch64",
feature(
stdarch_loongarch,
stdarch_loongarch_feature_detection,
loongarch_target_feature
)
)]
#![allow(clippy::verbose_bit_mask, clippy::cast_lossless, unexpected_cfgs)]
#![cfg_attr(not(feature = "std"), no_std)]
#![doc(html_root_url = "https://docs.rs/half/2.6.0")]
Expand Down
Loading