Skip to content

Commit 2e6c16b

Browse files
abrarsheikhFuture-Outlier
authored andcommitted
[1/n] expose outbound deployment ids from replica actor (ray-project#58345)
## Summary Adds a new method to expose all downstream deployments that a replica calls into, enabling dependency graph construction. ## Motivation Deployments call downstream deployments via handles in two ways: 1. **Stored handles**: Passed to `__init__()` and stored as attributes → `self.model.func.remote()` 2. **Dynamic handles**: Obtained at runtime via `serve.get_deployment_handle()` → `model.func.remote()` Previously, there was no way to programmatically discover these dependencies from a running replica. ## Implementation ### Core Changes - **`ReplicaActor.list_outbound_deployments()`**: Returns `List[DeploymentID]` of all downstream deployments - Recursively inspects user callable attributes to find stored handles (including nested in dicts/lists) - Tracks dynamic handles created via `get_deployment_handle()` at runtime using a callback mechanism - **Runtime tracking**: Modified `get_deployment_handle()` to register handles when called from within a replica via `ReplicaContext._handle_registration_callback` Next PR: ray-project#58350 --------- Signed-off-by: abrar <abrar@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
1 parent 6cdd946 commit 2e6c16b

File tree

5 files changed

+267
-2
lines changed

5 files changed

+267
-2
lines changed

python/ray/serve/_private/replica.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Generator,
2323
List,
2424
Optional,
25+
Set,
2526
Tuple,
2627
Union,
2728
)
@@ -37,6 +38,7 @@
3738
from ray._common.filters import CoreContextFilter
3839
from ray._common.utils import get_or_create_event_loop
3940
from ray.actor import ActorClass, ActorHandle
41+
from ray.dag.py_obj_scanner import _PyObjScanner
4042
from ray.remote_function import RemoteFunction
4143
from ray.serve import metrics
4244
from ray.serve._private.common import (
@@ -113,6 +115,7 @@
113115
DeploymentUnavailableError,
114116
RayServeException,
115117
)
118+
from ray.serve.handle import DeploymentHandle
116119
from ray.serve.schema import EncodingType, LoggingConfig
117120

118121
logger = logging.getLogger(SERVE_LOGGER_NAME)
@@ -542,6 +545,9 @@ def __init__(
542545
self._user_callable_initialized_lock = asyncio.Lock()
543546
self._initialization_latency: Optional[float] = None
544547

548+
# Track deployment handles created dynamically via get_deployment_handle()
549+
self._dynamically_created_handles: Set[DeploymentID] = set()
550+
545551
# Flipped to `True` when health checks pass and `False` when they fail. May be
546552
# used by replica subclass implementations.
547553
self._healthy = False
@@ -600,17 +606,26 @@ def get_metadata(self) -> ReplicaMetadata:
600606
route_patterns,
601607
)
602608

609+
def get_dynamically_created_handles(self) -> Set[DeploymentID]:
610+
return self._dynamically_created_handles
611+
603612
def _set_internal_replica_context(
604613
self, *, servable_object: Callable = None, rank: int = None
605614
):
606615
# Calculate world_size from deployment config instead of storing it
607616
world_size = self._deployment_config.num_replicas
617+
618+
# Create callback for registering dynamically created handles
619+
def register_handle_callback(deployment_id: DeploymentID) -> None:
620+
self._dynamically_created_handles.add(deployment_id)
621+
608622
ray.serve.context._set_internal_replica_context(
609623
replica_id=self._replica_id,
610624
servable_object=servable_object,
611625
_deployment_config=self._deployment_config,
612626
rank=rank,
613627
world_size=world_size,
628+
handle_registration_callback=register_handle_callback,
614629
)
615630

