Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
"""# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +37,7 @@
from fastdeploy.engine.expert_service import ExpertService
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -266,6 +266,12 @@ async def lifespan(app: FastAPI):
instrument(app)


env_api_key_func = environment_variables.get("FD_API_KEY")
env_tokens = env_api_key_func() if env_api_key_func else []
if tokens := [key for key in (args.api_key or env_tokens) if key]:
app.add_middleware(AuthenticationMiddleware, tokens)


@asynccontextmanager
async def connection_manager():
"""
Expand Down
55 changes: 55 additions & 0 deletions fastdeploy/entrypoints/openai/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import hashlib
import secrets
from collections.abc import Awaitable

from fastapi.responses import JSONResponse
from starlette.datastructures import URL, Headers
from starlette.types import ASGIApp, Receive, Scope, Send


class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization Bearer token exists and equals anyof "{api_key}".

Notes
-----
There are two cases in which authentication is skipped:
1. The HTTP method is OPTIONS.
2. The request path doesn't start with /v1 (e.g. /health).
"""

def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]

def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False

scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False

param_hash = hashlib.sha256(param.encode("utf-8")).digest()

token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)

return token_match

def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
# scope["type"] can be "lifespan" or "startup" for example,
# in which case we don't need to do anything
return self.app(scope, receive, send)
root_path = scope.get("root_path", "")
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return response(scope, receive, send)
return self.app(scope, receive, send)
2 changes: 2 additions & 0 deletions fastdeploy/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Workers silent for more than this many seconds are killed and restarted.Value is a positive number or 0. Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
)

parser.add_argument("--api-key", type=str, action="append", help="API_KEY required for service authentication")

parser = EngineArgs.add_cli_args(parser)
return parser
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@
"FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")),
# Count for cache_transfer_manager process error
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),
# API_KEY required for service authentication
"FD_API_KEY": lambda: [] if "FD_API_KEY" not in os.environ else os.environ["FD_API_KEY"].split(","),
# EPLB related
"FD_ENABLE_REDUNDANT_EXPERTS": lambda: int(os.getenv("FD_ENABLE_REDUNDANT_EXPERTS", "0")) == 1,
"FD_REDUNDANT_EXPERTS_NUM": lambda: int(os.getenv("FD_REDUNDANT_EXPERTS_NUM", "0")),
Expand Down
231 changes: 231 additions & 0 deletions tests/e2e/test_api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import os
import signal
import socket
import subprocess
import sys
import time
from typing import Optional

import pytest
import requests

FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]

current_server_process: Optional[subprocess.Popen] = None


def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False


def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
current_pid = os.getpid()
parent_pid = os.getppid()
for pid in output.splitlines():
pid = int(pid)
if pid in (current_pid, parent_pid):
print(f"Skip killing current process (pid={pid}) on port {port}")
continue
os.kill(pid, signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass


def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)


def start_api_server(api_key_cli: Optional[list[str]] = None, api_key_env: Optional[str] = None):
global current_server_process
clean_ports()

env = os.environ.copy()
if api_key_env is not None:
env["FD_API_KEY"] = api_key_env
else:
env.pop("FD_API_KEY", None)
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "./ERNIE-4.5-0.3B-Paddle"
log_path = "server.log"

cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"32768",
"--max-num-seqs",
"128",
"--quantization",
"wint4",
"--graph-optimization-config",
'{"cudagraph_capture_sizes": [1], "use_cudagraph":true}',
]

if api_key_cli is not None:
for key in api_key_cli:
cmd.extend(["--api-key", key])

with open(log_path, "w") as logfile:
process = subprocess.Popen(cmd, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True, env=env)

for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"API server started (port: {FD_API_PORT}, cli_key: {api_key_cli}, env_key: {api_key_env})")
current_server_process = process
return process
time.sleep(1)
else:
if process.poll() is None:
os.killpg(process.pid, signal.SIGTERM)
raise RuntimeError(f"API server failed to start in 5 minutes (port: {FD_API_PORT})")


def stop_api_server():
global current_server_process
if current_server_process and current_server_process.poll() is None:
try:
os.killpg(current_server_process.pid, signal.SIGTERM)
current_server_process.wait(timeout=10)
print(f"API server stopped (pid: {current_server_process.pid})")
except Exception as e:
print(f"Failed to stop server: {e}")
current_server_process = None
clean_ports()


@pytest.fixture(scope="function", autouse=True)
def teardown_server():
yield
stop_api_server()
os.environ.pop("FD_API_KEY", None)


@pytest.fixture(scope="function")
def api_url():
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"


@pytest.fixture
def common_headers():
return {"Content-Type": "application/json"}


@pytest.fixture
def valid_auth_headers():
return {"Content-Type": "application/json", "Authorization": "Bearer {api_key}"}


@pytest.fixture
def test_payload():
return {"messages": [{"role": "user", "content": "hello"}], "temperature": 0.9, "max_tokens": 100}


def test_api_key_cli_only(api_url, common_headers, valid_auth_headers, test_payload):
test_api_key = ["cli_test_key_123", "cli_test_key_456"]
start_api_server(api_key_cli=test_api_key)

response = requests.post(api_url, json=test_payload, headers=common_headers)
assert response.status_code == 401
assert "error" in response.json()
assert "unauthorized" in response.json()["error"].lower()

invalid_headers = valid_auth_headers.copy()
invalid_headers["Authorization"] = invalid_headers["Authorization"].format(api_key="wrong_key")
response = requests.post(api_url, json=test_payload, headers=invalid_headers)
assert response.status_code == 401

valid_headers = valid_auth_headers.copy()
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key=test_api_key[0])
response = requests.post(api_url, json=test_payload, headers=valid_headers)
assert response.status_code == 200

valid_headers = valid_auth_headers.copy()
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key=test_api_key[1])
response = requests.post(api_url, json=test_payload, headers=valid_headers)
assert response.status_code == 200


def test_api_key_env_only(api_url, common_headers, valid_auth_headers, test_payload):
test_api_key = "env_test_key_456,env_test_key_789"
start_api_server(api_key_env=test_api_key)

response = requests.post(api_url, json=test_payload, headers=common_headers)
assert response.status_code == 401

valid_headers = valid_auth_headers.copy()
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key="env_test_key_456")
response = requests.post(api_url, json=test_payload, headers=valid_headers)
assert response.status_code == 200

valid_headers = valid_auth_headers.copy()
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key="env_test_key_789")
response = requests.post(api_url, json=test_payload, headers=valid_headers)
assert response.status_code == 200


def test_api_key_cli_priority_over_env(api_url, valid_auth_headers, test_payload):
cli_key = ["cli_priority_key_789"]
env_key = "env_low_priority_key_000"
start_api_server(api_key_cli=cli_key, api_key_env=env_key)

env_headers = valid_auth_headers.copy()
env_headers["Authorization"] = env_headers["Authorization"].format(api_key=env_key)
response = requests.post(api_url, json=test_payload, headers=env_headers)
assert response.status_code == 401

cli_headers = valid_auth_headers.copy()
cli_headers["Authorization"] = cli_headers["Authorization"].format(api_key=cli_key[0])
response = requests.post(api_url, json=test_payload, headers=cli_headers)
assert response.status_code == 200


def test_api_key_not_set(api_url, common_headers, valid_auth_headers, test_payload):
start_api_server(api_key_cli=None, api_key_env=None)

response = requests.post(api_url, json=test_payload, headers=common_headers)
assert response.status_code == 200

cli_headers = valid_auth_headers.copy()
cli_headers["Authorization"] = cli_headers["Authorization"].format(api_key="some_api_key")
response = requests.post(api_url, json=test_payload, headers=cli_headers)
assert response.status_code == 200
Loading
Loading