Skip to content

Commit b83102a

Browse files
committed
Update index_select to exclude CUDA + BF16
1 parent df60701 commit b83102a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
// SPDX-License-Identifier: MIT or Apache-2.0
22
// First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil
33

4-
use candle::{Result, Tensor};
4+
use candle::{DType, Result, Tensor};
55
#[cfg(feature = "cuda")]
66
use candle_index_select_cu;
77

88
#[inline]
99
#[allow(dead_code)]
1010
pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result<Tensor> {
11-
#[cfg(not(feature = "cuda"))]
12-
{
13-
tensor.index_select(ids, dim)
14-
}
15-
#[cfg(feature = "cuda")]
11+
if cfg!(feature = "cuda")
12+
&& matches!(tensor.dtype(), DType::F16 | DType::F32)
13+
&& matches!(ids.dtype(), DType::U32)
1614
{
15+
// NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
1716
candle_index_select_cu::index_select(tensor, ids, dim)
17+
} else {
18+
tensor.index_select(ids, dim)
1819
}
1920
}

0 commit comments

Comments
 (0)