Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@
os.environ.get("RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT", "100")
)
# The minimum drain period for a HTTP proxy.
# If RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS is set to 1,
# then the minimum draining period is 0.
RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S = float(
os.environ.get("RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S", "30")
)
Expand Down
67 changes: 28 additions & 39 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from ray.serve._private.constants import (
DEFAULT_LATENCY_BUCKET_MS,
MAX_PER_REPLICA_RETRY_COUNT,
RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S,
RAY_SERVE_ENABLE_DIRECT_INGRESS,
RAY_SERVE_ENABLE_TASK_EVENTS,
RAY_SERVE_FAIL_ON_RANK_ERROR,
RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS,
Expand Down Expand Up @@ -251,11 +253,6 @@ class DeploymentStateUpdateResult:

ALL_REPLICA_STATES = list(ReplicaState)
_SCALING_LOG_ENABLED = os.environ.get("SERVE_ENABLE_SCALING_LOG", "0") != "0"
# Feature flag to disable forcibly shutting down replicas.
RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY = (
os.environ.get("RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY", "0")
== "1"
)


def print_verbose_scaling_log():
Expand Down Expand Up @@ -758,7 +755,7 @@ def reconfigure(self, version: DeploymentVersion, rank: ReplicaRank) -> bool:
self._rank = rank
return updating

def recover(self) -> bool:
def recover(self, ingress: bool = False) -> bool:
"""Recover replica version from a live replica actor.

When controller dies, the deployment state loses the info on the version that's
Expand All @@ -767,12 +764,17 @@ def recover(self) -> bool:

Also confirm that actor is allocated and initialized before marking as running.

Returns: False if the replica actor is no longer alive; the
Args:
ingress: Whether this replica is an ingress replica.

Returns:
False if the replica actor is no longer alive; the
actor could have been killed in the time between when the
controller fetching all Serve actors in the cluster and when
the controller tries to recover it. Otherwise, return True.
"""
logger.info(f"Recovering {self.replica_id}.")
self._ingress = ingress
try:
self._actor_handle = ray.get_actor(
self._actor_name, namespace=SERVE_NAMESPACE
Expand Down Expand Up @@ -1189,20 +1191,8 @@ def get_routing_stats(self) -> Dict[str, Any]:

return self._routing_stats

def force_stop(self, log_shutdown_message: bool = False):
def force_stop(self):
"""Force the actor to exit without shutting down gracefully."""
if (
self._ingress
and RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY
):
if log_shutdown_message:
logger.info(
f"{self.replica_id} did not shut down because it had not finished draining requests. "
"Going to wait until the draining is complete. You can force-stop the replica by "
"setting RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY to 0."
)
return

try:
ray.kill(ray.get_actor(self._actor_name, namespace=SERVE_NAMESPACE))
except ValueError:
Expand Down Expand Up @@ -1235,7 +1225,6 @@ def __init__(
)
self._multiplexed_model_ids: List[str] = []
self._routing_stats: Dict[str, Any] = {}
self._logged_shutdown_message = False

def get_running_replica_info(
self, cluster_node_info_cache: ClusterNodeInfoCache
Expand Down Expand Up @@ -1366,7 +1355,6 @@ def start(
deployment_info, assign_rank_callback=assign_rank_callback
)
self._start_time = time.time()
self._logged_shutdown_message = False
self.update_actor_details(start_time_s=self._start_time)
return replica_scheduling_request

Expand All @@ -1383,16 +1371,20 @@ def reconfigure(
"""
return self._actor.reconfigure(version, rank=rank)

def recover(self) -> bool:
def recover(self, deployment_info: DeploymentInfo) -> bool:
"""
Recover states in DeploymentReplica instance by fetching running actor
status

Returns: False if the replica is no longer alive at the time
when this method is called.
Args:
deployment_info: The deployment info for this replica.

Returns:
True if the replica actor is alive and recovered successfully.
False if the replica actor is no longer alive.
"""
# If replica is no longer alive
if not self._actor.recover():
if not self._actor.recover(ingress=deployment_info.ingress):
return False

self._start_time = time.time()
Expand Down Expand Up @@ -1442,6 +1434,11 @@ def stop(self, graceful: bool = True) -> None:
timeout_s = self._actor.graceful_stop()
if not graceful:
timeout_s = 0
elif self._actor._ingress and RAY_SERVE_ENABLE_DIRECT_INGRESS:
# In direct ingress mode, ensure we wait at least
# RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external
# load balancers (e.g., ALB) time to deregister the replica.
timeout_s = max(timeout_s, RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S)
self._shutdown_deadline = time.time() + timeout_s

def check_stopped(self) -> bool:
Expand All @@ -1451,19 +1448,11 @@ def check_stopped(self) -> bool:

timeout_passed = time.time() >= self._shutdown_deadline
if timeout_passed:
if (
not self._logged_shutdown_message
and not RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated log messages when replica shutdown times out

Low Severity

The removal of _logged_shutdown_message tracking causes the "did not shut down after grace period, force-killing it" log message to be printed repeatedly. When check_stopped() is called in the control loop and the timeout has passed, it logs and calls force_stop() every iteration until the actor fully terminates. The old code used _logged_shutdown_message to ensure the message was only logged once. This can cause significant log spam during shutdown of stuck replicas.

Fix in Cursor Fix in Web

):
logger.info(
f"{self.replica_id} did not shut down after grace "
"period, force-killing it. "
)

self._actor.force_stop(
log_shutdown_message=not self._logged_shutdown_message
logger.info(
f"{self.replica_id} did not shut down after grace "
"period, force-killing it."
)
self._logged_shutdown_message = True
self._actor.force_stop()
return False

def check_health(self) -> bool:
Expand Down Expand Up @@ -2342,7 +2331,7 @@ def recover_current_state_from_replica_actor_names(
)
# If replica is no longer alive, simply don't add it to the
# deployment state manager to track.
if not new_deployment_replica.recover():
if not new_deployment_replica.recover(self._target_state.info):
logger.warning(f"{replica_id} died before controller could recover it.")
continue

Expand Down
41 changes: 13 additions & 28 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,38 +2311,23 @@ async def call_asgi():
raise asyncio.CancelledError

async def perform_graceful_shutdown(self):
if not RAY_SERVE_ENABLE_DIRECT_INGRESS or not self._ingress:
# if direct ingress is not enabled or the replica is not an ingress replica,
# we can just call the super method to perform the graceful shutdown.
await super().perform_graceful_shutdown()
return

# set the shutting down flag to True to signal ALBs with failing health checks
# to stop sending traffic to this replica.
self._shutting_down = True

# If the replica was never initialized it never served traffic, so we
# can skip the wait period.
if self._user_callable_initialized:
# in order to gracefully shutdown the replica, we need to wait for the
# requests to drain and for PROXY_MIN_DRAINING_PERIOD_S to pass.
# this is necessary because we want to give ALB time to update its
# target group to remove the replica from it and to mark this replica
# as unhealthy.
# TODO(abrar): the code below assumes that once ALB marks a replica target
# as unhealthy, it will not send traffic to it. This is not true because
# ALB can send traffic to a replica if all targets are unhealthy.
# The correct way to handle is this we start the cooldown period since
# the last request finished and wait for the cooldown period to pass.
if (
RAY_SERVE_ENABLE_DIRECT_INGRESS
and self._ingress
and self._user_callable_initialized
):
# In direct ingress mode, we need to wait at least
# RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external load
# balancers (e.g., ALB) time to deregister the replica, in addition to
# waiting for requests to drain.
await asyncio.gather(
asyncio.sleep(RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S),
self._drain_ongoing_requests(),
)
logger.info(
f"Replica {self._replica_id} successfully drained ongoing requests."
super().perform_graceful_shutdown(),
)
else:
await super().perform_graceful_shutdown()
Comment on lines +2314 to +2328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to super().perform_graceful_shutdown() is duplicated in the if and else blocks. This can be refactored to reduce code duplication and improve maintainability.

Suggested change
if (
RAY_SERVE_ENABLE_DIRECT_INGRESS
and self._ingress
and self._user_callable_initialized
):
# In direct ingress mode, we need to wait at least
# RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external load
# balancers (e.g., ALB) time to deregister the replica, in addition to
# waiting for requests to drain.
await asyncio.gather(
asyncio.sleep(RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S),
self._drain_ongoing_requests(),
)
logger.info(
f"Replica {self._replica_id} successfully drained ongoing requests."
super().perform_graceful_shutdown(),
)
else:
await super().perform_graceful_shutdown()
tasks = [super().perform_graceful_shutdown()]
if (
RAY_SERVE_ENABLE_DIRECT_INGRESS
and self._ingress
and self._user_callable_initialized
):
# In direct ingress mode, we need to wait at least
# RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external load
# balancers (e.g., ALB) time to deregister the replica, in addition to
# waiting for requests to drain.
tasks.append(
asyncio.sleep(RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S)
)
await asyncio.gather(*tasks)


await self.shutdown()
# Cancel direct ingress HTTP/gRPC server tasks if they exist.
if self._direct_ingress_http_server_task:
self._direct_ingress_http_server_task.cancel()
if self._direct_ingress_grpc_server_task:
Expand Down
88 changes: 86 additions & 2 deletions python/ray/serve/tests/test_direct_ingress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,8 +2037,9 @@ def test_shutdown_replica_only_after_draining_requests(
"""Test that the replica is shutdown correctly when the deployment is shutdown."""
signal = SignalActor.remote()

# Increase graceful_shutdown_timeout_s to ensure replicas aren't force-killed
# before requests complete when RAY_SERVE_DISABLE_SHUTTING_DOWN_INGRESS_REPLICAS_FORCEFULLY=0
# In direct ingress mode, graceful_shutdown_timeout_s is automatically bumped to
# max(graceful_shutdown_timeout_s, RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S)
# to give external load balancers time to deregister the replica.
@serve.deployment(name="replica-shutdown-deployment", graceful_shutdown_timeout_s=5)
class ReplicaShutdownTest:
async def __call__(self):
Expand Down Expand Up @@ -2439,5 +2440,88 @@ class App:
assert isinstance(deployment_config, DeploymentConfig)


def test_stuck_requests_are_force_killed(_skip_if_ff_not_enabled, serve_instance):
"""This test is really slow, because it waits for the ports to be released from TIME_WAIT state.
The ports are in TIME_WAIT state because the replicas are force-killed and the ports are not
released immediately."""
import socket

def _can_bind_to_port(port):
"""Check if we can bind to the port (not just if nothing is listening)."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(("0.0.0.0", port))
sock.close()
return True
except OSError:
sock.close()
return False

signal = SignalActor.remote()

@serve.deployment(
name="stuck-requests-deployment",
graceful_shutdown_timeout_s=1,
)
class StuckRequestsTest:
async def __call__(self):
# This request will never complete - it waits forever
await signal.wait.remote()
return "ok"

serve.run(
StuckRequestsTest.bind(),
name="stuck-requests-deployment",
route_prefix="/stuck-requests-deployment",
)

# Collect all ports used by the application before deleting it
http_ports = get_http_ports(route_prefix="/stuck-requests-deployment")
grpc_ports = get_grpc_ports(route_prefix="/stuck-requests-deployment")

http_url = get_application_url("HTTP", app_name="stuck-requests-deployment")

with ThreadPoolExecutor() as executor:
# Send requests that will hang forever (signal is never sent)
futures = [executor.submit(httpx.get, http_url, timeout=60) for _ in range(2)]

# Wait for requests to be received by the replica
wait_for_condition(
lambda: ray.get(signal.cur_num_waiters.remote()) == 2, timeout=10
)

# Delete the deployment - requests are still stuck
serve.delete("stuck-requests-deployment", _blocking=False)

# Verify the application is eventually deleted (replica was force-killed).
# This should complete within graceful_shutdown_timeout_s (35s) + buffer.
wait_for_condition(
lambda: "stuck-requests-deployment" not in serve.status().applications,
timeout=10,
)

# The stuck requests should fail (connection closed or similar)
for future in futures:
try:
result = future.result(timeout=5)
# If we get a response, it should be an error (not 200)
assert result.status_code != 200
except Exception:
# Expected - request failed due to force-kill
pass

# Wait until all ports can be bound (not just until nothing is listening).
# This ensures the ports are fully released from TIME_WAIT state.
def all_ports_can_be_bound():
for port in http_ports + grpc_ports:
if not _can_bind_to_port(port):
return False
return True

# TIME_WAIT can last up to 60s on Linux, so use a generous timeout
wait_for_condition(all_ports_can_be_bound, timeout=120)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
4 changes: 3 additions & 1 deletion python/ray/serve/tests/unit/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self._docs_path = None
self._rank = replica_rank_context.get(replica_id.unique_id, None)
self._assign_rank_callback = None
self._ingress = False

@property
def is_cross_language(self) -> bool:
Expand Down Expand Up @@ -280,10 +281,11 @@ def reconfigure(
replica_rank_context[self._replica_id.unique_id] = rank
return updating

def recover(self):
def recover(self, ingress: bool = False):
if self.replica_id in dead_replicas_context:
return False

self._ingress = ingress
self.recovering = True
self.started = False
self._rank = replica_rank_context.get(self._replica_id.unique_id, None)
Expand Down