@@ -2060,6 +2060,46 @@ impl Tensor {
2060
2060
Ok(Tensor { c_tensor: c_tensors[0] })
2061
2061
}
2062
2062
2063
+ pub fn f_internal_dyn_quant_matmul_4bit(
2064
+ inp: &Tensor,
2065
+ packed_weights: &Tensor,
2066
+ block_size: i64,
2067
+ in_features: i64,
2068
+ out_features: i64,
2069
+ ) -> Result<Tensor, TchError> {
2070
+ let mut c_tensors = [std::ptr::null_mut(); 1];
2071
+ unsafe_torch_err!(atg__dyn_quant_matmul_4bit(
2072
+ c_tensors.as_mut_ptr(),
2073
+ inp.c_tensor,
2074
+ packed_weights.c_tensor,
2075
+ block_size,
2076
+ in_features,
2077
+ out_features
2078
+ ));
2079
+ Ok(Tensor { c_tensor: c_tensors[0] })
2080
+ }
2081
+
2082
+ pub fn f_internal_dyn_quant_pack_4bit_weight<T: Borrow<Tensor>>(
2083
+ weights: &Tensor,
2084
+ scales_zeros: &Tensor,
2085
+ bias: Option<T>,
2086
+ block_size: i64,
2087
+ in_features: i64,
2088
+ out_features: i64,
2089
+ ) -> Result<Tensor, TchError> {
2090
+ let mut c_tensors = [std::ptr::null_mut(); 1];
2091
+ unsafe_torch_err!(atg__dyn_quant_pack_4bit_weight(
2092
+ c_tensors.as_mut_ptr(),
2093
+ weights.c_tensor,
2094
+ scales_zeros.c_tensor,
2095
+ bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor),
2096
+ block_size,
2097
+ in_features,
2098
+ out_features
2099
+ ));
2100
+ Ok(Tensor { c_tensor: c_tensors[0] })
2101
+ }
2102
+
2063
2103
pub fn f_internal_efficient_attention_backward<T: Borrow<Tensor>>(
2064
2104
grad_out_: &Tensor,
2065
2105
query: &Tensor,
@@ -2900,8 +2940,8 @@ impl Tensor {
2900
2940
max_k: i64,
2901
2941
dropout_p: f64,
2902
2942
is_causal: bool,
2903
- philox_seed : &Tensor,
2904
- philox_offset : &Tensor,
2943
+ rng_state : &Tensor,
2944
+ unused : &Tensor,
2905
2945
scale: impl Into<Option<f64>>,
2906
2946
window_size_left: impl Into<Option<i64>>,
2907
2947
window_size_right: impl Into<Option<i64>>,
@@ -2924,8 +2964,8 @@ impl Tensor {
2924
2964
max_k,
2925
2965
dropout_p,
2926
2966
if is_causal { 1 } else { 0 },
2927
- philox_seed .c_tensor,
2928
- philox_offset .c_tensor,
2967
+ rng_state .c_tensor,
2968
+ unused .c_tensor,
2929
2969
scale.unwrap_or(std::f64::NAN),
2930
2970
scale.is_none() as i8,
2931
2971
window_size_left.unwrap_or(0i64),
@@ -5783,6 +5823,33 @@ impl Tensor {
5783
5823
))
5784
5824
}
5785
5825
5826
+ pub fn f_internal_scaled_grouped_mm<T: Borrow<Tensor>>(
5827
+ &self,
5828
+ mat2: &Tensor,
5829
+ scale_a: &Tensor,
5830
+ scale_b: &Tensor,
5831
+ offs: Option<T>,
5832
+ bias: Option<T>,
5833
+ scale_result: Option<T>,
5834
+ out_dtype: impl Into<Option<Kind>>,
5835
+ use_fast_accum: bool,
5836
+ ) -> Result<Tensor, TchError> {
5837
+ let mut c_tensors = [std::ptr::null_mut(); 1];
5838
+ unsafe_torch_err!(atg__scaled_grouped_mm(
5839
+ c_tensors.as_mut_ptr(),
5840
+ self.c_tensor,
5841
+ mat2.c_tensor,
5842
+ scale_a.c_tensor,
5843
+ scale_b.c_tensor,
5844
+ offs.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor),
5845
+ bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor),
5846
+ scale_result.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor),
5847
+ out_dtype.into().map_or(-1, |s| s.c_int()),
5848
+ if use_fast_accum { 1 } else { 0 }
5849
+ ));
5850
+ Ok(Tensor { c_tensor: c_tensors[0] })
5851
+ }
5852
+
5786
5853
pub fn f_internal_scaled_mm<T: Borrow<Tensor>>(
5787
5854
&self,
5788
5855
mat2: &Tensor,
@@ -35234,6 +35301,7 @@ impl Tensor {
35234
35301
normalized: bool,
35235
35302
onesided: bool,
35236
35303
return_complex: bool,
35304
+ align_to_window: bool,
35237
35305
) -> Result<Tensor, TchError> {
35238
35306
let hop_length = hop_length.into();
35239
35307
let win_length = win_length.into();
@@ -35249,7 +35317,8 @@ impl Tensor {
35249
35317
window.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor),
35250
35318
if normalized { 1 } else { 0 },
35251
35319
if onesided { 1 } else { 0 },
35252
- if return_complex { 1 } else { 0 }
35320
+ if return_complex { 1 } else { 0 },
35321
+ if align_to_window { 1 } else { 0 }
35253
35322
));
35254
35323
Ok(Tensor { c_tensor: c_tensors[0] })
35255
35324
}
@@ -35265,6 +35334,7 @@ impl Tensor {
35265
35334
normalized: bool,
35266
35335
onesided: bool,
35267
35336
return_complex: bool,
35337
+ align_to_window: bool,
35268
35338
) -> Result<Tensor, TchError> {
35269
35339
let hop_length = hop_length.into();
35270
35340
let win_length = win_length.into();
@@ -35283,7 +35353,8 @@ impl Tensor {
35283
35353
pad_mode.len() as i32,
35284
35354
if normalized { 1 } else { 0 },
35285
35355
if onesided { 1 } else { 0 },
35286
- if return_complex { 1 } else { 0 }
35356
+ if return_complex { 1 } else { 0 },
35357
+ if align_to_window { 1 } else { 0 }
35287
35358
));
35288
35359
Ok(Tensor { c_tensor: c_tensors[0] })
35289
35360
}
0 commit comments