616631
def _configure_logger_and_profilers(
@@ -1204,6 +1219,45 @@ def get_num_ongoing_requests(self) -> int:
12041219
"""
12051220
return self._replica_impl.get_num_ongoing_requests()
12061221

1222+
def list_outbound_deployments(self) -> List[DeploymentID]:
1223+
"""List all outbound deployment IDs this replica calls into.
1224+
1225+
This includes:
1226+
- Handles created via get_deployment_handle()
1227+
- Handles passed as init args/kwargs to the deployment constructor
1228+
1229+
This is used to determine which deployments are reachable from this replica.
1230+
The list of DeploymentIDs can change over time as new handles can be created at runtime.
1231+
Also its not guaranteed that the list of DeploymentIDs are identical across replicas
1232+
because it depends on user code.
1233+
1234+
Returns:
1235+
A list of DeploymentIDs that this replica calls into.
1236+
"""
1237+
seen_deployment_ids: Set[DeploymentID] = set()
1238+
1239+
# First, collect dynamically created handles
1240+
for deployment_id in self._replica_impl.get_dynamically_created_handles():
1241+
seen_deployment_ids.add(deployment_id)
1242+
1243+
# Get the init args/kwargs
1244+
init_args = self._replica_impl._user_callable_wrapper._init_args
1245+
init_kwargs = self._replica_impl._user_callable_wrapper._init_kwargs
1246+
1247+
# Use _PyObjScanner to find all DeploymentHandle objects in:
1248+
# The init_args and init_kwargs (handles might be passed as init args)
1249+
scanner = _PyObjScanner(source_type=DeploymentHandle)
1250+
try:
1251+
handles = scanner.find_nodes((init_args, init_kwargs))
1252+
1253+
for handle in handles:
1254+
deployment_id = handle.deployment_id
1255+
seen_deployment_ids.add(deployment_id)
1256+
finally:
1257+
scanner.clear()
1258+
1259+
return list(seen_deployment_ids)
1260+
12071261
async def is_allocated(self) -> str:
12081262
"""poke the replica to check whether it's alive.
12091263

python/ray/serve/api.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,4 +1065,15 @@ async def __call__(self, val: int) -> int:
10651065
if _record_telemetry:
10661066
ServeUsageTag.SERVE_GET_DEPLOYMENT_HANDLE_API_USED.record("1")
10671067

1068-
return client.get_handle(deployment_name, app_name, check_exists=_check_exists)
1068+
handle: DeploymentHandle = client.get_handle(
1069+
deployment_name, app_name, check_exists=_check_exists
1070+
)
1071+
1072+
# Track handle creation if called from within a replica
1073+
if (
1074+
internal_replica_context is not None
1075+
and internal_replica_context._handle_registration_callback is not None
1076+
):
1077+
internal_replica_context._handle_registration_callback(handle.deployment_id)
1078+
1079+
return handle

python/ray/serve/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import ray
1414
from ray.exceptions import RayActorError
1515
from ray.serve._private.client import ServeControllerClient
16-
from ray.serve._private.common import ReplicaID
16+
from ray.serve._private.common import DeploymentID, ReplicaID
1717
from ray.serve._private.config import DeploymentConfig
1818
from ray.serve._private.constants import (
1919
SERVE_CONTROLLER_NAME,
@@ -50,6 +50,7 @@ class ReplicaContext:
5050
_deployment_config: DeploymentConfig
5151
rank: int
5252
world_size: int
53+
_handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None
5354

5455
@property
5556
def app_name(self) -> str:
@@ -114,6 +115,7 @@ def _set_internal_replica_context(
114115
_deployment_config: DeploymentConfig,
115116
rank: int,
116117
world_size: int,
118+
handle_registration_callback: Optional[Callable[[str, str], None]] = None,
117119
):
118120
global _INTERNAL_REPLICA_CONTEXT
119121
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
@@ -122,6 +124,7 @@ def _set_internal_replica_context(
122124
_deployment_config=_deployment_config,
123125
rank=rank,
124126
world_size=world_size,
127+
_handle_registration_callback=handle_registration_callback,
125128
)
126129

127130

python/ray/serve/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ py_test_module_list(
119119
"test_http_headers.py",
120120
"test_http_routes.py",
121121
"test_https_proxy.py",
122+
"test_list_outbound_deployments.py",
122123
"test_max_replicas_per_node.py",
123124
"test_multiplex.py",
124125
"test_proxy.py",
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import sys
2+
from typing import List
3+
4+
import pytest
5+
6+
import ray
7+
from ray import serve
8+
from ray.serve._private.common import DeploymentID
9+
from ray.serve._private.constants import SERVE_NAMESPACE
10+
from ray.serve.handle import DeploymentHandle
11+
12+
13+
@serve.deployment
14+
class DownstreamA:
15+
def __call__(self, x: int) -> int:
16+
return x * 2
17+
18+
19+
@serve.deployment
20+
class DownstreamB:
21+
def process(self, x: int) -> int:
22+
return x + 10
23+
24+
25+
@serve.deployment
26+
class UpstreamWithStoredHandles:
27+
def __init__(self, handle_a: DeploymentHandle, handle_b: DeploymentHandle):
28+
self.handle_a = handle_a
29+
self.handle_b = handle_b
30+
31+
async def __call__(self, x: int) -> int:
32+
result_a = await self.handle_a.remote(x)
33+
result_b = await self.handle_b.process.remote(x)
34+
return result_a + result_b
35+
36+
37+
@serve.deployment
38+
class UpstreamWithNestedHandles:
39+
def __init__(self, handles_dict: dict, handles_list: list):
40+
self.handles = handles_dict # {"a": handle_a, "b": handle_b}
41+
self.handle_list = handles_list # [handle_a, handle_b]
42+
43+
async def __call__(self, x: int) -> int:
44+
result_a = await self.handles["a"].remote(x)
45+
result_b = await self.handles["b"].process.remote(x)
46+
return result_a + result_b
47+
48+
49+
@serve.deployment
50+
class DynamicDeployment:
51+
async def __call__(self, x: int, app_name1: str, app_name2: str) -> int:
52+
handle_a = serve.get_deployment_handle("DownstreamA", app_name=app_name1)
53+
handle_b = serve.get_deployment_handle("DownstreamB", app_name=app_name2)
54+
result_a = await handle_a.remote(x)
55+
result_b = await handle_b.process.remote(x)
56+
return result_a + result_b
57+
58+
59+
def get_replica_actor_handle(deployment_name: str, app_name: str):
60+
actors = ray.util.list_named_actors(all_namespaces=True)
61+
replica_actor_name = None
62+
for actor in actors:
63+
# Match pattern: SERVE_REPLICA::{app_name}#{deployment_name}#
64+
if actor["name"].startswith(f"SERVE_REPLICA::{app_name}#{deployment_name}#"):
65+
replica_actor_name = actor["name"]
66+
break
67+
68+
if replica_actor_name is None:
69+
# Debug: print all actor names to help diagnose
70+
all_actors = [a["name"] for a in actors if "SERVE" in a["name"]]
71+
raise RuntimeError(
72+
f"Could not find replica actor for {deployment_name} in app {app_name}. "
73+
f"Available serve actors: {all_actors}"
74+
)
75+
76+
return ray.get_actor(replica_actor_name, namespace=SERVE_NAMESPACE)
77+
78+
79+
@pytest.mark.asyncio
80+
class TestListOutboundDeployments:
81+
"""Test suite for list_outbound_deployments() method."""
82+
83+
async def test_stored_handles_in_init(self, serve_instance):
84+
"""Test listing handles that are passed to __init__ and stored as attributes."""
85+
app_name = "test_stored_handles"
86+
87+
# Build and deploy the app
88+
handle_a = DownstreamA.bind()
89+
handle_b = DownstreamB.bind()
90+
app = UpstreamWithStoredHandles.bind(handle_a, handle_b)
91+
92+
serve.run(app, name=app_name)
93+
94+
# Get the replica actor for the upstream deployment
95+
replica_actor = get_replica_actor_handle("UpstreamWithStoredHandles", app_name)
96+
97+
# Call list_outbound_deployments
98+
outbound_deployments: List[DeploymentID] = ray.get(
99+
replica_actor.list_outbound_deployments.remote()
100+
)
101+
102+
# Verify results
103+
deployment_names = {dep_id.name for dep_id in outbound_deployments}
104+
assert "DownstreamA" in deployment_names
105+
assert "DownstreamB" in deployment_names
106+
assert len(outbound_deployments) == 2
107+
108+
# Verify app names match
109+
for dep_id in outbound_deployments:
110+
assert dep_id.app_name == app_name
111+
112+
async def test_nested_handles_in_dict_and_list(self, serve_instance):
113+
"""Test listing handles stored in nested data structures (dict, list)."""
114+
app_name = "test_nested_handles"
115+
116+
# Build and deploy the app
117+
handle_a = DownstreamA.bind()
118+
handle_b = DownstreamB.bind()
119+
handles_dict = {"a": handle_a, "b": handle_b}
120+
handles_list = [handle_a, handle_b]
121+
app = UpstreamWithNestedHandles.bind(handles_dict, handles_list)
122+
123+
serve.run(app, name=app_name)
124+
125+
# Get the replica actor
126+
replica_actor = get_replica_actor_handle("UpstreamWithNestedHandles", app_name)
127+
128+
# Call list_outbound_deployments
129+
outbound_deployments: List[DeploymentID] = ray.get(
130+
replica_actor.list_outbound_deployments.remote()
131+
)
132+
133+
# Verify results (should find handles despite being in nested structures)
134+
deployment_names = {dep_id.name for dep_id in outbound_deployments}
135+
assert "DownstreamA" in deployment_names
136+
assert "DownstreamB" in deployment_names
137+
138+
# Verify no duplicates (handle_a and handle_b appear in both dict and list)
139+
assert len(outbound_deployments) == 2
140+
141+
async def test_no_handles(self, serve_instance):
142+
"""Test deployment with no outbound handles."""
143+
app_name = "test_no_handles"
144+
145+
# Deploy a simple deployment with no handles
146+
app = DownstreamA.bind()
147+
serve.run(app, name=app_name)
148+
149+
# Get the replica actor
150+
replica_actor = get_replica_actor_handle("DownstreamA", app_name)
151+
152+
# Call list_outbound_deployments
153+
outbound_deployments: List[DeploymentID] = ray.get(
154+
replica_actor.list_outbound_deployments.remote()
155+
)
156+
157+
# Should be empty
158+
assert len(outbound_deployments) == 0
159+
160+
async def test_dynamic_handles(self, serve_instance):
161+
app1 = DownstreamA.bind()
162+
app2 = DownstreamB.bind()
163+
app3 = DynamicDeployment.bind()
164+
165+
serve.run(app1, name="app1", route_prefix="/app1")
166+
serve.run(app2, name="app2", route_prefix="/app2")
167+
handle = serve.run(app3, name="app3", route_prefix="/app3")
168+
169+
# Make requests to trigger dynamic handle creation
170+
# x=1: DownstreamA returns 1*2=2, DownstreamB returns 1+10=11, total=2+11=13
171+
results = [await handle.remote(1, "app1", "app2") for _ in range(10)]
172+
for result in results:
173+
assert result == 13
174+
175+
# Get the replica actor
176+
replica_actor = get_replica_actor_handle("DynamicDeployment", "app3")
177+
178+
# Call list_outbound_deployments
179+
outbound_deployments: List[DeploymentID] = ray.get(
180+
replica_actor.list_outbound_deployments.remote()
181+
)
182+
183+
# Verify results - should include dynamically created handles
184+
deployment_names = {dep_id.name for dep_id in outbound_deployments}
185+
assert "DownstreamA" in deployment_names
186+
assert "DownstreamB" in deployment_names
187+
assert len(outbound_deployments) == 2
188+
189+
# Verify the app names are correct
190+
app_names = {dep_id.app_name for dep_id in outbound_deployments}
191+
assert "app1" in app_names
192+
assert "app2" in app_names
193+
194+
195+
if __name__ == "__main__":
196+
sys.exit(pytest.main(["-v", "-s", __file__]))

0 commit comments

Comments
 (0)