File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
backends/candle/src/layers Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change 11// SPDX-License-Identifier: MIT or Apache-2.0
22// First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil
33
4- use candle:: { Result , Tensor } ;
4+ use candle:: { DType , Result , Tensor } ;
55#[ cfg( feature = "cuda" ) ]
66use candle_index_select_cu;
77
88#[ inline]
99#[ allow( dead_code) ]
1010pub fn index_select ( tensor : & Tensor , ids : & Tensor , dim : usize ) -> Result < Tensor > {
11- #[ cfg( not( feature = "cuda" ) ) ]
12- {
13- tensor. index_select ( ids, dim)
14- }
15- #[ cfg( feature = "cuda" ) ]
11+ if cfg ! ( feature = "cuda" )
12+ && matches ! ( tensor. dtype( ) , DType :: F16 | DType :: F32 )
13+ && matches ! ( ids. dtype( ) , DType :: U32 )
1614 {
15+ // NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
1716 candle_index_select_cu:: index_select ( tensor, ids, dim)
17+ } else {
18+ tensor. index_select ( ids, dim)
1819 }
1920}
You can’t perform that action at this time.
0 commit comments