Skip to content

Commit 76feb80

Browse files
fengwuyaocopybara-github
authored andcommitted
Add input handling and inference triggering to TopKOpenClSampler.
LiteRT-LM-PiperOrigin-RevId: 888955323
1 parent d9341af commit 76feb80

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

runtime/components/sampler_factory.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ extern "C" int (*LiteRtTopKOpenClSampler_UpdateConfig_Static)(
6969
const LiteRtTopKSampler_SamplerParameters* sampler_params, int batch_size,
7070
void* rand_gen_shared_ptr, char** error_msg) = nullptr;
7171

72+
extern "C" int (*LiteRtTopKOpenClSampler_CanHandleInput_Static)(
73+
LiteRtTopKSampler_Sampler* sampler) = nullptr;
74+
75+
extern "C" int (*LiteRtTopKOpenClSampler_HandlesInput_Static)(
76+
LiteRtTopKSampler_Sampler* sampler) = nullptr;
77+
78+
extern "C" int (
79+
*LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static)(
80+
LiteRtTopKSampler_Sampler* sampler,
81+
LiteRtTensorBuffer absl_nullable ids_tensor,
82+
LiteRtTensorBuffer absl_nullable prev_input_positions_tensor,
83+
LiteRtTensorBuffer absl_nullable input_positions_tensor,
84+
LiteRtTensorBuffer absl_nullable prev_mask_tensor,
85+
LiteRtTensorBuffer absl_nullable mask_tensor,
86+
int (*run_inference_func)(void* arg), void* arg,
87+
char** error_msg) = nullptr;
88+
7289
// WebGPU Sampler C API function pointers.
7390
extern "C" int (*LiteRtTopKWebGpuSampler_Create_Static)(
7491
LiteRtEnvironment env, int batch_size, int sequence_size, int vocab_size,
@@ -361,7 +378,10 @@ class TopKOpenClCApiSampler : public TopKCApiSampler {
361378
"libLiteRtTopKOpenClSampler.so", "LiteRtTopKOpenClSampler_Create",
362379
"LiteRtTopKOpenClSampler_Destroy",
363380
"LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer",
364-
"LiteRtTopKOpenClSampler_UpdateConfig");
381+
"LiteRtTopKOpenClSampler_UpdateConfig",
382+
"LiteRtTopKOpenClSampler_CanHandleInput",
383+
"LiteRtTopKOpenClSampler_HandlesInput",
384+
"LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc");
365385
if (capi_or.ok()) {
366386
capi = std::move(capi_or.value());
367387
ABSL_LOG(INFO) << "Dynamically loaded LiteRtTopKOpenClSampler C API.";
@@ -403,15 +423,22 @@ class TopKOpenClCApiSampler : public TopKCApiSampler {
403423
if (LiteRtTopKOpenClSampler_Create_Static == nullptr ||
404424
LiteRtTopKOpenClSampler_Destroy_Static == nullptr ||
405425
LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer_Static == nullptr ||
406-
LiteRtTopKOpenClSampler_UpdateConfig_Static == nullptr) {
426+
LiteRtTopKOpenClSampler_UpdateConfig_Static == nullptr ||
427+
LiteRtTopKOpenClSampler_CanHandleInput_Static == nullptr ||
428+
LiteRtTopKOpenClSampler_HandlesInput_Static == nullptr ||
429+
LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static ==
430+
nullptr) {
407431
return absl::UnavailableError(
408432
"Static LiteRtTopKOpenClSampler C API not available.");
409433
}
410434
return std::make_unique<TopKSamplerCApi>(
411435
/*lib=*/std::nullopt, LiteRtTopKOpenClSampler_Create_Static,
412436
LiteRtTopKOpenClSampler_Destroy_Static,
413437
LiteRtTopKOpenClSampler_SampleToIdAndScoreBuffer_Static,
414-
LiteRtTopKOpenClSampler_UpdateConfig_Static);
438+
LiteRtTopKOpenClSampler_UpdateConfig_Static,
439+
LiteRtTopKOpenClSampler_CanHandleInput_Static,
440+
LiteRtTopKOpenClSampler_HandlesInput_Static,
441+
LiteRtTopKOpenClSampler_SetInputTensorsAndInferenceFunc_Static);
415442
}
416443
};
417444

0 commit comments

Comments
 (0)