Skip to content

Commit f22259b

Browse files
authored
[serve] Add new DeploymentHandle type that returns DeploymentResponse (#37817)
Adds new DeploymentHandle that removes the need for "double await" and also unifies the sync & async handle types. See docstrings for detailed API information. The new API can be used either by setting `RAY_SERVE_ENABLE_NEW_HANDLE_API=1` at the cluster level or by using `handle.options(use_new_handle_api=True)`. The migration plan is as follows: - In Ray 2.7, we will be fully backwards-compatible and users will need to opt-in to use the new API (via environment variable or .options() flag). - The documentation and other docstrings will be updated to push users to use the new handle type and warn that the existing one will be deprecated. - Starting in Ray 2.8, we will print deprecation warnings when users use the existing handle API. - In Ray 3.0, the default will be changed to use the new handle.
1 parent f2a9ca8 commit f22259b

26 files changed

+1405
-293
lines changed

python/ray/_private/usage/usage_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,15 @@ def put_pre_init_usage_stats():
269269
_put_pre_init_extra_usage_tags()
270270

271271

272+
def reset_global_state():
273+
global _recorded_library_usages, _recorded_extra_usage_tags
274+
275+
with _recorded_library_usages_lock:
276+
_recorded_library_usages = set()
277+
with _recorded_extra_usage_tags_lock:
278+
_recorded_extra_usage_tags = dict()
279+
280+
272281
ray._private.worker._post_init_hooks.append(put_pre_init_usage_stats)
273282

274283

python/ray/serve/BUILD

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ py_test(
171171

172172
py_test(
173173
name = "test_telemetry",
174-
size = "medium",
174+
size = "large",
175175
srcs = serve_tests_srcs,
176176
tags = ["exclusive", "team:serve"],
177177
deps = [":serve_lib"],
@@ -284,6 +284,25 @@ py_test(
284284
)
285285

286286

287+
py_test(
288+
name = "test_new_handle_api",
289+
size = "medium",
290+
srcs = serve_tests_srcs,
291+
tags = ["exclusive", "team:serve"],
292+
deps = [":serve_lib"],
293+
)
294+
295+
296+
py_test(
297+
name = "test_new_handle_api_set_via_env_var",
298+
size = "medium",
299+
srcs = serve_tests_srcs,
300+
tags = ["exclusive", "team:serve"],
301+
deps = [":serve_lib"],
302+
env = {"RAY_SERVE_ENABLE_NEW_HANDLE_API": "1"},
303+
)
304+
305+
287306
py_test(
288307
name = "test_kv_store",
289308
size = "small",

python/ray/serve/_private/client.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CLIENT_POLLING_INTERVAL_S,
2222
CLIENT_CHECK_CREATION_POLLING_INTERVAL_S,
2323
MAX_CACHED_HANDLES,
24+
RAY_SERVE_ENABLE_NEW_HANDLE_API,
2425
SERVE_DEFAULT_APP_NAME,
2526
)
2627
from ray.serve._private.deploy_utils import get_deploy_args
@@ -31,7 +32,7 @@
3132
from ray.serve.generated.serve_pb2 import (
3233
DeploymentStatusInfo as DeploymentStatusInfoProto,
3334
)
34-
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
35+
from ray.serve.handle import DeploymentHandle, RayServeHandle, RayServeSyncHandle
3536
from ray.serve.schema import ServeApplicationSchema, ServeDeploySchema
3637

3738
logger = logging.getLogger(__file__)
@@ -495,10 +496,24 @@ def get_handle(
495496
"exist."
496497
)
497498

498-
if sync:
499-
handle = RayServeSyncHandle(deployment_name, app_name)
499+
if RAY_SERVE_ENABLE_NEW_HANDLE_API:
500+
handle = DeploymentHandle(
501+
deployment_name,
502+
app_name,
503+
_is_for_sync_context=sync,
504+
)
505+
elif sync:
506+
handle = RayServeSyncHandle(
507+
deployment_name,
508+
app_name,
509+
_is_for_sync_context=sync,
510+
)
500511
else:
501-
handle = RayServeHandle(deployment_name, app_name)
512+
handle = RayServeHandle(
513+
deployment_name,
514+
app_name,
515+
_is_for_sync_context=sync,
516+
)
502517

503518
self.handle_cache[cache_key] = handle
504519
if cache_key in self._evicted_handle_keys:

python/ray/serve/_private/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@
221221
or RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
222222
)
223223

224+
# Feature flag to enable new handle API.
225+
RAY_SERVE_ENABLE_NEW_HANDLE_API = (
226+
os.environ.get("RAY_SERVE_ENABLE_NEW_HANDLE_API", "0") == "1"
227+
)
228+
224229
# Serve HTTP proxy callback import path.
225230
RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH = os.environ.get(
226231
"RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH", None

python/ray/serve/_private/deployment_function_node.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from ray.dag.dag_node import DAGNode
44
from ray.dag.format_utils import get_dag_node_str
5+
56
from ray.serve.deployment import Deployment, schema_to_deployment
67
from ray.serve.config import DeploymentConfig, ReplicaConfig
7-
from ray.serve.handle import RayServeHandle
8+
from ray.serve.handle import DeploymentHandle, RayServeHandle
89
from ray.serve.schema import DeploymentSchema
10+
from ray.serve._private.constants import RAY_SERVE_ENABLE_NEW_HANDLE_API
911

1012

1113
class DeploymentFunctionNode(DAGNode):
@@ -68,7 +70,14 @@ def __init__(
6870
_internal=True,
6971
)
7072

71-
self._deployment_handle = RayServeHandle(self._deployment.name, self._app_name)
73+
if RAY_SERVE_ENABLE_NEW_HANDLE_API:
74+
self._deployment_handle = DeploymentHandle(
75+
self._deployment.name, self._app_name
76+
)
77+
else:
78+
self._deployment_handle = RayServeHandle(
79+
self._deployment.name, self._app_name
80+
)
7281

7382
def _copy_impl(
7483
self,

python/ray/serve/_private/deployment_graph_build.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from ray.serve.deployment import Deployment, schema_to_deployment
88
from ray.serve.deployment_graph import RayServeDAGHandle
9-
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME
9+
from ray.serve._private.constants import (
10+
RAY_SERVE_ENABLE_NEW_HANDLE_API,
11+
SERVE_DEFAULT_APP_NAME,
12+
)
1013
from ray.serve._private.deployment_method_node import DeploymentMethodNode
1114
from ray.serve._private.deployment_node import DeploymentNode
1215
from ray.serve._private.deployment_function_node import DeploymentFunctionNode
@@ -17,7 +20,7 @@
1720
from ray.serve._private.deployment_function_executor_node import (
1821
DeploymentFunctionExecutorNode,
1922
)
20-
from ray.serve.handle import RayServeHandle
23+
from ray.serve.handle import DeploymentHandle, RayServeHandle
2124
from ray.serve.schema import DeploymentSchema
2225

2326

@@ -177,7 +180,10 @@ def replace_with_handle(node):
177180
if isinstance(node, DeploymentNode) or isinstance(
178181
node, DeploymentFunctionNode
179182
):
180-
return RayServeHandle(node._deployment.name, app_name)
183+
if RAY_SERVE_ENABLE_NEW_HANDLE_API:
184+
return DeploymentHandle(node._deployment.name, app_name)
185+
else:
186+
return RayServeHandle(node._deployment.name, app_name)
181187
elif isinstance(node, DeploymentExecutorNode):
182188
return node._deployment_handle
183189

python/ray/serve/_private/deployment_node.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any, Dict, Optional, List, Tuple
22

33
from ray.dag import DAGNode
4-
from ray.serve.handle import RayServeHandle
5-
64
from ray.dag.constants import PARENT_CLASS_NODE_KEY
75
from ray.dag.format_utils import get_dag_node_str
8-
from ray.serve._private.deployment_method_node import DeploymentMethodNode
6+
97
from ray.serve.deployment import Deployment
8+
from ray.serve.handle import DeploymentHandle, RayServeHandle
9+
from ray.serve._private.constants import RAY_SERVE_ENABLE_NEW_HANDLE_API
10+
from ray.serve._private.deployment_method_node import DeploymentMethodNode
1011

1112

1213
class DeploymentNode(DAGNode):
@@ -30,9 +31,16 @@ def __init__(
3031
ray_actor_options,
3132
other_args_to_resolve=other_args_to_resolve,
3233
)
33-
self._deployment = deployment
3434
self._app_name = app_name
35-
self._deployment_handle = RayServeHandle(self._deployment.name, app_name)
35+
self._deployment = deployment
36+
if RAY_SERVE_ENABLE_NEW_HANDLE_API:
37+
self._deployment_handle = DeploymentHandle(
38+
self._deployment.name, self._app_name
39+
)
40+
else:
41+
self._deployment_handle = RayServeHandle(
42+
self._deployment.name, self._app_name
43+
)
3644

3745
def _copy_impl(
3846
self,

python/ray/serve/_private/http_proxy.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -501,21 +501,24 @@ async def _assign_request_with_timeout(
501501
`disconnected_task` is expected to be done if the client disconnects; in this
502502
case, we will abort assigning a replica and return `None`.
503503
"""
504-
assignment_task = handle.remote(
504+
result_gen = handle.remote(
505505
StreamingHTTPRequest(pickle.dumps(scope), self.self_actor_handle)
506506
)
507+
to_object_ref_gen = asyncio.ensure_future(
508+
result_gen._to_object_ref_gen(_record_telemetry=False)
509+
)
507510
done, _ = await asyncio.wait(
508-
[assignment_task, disconnected_task],
511+
[to_object_ref_gen, disconnected_task],
509512
return_when=FIRST_COMPLETED,
510513
timeout=timeout_s,
511514
)
512-
if assignment_task in done:
513-
return assignment_task.result()
515+
if to_object_ref_gen in done:
516+
return to_object_ref_gen.result()
514517
elif disconnected_task in done:
515-
assignment_task.cancel()
518+
result_gen.cancel()
516519
return None
517520
else:
518-
assignment_task.cancel()
521+
result_gen.cancel()
519522
raise TimeoutError()
520523

521524
@abstractmethod
@@ -650,10 +653,13 @@ async def send_request_to_replica_unary(
650653
# call might never arrive; if it does, it can only be `http.disconnect`.
651654
while retries < HTTP_REQUEST_MAX_RETRIES + 1:
652655
should_backoff = False
653-
assignment_task: asyncio.Task = handle.remote(request)
656+
result_ref = handle.remote(request)
654657
client_disconnection_task = loop.create_task(receive())
655658
done, _ = await asyncio.wait(
656-
[assignment_task, client_disconnection_task],
659+
[
660+
result_ref._to_object_ref(_record_telemetry=False),
661+
client_disconnection_task,
662+
],
657663
return_when=FIRST_COMPLETED,
658664
)
659665
if client_disconnection_task in done:
@@ -667,14 +673,11 @@ async def send_request_to_replica_unary(
667673
"request.",
668674
extra={"log_to_stderr": False},
669675
)
670-
# This will make the .result() to raise cancelled error.
671-
assignment_task.cancel()
676+
result_ref.cancel()
672677
else:
673678
client_disconnection_task.cancel()
674679

675680
try:
676-
object_ref = await assignment_task
677-
678681
# NOTE (shrekris-anyscale): when the gcs, Serve controller, and
679682
# some replicas crash simultaneously (e.g. if the head node crashes),
680683
# requests to the dead replicas hang until the gcs recovers.
@@ -683,7 +686,7 @@ async def send_request_to_replica_unary(
683686
# check if latency drops significantly. See
684687
# https://github.com/ray-project/ray/pull/29534 for more info.
685688
_, request_timed_out = await asyncio.wait(
686-
[object_ref], timeout=self.request_timeout_s
689+
[result_ref], timeout=self.request_timeout_s
687690
)
688691
if request_timed_out:
689692
logger.info(
@@ -694,7 +697,7 @@ async def send_request_to_replica_unary(
694697
)
695698
should_backoff = True
696699
else:
697-
result = await object_ref
700+
result = await result_ref
698701
break
699702
except asyncio.CancelledError:
700703
# Here because the client disconnected, we will return a custom
@@ -932,7 +935,6 @@ def __init__(self, app):
932935
self.app = app
933936

934937
async def __call__(self, scope, receive, send):
935-
936938
headers = MutableHeaders(scope=scope)
937939
if RAY_SERVE_REQUEST_ID_HEADER not in headers and "x-request-id" not in headers:
938940
# If X-Request-ID and RAY_SERVE_REQUEST_ID_HEADER are both not set, we

python/ray/serve/_private/proxy_router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
6363
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
6464
and not info.app_is_cross_language
6565
),
66+
use_new_handle_api=True,
6667
)
6768

6869
# Clean up any handles that are no longer used.
@@ -147,6 +148,7 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]):
147148
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
148149
and not info.app_is_cross_language
149150
),
151+
use_new_handle_api=True,
150152
)
151153

152154
# Clean up any handles that are no longer used.

python/ray/serve/_private/router.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ class Query:
9797
metadata: RequestMetadata
9898

9999
async def resolve_async_tasks(self):
100-
"""Find all unresolved asyncio.Task and gather them all at once."""
100+
"""Find all unresolved asyncio.Task and gather them all at once.
101+
102+
This is used for the old serve handle API and should be removed once that API
103+
is fully deprecated & removed.
104+
"""
101105
scanner = _PyObjScanner(source_type=asyncio.Task)
102106

103107
try:
@@ -111,6 +115,38 @@ async def resolve_async_tasks(self):
111115
# Make the scanner GC-able to avoid memory leaks.
112116
scanner.clear()
113117

118+
async def resolve_deployment_handle_results_to_object_refs(self):
119+
"""Replace DeploymentHandleResults with their resolved ObjectRefs.
120+
121+
DeploymentResponseGenerators are rejected (not currently supported).
122+
"""
123+
from ray.serve.handle import (
124+
_DeploymentResponseBase,
125+
DeploymentResponseGenerator,
126+
)
127+
128+
scanner = _PyObjScanner(source_type=_DeploymentResponseBase)
129+
130+
try:
131+
result_to_object_ref_coros = []
132+
results = scanner.find_nodes((self.args, self.kwargs))
133+
for result in results:
134+
result_to_object_ref_coros.append(result._to_object_ref())
135+
if isinstance(result, DeploymentResponseGenerator):
136+
raise RuntimeError(
137+
"Streaming deployment handle results cannot be passed to "
138+
"downstream handle calls. If you have a use case requiring "
139+
"this feature, please file a feature request on GitHub."
140+
)
141+
142+
if len(results) > 0:
143+
obj_refs = await asyncio.gather(*result_to_object_ref_coros)
144+
replacement_table = dict(zip(results, obj_refs))
145+
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
146+
finally:
147+
# Make the scanner GC-able to avoid memory leaks.
148+
scanner.clear()
149+
114150
async def buffer_starlette_requests_and_warn(self):
115151
"""Buffer any `starlette.request.Requests` objects to make them serializable.
116152
@@ -1109,10 +1145,10 @@ async def assign_request(
11091145
metadata=request_meta,
11101146
)
11111147
await query.resolve_async_tasks()
1148+
await query.resolve_deployment_handle_results_to_object_refs()
11121149
await query.buffer_starlette_requests_and_warn()
1113-
result = await self._replica_scheduler.assign_replica(query)
11141150

1115-
return result
1151+
return await self._replica_scheduler.assign_replica(query)
11161152
finally:
11171153
# If the query is disconnected before assignment, this coroutine
11181154
# gets cancelled by the caller and an asyncio.CancelledError is

0 commit comments

Comments
 (0)