Skip to content

feat: Free gpu space after each inference run #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: gpu
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ac641d0
new: gpu package publish workflow
joein May 1, 2024
9591396
fix: do not run windows and mac os tests on gpu branch
joein May 2, 2024
b3461c0
refactoring: alter workflow names
joein May 2, 2024
9b1af0c
alter workflow
joein May 2, 2024
ee5fa07
fix: workflow dispatch can only be triggered from the default branch
joein May 2, 2024
ad62c7d
sync publih with main
joein May 3, 2024
238e8ca
add eofl
joein May 3, 2024
0499ea5
new: gpu package (#224)
joein May 3, 2024
4c50015
fix: Minimize gpu memory fragmentation
hh-space-invader Feb 28, 2025
a18f735
nit
hh-space-invader Mar 4, 2025
b8d30b1
new: Add arena extend strategy
hh-space-invader Mar 4, 2025
a761dcf
change initial chunk size
hh-space-invader Mar 4, 2025
b82e4d0
a
hh-space-invader Mar 4, 2025
c212c1f
specify shrinkage as run options not session options
hh-space-invader Mar 4, 2025
13b7d6d
specify shrinkage as run options not session options
hh-space-invader Mar 4, 2025
f63333d
specify shrinkage as run options not session options
hh-space-invader Mar 4, 2025
5c46b17
specify shrinkage as run options not session options
hh-space-invader Mar 4, 2025
758d339
new: Shrink empty arena for multi gpu settings
hh-space-invader Mar 4, 2025
4037e14
chore: Remove print statement
hh-space-invader Mar 4, 2025
1c016a2
fix: Fix multi gpu settings
hh-space-invader Mar 5, 2025
5ea3bbc
docs: Add description for changes
hh-space-invader Mar 5, 2025
3d90072
new: Added experiment to benchmark fastembed on gpu
hh-space-invader Mar 7, 2025
4bc5fbe
fix: Fix passing cuda and providers in single gpu settings
hh-space-invader Mar 13, 2025
9e21de2
fix: Pass providers and cuda to multimodal models
hh-space-invader Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: Tests
run-name: Tests (gpu)

on:
push:
Expand All @@ -21,8 +22,6 @@ jobs:
- '3.13.x'
os:
- ubuntu-latest
- macos-latest
- windows-latest

runs-on: ${{ matrix.os }}

Expand Down
493 changes: 493 additions & 0 deletions experiments/Throughput_Across_Models_GPU.ipynb

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,15 @@ def _load_onnx_model(
if device_id is None:
onnx_providers = ["CUDAExecutionProvider"]
else:
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
# kSameAsRequested: Allocates only the requested memory, avoiding over-allocation.
# more precise than 'kNextPowerOfTwo', which grows memory aggressively.
# source: https://onnxruntime.ai/docs/get-started/with-c.html#features:~:text=Memory%20arena%20shrinkage:
onnx_providers = [
(
"CUDAExecutionProvider",
{"device_id": device_id, "arena_extend_strategy": "kSameAsRequested"},
)
]
else:
onnx_providers = ["CPUExecutionProvider"]

Expand Down Expand Up @@ -132,5 +140,7 @@ def __init__(
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker[T]":
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
raise NotImplementedError("Subclasses must implement this method")
19 changes: 17 additions & 2 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import unicodedata
from pathlib import Path
from itertools import islice
from typing import Iterable, Optional, TypeVar
from typing import Iterable, Optional, TypeVar, Sequence

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray
from fastembed.common.types import NumpyArray, OnnxProvider

T = TypeVar("T")

Expand Down Expand Up @@ -67,3 +67,18 @@ def get_all_punctuation() -> set[str]:

def remove_non_alphanumeric(text: str) -> str:
return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE)


def is_cuda_enabled(cuda: bool, providers: Optional[Sequence[OnnxProvider]]) -> bool:
"""
Check if CUDA is enabled based on the `cuda` and `providers` parameters
"""
if cuda:
return True
if not providers:
return False
if isinstance(providers, str):
return "CUDAExecutionProvider" in providers
return isinstance(providers, (list, tuple)) and any(
isinstance(p, str) and "CUDAExecutionProvider" in p for p in providers
)
29 changes: 24 additions & 5 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import numpy as np
from PIL import Image
import onnxruntime as ort

from fastembed.image.transform.operators import Compose
from fastembed.common.types import NumpyArray
from fastembed.common import ImageInput, OnnxProvider
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
from fastembed.common.preprocessor_utils import load_preprocessor
from fastembed.common.utils import iter_batch
from fastembed.common.utils import iter_batch, is_cuda_enabled
from fastembed.parallel_processor import ParallelWorkerPool

# Holds type of the embedding result
Expand Down Expand Up @@ -74,7 +75,21 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]

run_options = ort.RunOptions()
providers = kwargs.get("providers", None)
cuda = kwargs.get("cuda", False)
if is_cuda_enabled(cuda, providers):
device_id = kwargs.get("device_id", None)
device_id = str(device_id if isinstance(device_id, int) else 0)
# enables memory arena shrinkage, freeing unused memory after each Run() cycle.
# helps prevent excessive memory retention, especially for dynamic workloads.
# source: https://onnxruntime.ai/docs/get-started/with-c.html#features:~:text=Memory%20arena%20shrinkage:
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)

model_output = self.model.run(None, onnx_input, run_options) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)

