@@ -69,6 +69,23 @@ extern "C" int (*LiteRtTopKOpenClSampler_UpdateConfig_Static)(
6969 const LiteRtTopKSampler_SamplerParameters* sampler_params, int batch_size,
7070 void * rand_gen_shared_ptr, char ** error_msg) = nullptr ;
7171
72+ extern " C" int (*LiteRtTopKOpenClSampler_CanHandleInput_Static)(
73+ LiteRtTopKSampler_Sampler* sampler) = nullptr ;
74+
75+ extern " C" int (*LiteRtTopKOpenClSampler_HandlesInput_Static)(
76+ LiteRtTopKSampler_Sampler* sampler) = nullptr ;
77+
78+ extern " C" int (
79+ *LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static)(
80+ LiteRtTopKSampler_Sampler* sampler,
81+ LiteRtTensorBuffer absl_nullable ids_tensor,
82+ LiteRtTensorBuffer absl_nullable prev_input_positions_tensor,
83+ LiteRtTensorBuffer absl_nullable input_positions_tensor,
84+ LiteRtTensorBuffer absl_nullable prev_mask_tensor,
85+ LiteRtTensorBuffer absl_nullable mask_tensor,
86+ int (*run_inference_func)(void * arg), void * arg,
87+ char ** error_msg) = nullptr ;
88+
7289// WebGPU Sampler C API function pointers.
7390extern " C" int (*LiteRtTopKWebGpuSampler_Create_Static)(
7491 LiteRtEnvironment env, int batch_size, int sequence_size, int vocab_size,
@@ -361,7 +378,10 @@ class TopKOpenClCApiSampler : public TopKCApiSampler {
361378 " libLiteRtTopKOpenClSampler.so" , " LiteRtTopKOpenClSampler_Create" ,
362379 " LiteRtTopKOpenClSampler_Destroy" ,
363380 " LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer" ,
364- " LiteRtTopKOpenClSampler_UpdateConfig" );
381+ " LiteRtTopKOpenClSampler_UpdateConfig" ,
382+ " LiteRtTopKOpenClSampler_CanHandleInput" ,
383+ " LiteRtTopKOpenClSampler_HandlesInput" ,
384+ " LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc" );
365385 if (capi_or.ok ()) {
366386 capi = std::move (capi_or.value ());
367387 ABSL_LOG (INFO) << " Dynamically loaded LiteRtTopKOpenClSampler C API." ;
@@ -403,15 +423,22 @@ class TopKOpenClCApiSampler : public TopKCApiSampler {
403423 if (LiteRtTopKOpenClSampler_Create_Static == nullptr ||
404424 LiteRtTopKOpenClSampler_Destroy_Static == nullptr ||
405425 LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer_Static == nullptr ||
406- LiteRtTopKOpenClSampler_UpdateConfig_Static == nullptr ) {
426+ LiteRtTopKOpenClSampler_UpdateConfig_Static == nullptr ||
427+ LiteRtTopKOpenClSampler_CanHandleInput_Static == nullptr ||
428+ LiteRtTopKOpenClSampler_HandlesInput_Static == nullptr ||
429+ LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static ==
430+ nullptr ) {
407431 return absl::UnavailableError (
408432 " Static LiteRtTopKOpenClSampler C API not available." );
409433 }
410434 return std::make_unique<TopKSamplerCApi>(
411435 /* lib=*/ std::nullopt , LiteRtTopKOpenClSampler_Create_Static,
412436 LiteRtTopKOpenClSampler_Destroy_Static,
413437 LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer_Static,
414- LiteRtTopKOpenClSampler_UpdateConfig_Static);
438+ LiteRtTopKOpenClSampler_UpdateConfig_Static,
439+ LiteRtTopKOpenClSampler_CanHandleInput_Static,
440+ LiteRtTopKOpenClSampler_HandlesInput_Static,
441+ LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static);
415442 }
416443};
417444
0 commit comments