Skip to content

Commit 3bcc3a9

Browse files
committed
make AsyncSession and Session subclass SessionStartResponse for compat with autogen sdk docs
1 parent ab62392 commit 3bcc3a9

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/stagehand/resources/sessions_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create(
5959
extra_body=extra_body,
6060
timeout=timeout,
6161
)
62-
return Session(self._client, start_response.data.session_id)
62+
return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)
6363

6464

6565
class AsyncSessionsResourceWithHelpers(AsyncSessionsResource):
@@ -86,7 +86,7 @@ async def create(
8686
extra_body: Body | None = None,
8787
timeout: float | httpx.Timeout | None | NotGiven = not_given,
8888
) -> AsyncSession:
89-
start_response = await self.start(
89+
start_response: SessionStartResponse = await self.start(
9090
model_name=model_name,
9191
act_timeout_ms=act_timeout_ms,
9292
browser=browser,
@@ -107,4 +107,4 @@ async def create(
107107
extra_body=extra_body,
108108
timeout=timeout,
109109
)
110-
return AsyncSession(self._client, start_response.data.session_id)
110+
return AsyncSession(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)

src/stagehand/session.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
session_navigate_params,
1717
)
1818
from ._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
19+
from .types.session_start_response import SessionStartResponse, Data as SessionStartResponseData
1920
from .types.session_act_response import SessionActResponse
2021
from .types.session_end_response import SessionEndResponse
2122
from .types.session_execute_response import SessionExecuteResponse
@@ -36,12 +37,15 @@ def _with_default_frame_id(params: TSessionParams) -> TSessionParams:
3637
from ._client import Stagehand, AsyncStagehand
3738

3839

39-
class Session:
40+
class Session(SessionStartResponse):
4041
"""A Stagehand session bound to a specific `session_id`."""
4142

42-
def __init__(self, client: Stagehand, id: str) -> None:
43+
def __init__(self, client: Stagehand, id: str, data: SessionStartResponseData, success: bool) -> None:
4344
self._client = client
4445
self.id = id
46+
# in case user tries to use client.sessions.start(...) return value as a SessionStartResponse dataclass/dict
47+
super().__init__(data=data, success=success)
48+
4549

4650
def navigate(
4751
self,
@@ -158,12 +162,13 @@ def end(
158162
)
159163

160164

161-
class AsyncSession:
165+
class AsyncSession(SessionStartResponse):
162166
"""Async variant of `Session`."""
163167

164-
def __init__(self, client: AsyncStagehand, id: str) -> None:
168+
def __init__(self, client: AsyncStagehand, id: str, data: SessionStartResponseData, success: bool) -> None:
165169
self._client = client
166170
self.id = id
171+
super().__init__(data=data, success=success)
167172

168173
async def navigate(
169174
self,

0 commit comments

Comments
 (0)