Expand Down Expand Up @@ -104,7 +119,9 @@ def _embed_images(
self.load_onnx_model()

for batch in iter_batch(images, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(batch))
yield from self._post_process_onnx_output(
self.onnx_embed(batch, cuda=cuda, providers=providers)
)
else:
if parallel == 0:
parallel = os.cpu_count()
Expand All @@ -129,7 +146,9 @@ def _embed_images(


class ImageEmbeddingWorker(EmbeddingWorker[T]):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
embeddings = self.model.onnx_embed(batch)
embeddings = self.model.onnx_embed(batch, **kwargs)
yield idx, embeddings
55 changes: 46 additions & 9 deletions fastembed/late_interaction_multimodal/onnx_multimodal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import numpy as np
from PIL import Image
import onnxruntime as ort
from tokenizers import Encoding, Tokenizer

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
from fastembed.common.types import NumpyArray
from fastembed.common.utils import iter_batch
from fastembed.common.utils import iter_batch, is_cuda_enabled
from fastembed.image.transform.operators import Compose
from fastembed.parallel_processor import ParallelWorkerPool

Expand Down Expand Up @@ -103,7 +104,21 @@ def onnx_embed_text(
)

onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs)
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]

run_options = ort.RunOptions()
providers = kwargs.get("providers", None)
cuda = kwargs.get("cuda", False)
if is_cuda_enabled(cuda, providers):
device_id = kwargs.get("device_id", None)
device_id = str(device_id if isinstance(device_id, int) else 0)
# enables memory arena shrinkage, freeing unused memory after each Run() cycle.
# helps prevent excessive memory retention, especially for dynamic workloads.
# source: https://onnxruntime.ai/docs/get-started/with-c.html#features:~:text=Memory%20arena%20shrinkage:
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
attention_mask=onnx_input.get("attention_mask", attention_mask),
Expand Down Expand Up @@ -136,7 +151,9 @@ def _embed_documents(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_text_output(self.onnx_embed_text(batch))
yield from self._post_process_onnx_text_output(
self.onnx_embed_text(batch, cuda=cuda, providers=providers)
)
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down Expand Up @@ -169,7 +186,21 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
encoded = np.array(self.processor(image_files))
onnx_input = {"pixel_values": encoded}
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]

run_options = ort.RunOptions()
providers = kwargs.get("providers", None)
cuda = kwargs.get("cuda", False)
if is_cuda_enabled(cuda, providers):
device_id = kwargs.get("device_id", None)
device_id = str(device_id if isinstance(device_id, int) else 0)
# enables memory arena shrinkage, freeing unused memory after each Run() cycle.
# helps prevent excessive memory retention, especially for dynamic workloads.
# source: https://onnxruntime.ai/docs/get-started/with-c.html#features:~:text=Memory%20arena%20shrinkage:
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)

model_output = self.model.run(None, onnx_input, run_options) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)

Expand Down Expand Up @@ -199,7 +230,9 @@ def _embed_images(
self.load_onnx_model()

for batch in iter_batch(images, batch_size):
yield from self._post_process_onnx_image_output(self.onnx_embed_image(batch))
yield from self._post_process_onnx_image_output(
self.onnx_embed_image(batch, cuda=cuda, providers=providers)
)
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down Expand Up @@ -241,9 +274,11 @@ def init_embedding(
) -> OnnxMultimodalModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_text(batch)
onnx_output = self.model.onnx_embed_text(batch, **kwargs)
yield idx, onnx_output


Expand All @@ -265,7 +300,9 @@ def init_embedding(
) -> OnnxMultimodalModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
embeddings = self.model.onnx_embed_image(batch)
embeddings = self.model.onnx_embed_image(batch, **kwargs)
yield idx, embeddings
6 changes: 4 additions & 2 deletions fastembed/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class Worker:
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
raise NotImplementedError()


Expand Down Expand Up @@ -63,7 +65,7 @@ def input_queue_iterable() -> Iterable[Any]:
break
yield item

for processed_item in worker.process(input_queue_iterable()):
for processed_item in worker.process(input_queue_iterable(), **kwargs):
output_queue.put(processed_item)
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
Expand Down
29 changes: 24 additions & 5 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np
import onnxruntime as ort
from tokenizers import Encoding

from fastembed.common.onnx_model import (
Expand All @@ -14,7 +15,7 @@
)
from fastembed.common.types import NumpyArray
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.common.utils import iter_batch, is_cuda_enabled
from fastembed.parallel_processor import ParallelWorkerPool


Expand Down Expand Up @@ -71,7 +72,21 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxO
tokenized_input = self.tokenize(pairs, **kwargs)
inputs = self._build_onnx_input(tokenized_input)
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]

run_options = ort.RunOptions()
providers = kwargs.get("providers", None)
cuda = kwargs.get("cuda", False)
if is_cuda_enabled(cuda, providers):
device_id = kwargs.get("device_id", None)
device_id = str(device_id if isinstance(device_id, int) else 0)
# Enables memory arena shrinkage, freeing unused memory after each Run() cycle.
# Helps prevent excessive memory retention, especially for dynamic workloads.
# Source: https://onnxruntime.ai/docs/get-started/with-c.html#features:~:text=Memory%20arena%20shrinkage:
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)

outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
relevant_output = outputs[0]
scores: NumpyArray = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores)
Expand Down Expand Up @@ -110,7 +125,9 @@ def _rerank_pairs(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
yield from self._post_process_onnx_output(
self.onnx_embed_pairs(batch, cuda=cuda, providers=providers, **kwargs)
)
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down Expand Up @@ -163,7 +180,9 @@ def init_embedding(
) -> OnnxCrossEncoderModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_pairs(batch)
onnx_output = self.model.onnx_embed_pairs(batch, **kwargs)
yield idx, onnx_output
2 changes: 1 addition & 1 deletion fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)

def process(
self, items: Iterable[tuple[int, Any]]
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, list[SparseEmbedding]]]:
for idx, batch in items:
onnx_output = self.model.raw_embed(batch)
Expand Down
Loading