Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for pause more often #519

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions browser_use/agent/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def set_tool_calling_method(self, tool_calling_method: Optional[str]) -> Optiona
def add_new_task(self, new_task: str) -> None:
self.message_manager.add_new_task(new_task)

def _check_if_stopped_or_paused(self) -> bool:
if self._stopped or self._paused:
logger.debug('Agent paused after getting state')
raise InterruptedError
return False

@time_execution_async('--step')
async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
"""Execute one step of the task"""
Expand All @@ -246,13 +252,13 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
try:
state = await self.browser_context.get_state()

if self._stopped or self._paused:
logger.debug('Agent paused after getting state')
raise InterruptedError
self._check_if_stopped_or_paused()

self.message_manager.add_state_message(state, self._last_result, step_info, self.use_vision)
input_messages = self.message_manager.get_messages()

self._check_if_stopped_or_paused()

try:
model_output = await self.get_next_action(input_messages)

Expand All @@ -262,9 +268,7 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
self._save_conversation(input_messages, model_output)
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history

if self._stopped or self._paused:
logger.debug('Agent paused after getting next action')
raise InterruptedError
self._check_if_stopped_or_paused()

self.message_manager.add_model_output(model_output)
except Exception as e:
Expand All @@ -277,6 +281,7 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
self.browser_context,
page_extraction_llm=self.page_extraction_llm,
sensitive_data=self.sensitive_data,
check_break_if_paused=lambda: self._check_if_stopped_or_paused(),
)
self._last_result = result

Expand All @@ -287,6 +292,11 @@ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:

except InterruptedError:
logger.debug('Agent paused')
self._last_result = [
ActionResult(
error='The agent was paused - now continuing actions might need to be repeated', include_in_memory=True
)
]
return
except Exception as e:
result = await self._handle_step_error(e)
Expand Down Expand Up @@ -488,6 +498,7 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
self.browser_context,
check_for_new_elements=False,
page_extraction_llm=self.page_extraction_llm,
check_break_if_paused=lambda: self._check_if_stopped_or_paused(),
)
self._last_result = result

Expand Down
10 changes: 9 additions & 1 deletion browser_use/controller/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import Dict, Optional, Type
from typing import Callable, Dict, Optional, Type

from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel
Expand Down Expand Up @@ -444,6 +444,7 @@ async def multi_act(
self,
actions: list[ActionModel],
browser_context: BrowserContext,
check_break_if_paused: Callable[[], bool],
check_for_new_elements: bool = True,
page_extraction_llm: Optional[BaseChatModel] = None,
sensitive_data: Optional[Dict[str, str]] = None,
Expand All @@ -454,9 +455,14 @@ async def multi_act(
session = await browser_context.get_session()
cached_selector_map = session.cached_state.selector_map
cached_path_hashes = set(e.hash.branch_path_hash for e in cached_selector_map.values())

check_break_if_paused()

await browser_context.remove_highlights()

for i, action in enumerate(actions):
check_break_if_paused()

if action.get_index() is not None and i != 0:
new_state = await browser_context.get_state()
new_path_hashes = set(e.hash.branch_path_hash for e in new_state.selector_map.values())
Expand All @@ -465,6 +471,8 @@ async def multi_act(
logger.info(f'Something new appeared after action {i} / {len(actions)}')
break

check_break_if_paused()

results.append(await self.act(action, browser_context, page_extraction_llm, sensitive_data))

logger.debug(f'Executed action {i + 1} / {len(actions)}')
Expand Down
3 changes: 2 additions & 1 deletion examples/features/pause_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class AgentController:
def __init__(self):
llm = ChatOpenAI(model='gpt-4o')
self.agent = Agent(
task="Go to wikipedia.org and search for 'Python programming language', then read the first paragraph", llm=llm
task='open in one action https://www.google.com, https://www.wikipedia.org, https://www.youtube.com, https://www.github.com, https://amazon.com',
llm=llm,
)
self.running = False

Expand Down