Skip to content

Add int tensor cast - WIP #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sort_with_indices(tensor, dim, descending)
}

fn int_cast(tensor: IntTensor<Self>, dtype: burn_tensor::IntDType) -> IntTensor<Self> {
B::int_cast(tensor, dtype)
}

fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_argsort(tensor, dim, descending)
}
Expand Down
18 changes: 17 additions & 1 deletion crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_common::future::DynFut;
use burn_tensor::{
Bool, Device, Distribution, ElementConversion, Shape, TensorData,
Bool, Device, Distribution, ElementConversion, IntDType, Shape, TensorData,
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
};

Expand Down Expand Up @@ -369,6 +369,22 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}

fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
let dtype = match dtype {
IntDType::I64 => candle_core::DType::I64,
IntDType::U8 => candle_core::DType::U8,
IntDType::U32 => candle_core::DType::U32,
_ => panic!("candle doesn't support this dtype"),
};

if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}

fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
}
Expand Down
30 changes: 30 additions & 0 deletions crates/burn-cubecl/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,36 @@ where
kernel::flip::<R, I, BT>(tensor, axes)
}

fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
match (tensor.dtype, dtype) {
(Dtype::I8, IntDType::I8)
| (Dtype::I16, IntDType::I16)
| (Dtype::I32, IntDType::I32)
| (Dtype::I64, IntDType::I64)
| (Dtype::U32, IntDType::U32) => tensor,
(Dtype::I64, IntDType::I32) => kernel::cast::<R, i64, i32>(tensor),
(Dtype::I64, IntDType::I16) => kernel::cast::<R, i64, i16>(tensor),
(Dtype::I64, IntDType::I8) => kernel::cast::<R, i64, i8>(tensor),
(Dtype::I32, IntDType::I64) => kernel::cast::<R, i32, i64>(tensor),
(Dtype::I32, IntDType::I16) => kernel::cast::<R, i32, i16>(tensor),
(Dtype::I32, IntDType::I8) => kernel::cast::<R, i32, i8>(tensor),
(Dtype::I16, IntDType::I64) => kernel::cast::<R, i16, i64>(tensor),
(Dtype::I16, IntDType::I32) => kernel::cast::<R, i16, i32>(tensor),
(Dtype::I16, IntDType::I8) => kernel::cast::<R, i16, i8>(tensor),
(Dtype::I8, IntDType::I64) => kernel::cast::<R, i8, i64>(tensor),
(Dtype::I8, IntDType::I32) => kernel::cast::<R, i8, i32>(tensor),
(Dtype::I8, IntDType::I16) => kernel::cast::<R, i8, i16>(tensor),
(Dtype::I64, IntDType::U32) => kernel::cast::<R, i64, u32>(tensor),
(Dtype::I32, IntDType::U32) => kernel::cast::<R, i32, u32>(tensor),
(Dtype::I16, IntDType::U32) => kernel::cast::<R, i16, u32>(tensor),
(Dtype::I8, IntDType::U32) => kernel::cast::<R, i8, u32>(tensor),
(Dtype::U32, IntDType::I64) => kernel::cast::<R, u32, i64>(tensor),
(Dtype::U32, IntDType::I32) => kernel::cast::<R, u32, i32>(tensor),
(Dtype::U32, IntDType::I16) => kernel::cast::<R, u32, i16>(tensor),
(Dtype::U32, IntDType::I8) => kernel::cast::<R, u32, i8>(tensor),
}
}

fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
numeric::bitwise_and::<R, I>(lhs, rhs)
}
Expand Down
33 changes: 33 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1901,6 +1901,39 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_cast(tensor: IntTensor<Self>, dtype: burn_tensor::IntDType) -> IntTensor<Self> {
#[derive(new)]
struct CastOps<B: FusionBackend> {
desc: UnaryOpIr,
dtype: burn_tensor::IntDType,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_int_tensor::<B>(&self.desc.input);
let output: B::IntTensorPrimitive = B::int_cast(tensor, self.dtype);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), dtype.into());

let desc = UnaryOpIr {
input: tensor.into_ir(),
out: out.to_ir_out(),
};
out.client.register(
vec![stream],
OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),
CastOps::<B>::new(desc, dtype),
);

out
}

fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(BitwiseAndOps, B::bitwise_and);

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-ndarray/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
use crate::{NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat};
use crate::{NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat, NdArrayTensorInt};
use alloc::string::String;
use burn_common::stub::Mutex;
use burn_ir::{BackendIr, HandleKind, TensorHandle};
Expand Down Expand Up @@ -48,7 +48,7 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for
type FloatTensorPrimitive = NdArrayTensorFloat;
type FloatElem = E;

type IntTensorPrimitive = NdArrayTensor<I>;
type IntTensorPrimitive = NdArrayTensorInt;
type IntElem = I;

type BoolTensorPrimitive = NdArrayTensor<bool>;
Expand Down
8 changes: 7 additions & 1 deletion crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ where
}

/// An int element for ndarray backend.
pub trait IntNdArrayElement: NdArrayElement + Signed {}
pub trait IntNdArrayElement: NdArrayElement + Signed{}

/// A general element for ndarray backend.
pub trait NdArrayElement:
Expand Down Expand Up @@ -64,6 +64,12 @@ impl FloatNdArrayElement for f32 {}

impl IntNdArrayElement for i64 {}
impl IntNdArrayElement for i32 {}
impl IntNdArrayElement for i16 {}
impl IntNdArrayElement for i8 {}
impl IntNdArrayElement for u64 {}
impl IntNdArrayElement for u32 {}
impl IntNdArrayElement for u16 {}
impl IntNdArrayElement for u8 {}

macro_rules! make_elem {
(
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-ndarray/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Language
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensorOps};
use burn_tensor::ops::{BoolTensorOps, FloatTensor, IntTensor, IntTensorOps};
use burn_tensor::{ElementConversion, TensorMetadata};
use core::ops::Range;
use ndarray::IntoDimension;
Expand Down Expand Up @@ -42,7 +42,7 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOp
NdArrayOps::slice(tensor, ranges)
}

fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
fn bool_into_int(tensor: NdArrayTensor<bool>) -> IntTensor<Self> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
NdArray::<E, I>::int_from_data(
Expand Down
Loading