diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 13a309763b58..023697768983 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 11 + "modification": 12 } diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index a69a15516e8d..c1f94259aca6 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -39,6 +39,8 @@ from openai import OpenAI try: + # VLLM logging config breaks beam logging. + os.environ["VLLM_CONFIGURE_LOGGING"] = "0" import vllm # pylint: disable=unused-import logging.info('vllm module successfully imported.') except ModuleNotFoundError: @@ -127,7 +129,9 @@ def start_server(self, retries=3): ] for k, v in self._vllm_server_kwargs.items(): server_cmd.append(f'--{k}') - server_cmd.append(v) + # Only add values for commands with value part. + if v is not None: + server_cmd.append(v) self._server_process, self._server_port = start_process(server_cmd) self.check_connectivity(retries) @@ -138,27 +142,27 @@ def get_server_port(self) -> int: return self._server_port def check_connectivity(self, retries=3): - client = getVLLMClient(self._server_port) - while self._server_process.poll() is None: - try: - models = client.models.list().data - logging.info('models: %s' % models) - if len(models) > 0: - self._server_started = True - return - except: # pylint: disable=bare-except - pass - # Sleep while bringing up the process - time.sleep(5) - - if retries == 0: - self._server_started = False - raise Exception( - "Failed to start vLLM server, polling process exited with code " + - "%s. Next time a request is tried, the server will be restarted" % - self._server_process.poll()) - else: - self.start_server(retries - 1) + with getVLLMClient(self._server_port) as client: + while self._server_process.poll() is None: + try: + models = client.models.list().data + logging.info('models: %s' % models) + if len(models) > 0: + self._server_started = True + return + except: # pylint: disable=bare-except + pass + # Sleep while bringing up the process + time.sleep(5) + + if retries == 0: + self._server_started = False + raise Exception( + "Failed to start vLLM server, polling process exited with code " + + "%s. Next time a request is tried, the server will be restarted" % + self._server_process.poll()) + else: + self.start_server(retries - 1) class VLLMCompletionsModelHandler(ModelHandler[str, @@ -200,27 +204,21 @@ async def _async_run_inference( model: _VLLMModelServer, inference_args: Optional[dict[str, Any]] = None ) -> Iterable[PredictionResult]: - client = getAsyncVLLMClient(model.get_server_port()) inference_args = inference_args or {} - async_predictions = [] - for prompt in batch: - try: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - async_predictions.append(completion) - except Exception as e: - model.check_connectivity() - raise e - predictions = [] - for p in async_predictions: + async with getAsyncVLLMClient(model.get_server_port()) as client: try: - predictions.append(await p) + async_predictions = [ + client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + for prompt in batch + ] + responses = await asyncio.gather(*async_predictions) except Exception as e: model.check_connectivity() raise e - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + return [PredictionResult(x, y) for x, y in zip(batch, responses)] def run_inference( self, @@ -301,25 +299,19 @@ async def _async_run_inference( model: _VLLMModelServer, inference_args: Optional[dict[str, Any]] = None ) -> Iterable[PredictionResult]: - client = getAsyncVLLMClient(model.get_server_port()) inference_args = inference_args or {} - async_predictions = [] - for messages in batch: - formatted = [] - for message in messages: - formatted.append({"role": message.role, "content": message.content}) - try: - completion = client.chat.completions.create( - model=self._model_name, messages=formatted, **inference_args) - async_predictions.append(completion) - except Exception as e: - model.check_connectivity() - raise e - predictions = [] - for p in async_predictions: + async with getAsyncVLLMClient(model.get_server_port()) as client: try: - predictions.append(await p) + async_predictions = [ + client.chat.completions.create( + model=self._model_name, + messages=[{ + "role": message.role, "content": message.content + } for message in messages], + **inference_args) for messages in batch + ] + predictions = await asyncio.gather(*async_predictions) except Exception as e: model.check_connectivity() raise e