Skip to content

Commit 6ce29ed

Browse files
committed
feat: add scaleway inference provider
1 parent 490865e commit 6ce29ed

File tree

6 files changed

+107
-2
lines changed

6 files changed

+107
-2
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class InferenceClient:
130130
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
131131
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
132132
provider (`str`, *optional*):
133-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
133+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
134134
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
135135
If model is a URL or `base_url` is passed, then `provider` is not used.
136136
token (`str`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class AsyncInferenceClient:
118118
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
119119
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
120120
provider (`str`, *optional*):
121-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
121+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
122122
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
123123
If model is a URL or `base_url` is passed, then `provider` is not used.
124124
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .openai import OpenAIConversationalTask
3939
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
4040
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
41+
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
4142
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
4243

4344

@@ -61,6 +62,7 @@
6162
"replicate",
6263
"sambanova",
6364
"together",
65+
"scaleway",
6466
]
6567

6668
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
@@ -159,6 +161,10 @@
159161
"conversational": TogetherConversationalTask(),
160162
"text-generation": TogetherTextGenerationTask(),
161163
},
164+
"scaleway": {
165+
"conversational": ScalewayConversationalTask(),
166+
"feature-extraction": ScalewayFeatureExtractionTask(),
167+
},
162168
}
163169

164170

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"replicate": {},
3535
"sambanova": {},
3636
"together": {},
37+
"scaleway": {},
3738
}
3839

3940

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Any, Dict, Optional, Union
2+
3+
from huggingface_hub.inference._common import RequestParameters, _as_dict
4+
5+
from ._common import BaseConversationalTask, InferenceProviderMapping, TaskProviderHelper, filter_none
6+
7+
8+
class ScalewayConversationalTask(BaseConversationalTask):
9+
def __init__(self):
10+
super().__init__(provider="scaleway", base_url="https://api.scaleway.ai")
11+
12+
13+
class ScalewayFeatureExtractionTask(TaskProviderHelper):
14+
def __init__(self):
15+
super().__init__(provider="scaleway", base_url="https://api.scaleway.ai", task="feature-extraction")
16+
17+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
18+
return "/v1/embeddings"
19+
20+
def _prepare_payload_as_dict(
21+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
22+
) -> Optional[Dict]:
23+
parameters = filter_none(parameters)
24+
return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters}
25+
26+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
27+
embeddings = _as_dict(response)["data"]
28+
return [embedding["embedding"] for embedding in embeddings]

tests/test_inference_providers.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
ReplicateTextToSpeechTask,
5151
)
5252
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
53+
from huggingface_hub.inference._providers.scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
5354
from huggingface_hub.inference._providers.together import TogetherTextToImageTask
5455

5556
from .testing_utils import assert_in_logs
@@ -1077,6 +1078,75 @@ def test_prepare_url_conversational(self):
10771078
assert url == "https://api.novita.ai/v3/openai/chat/completions"
10781079

10791080

1081+
class TestScalewayProvider:
1082+
def test_prepare_hf_url_conversational(self):
1083+
helper = ScalewayConversationalTask()
1084+
url = helper._prepare_url("hf_token", "username/repo_name")
1085+
assert url == "https://router.huggingface.co/scaleway/v1/chat/completions"
1086+
1087+
def test_prepare_url_conversational(self):
1088+
helper = ScalewayConversationalTask()
1089+
url = helper._prepare_url("scw_token", "username/repo_name")
1090+
assert url == "https://api.scaleway.ai/v1/chat/completions"
1091+
1092+
def test_prepare_payload_as_dict(self):
1093+
helper = ScalewayConversationalTask()
1094+
payload = helper._prepare_payload_as_dict(
1095+
[
1096+
{"role": "system", "content": "You are a helpful assistant"},
1097+
{"role": "user", "content": "Hello!"},
1098+
],
1099+
{
1100+
"max_tokens": 512,
1101+
"temperature": 0.15,
1102+
"top_p": 1,
1103+
"presence_penalty": 0,
1104+
"stream": True,
1105+
},
1106+
InferenceProviderMapping(
1107+
provider="scaleway",
1108+
hf_model_id="meta-llama/Llama-3.1-8B-Instruct",
1109+
providerId="meta-llama/llama-3.1-8B-Instruct",
1110+
task="conversational",
1111+
status="live",
1112+
),
1113+
)
1114+
assert payload == {
1115+
"max_tokens": 512,
1116+
"messages": [
1117+
{"content": "You are a helpful assistant", "role": "system"},
1118+
{"role": "user", "content": "Hello!"},
1119+
],
1120+
"model": "meta-llama/llama-3.1-8B-Instruct",
1121+
"presence_penalty": 0,
1122+
"stream": True,
1123+
"temperature": 0.15,
1124+
"top_p": 1,
1125+
}
1126+
1127+
def test_prepare_url_feature_extraction(self):
1128+
helper = ScalewayFeatureExtractionTask()
1129+
assert (
1130+
helper._prepare_url("hf_token", "username/repo_name")
1131+
== "https://router.huggingface.co/scaleway/v1/embeddings"
1132+
)
1133+
1134+
def test_prepare_payload_as_dict_feature_extraction(self):
1135+
helper = ScalewayFeatureExtractionTask()
1136+
payload = helper._prepare_payload_as_dict(
1137+
"Example text to embed",
1138+
{"truncate": True},
1139+
InferenceProviderMapping(
1140+
provider="scaleway",
1141+
hf_model_id="username/repo_name",
1142+
providerId="provider-id",
1143+
task="feature-extraction",
1144+
status="live",
1145+
),
1146+
)
1147+
assert payload == {"input": "Example text to embed", "model": "provider-id", "truncate": True}
1148+
1149+
10801150
class TestNscaleProvider:
10811151
def test_prepare_route_text_to_image(self):
10821152
helper = NscaleTextToImageTask()

0 commit comments

Comments
 (0)