Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 157 additions & 6 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -873,14 +873,17 @@ impl Agent {
let reply_span = tracing::Span::current();
self.reset_retry_attempts().await;

let working_dir = session.working_dir.clone();
let provider = self.provider().await?;
let session_id = session_config.id.clone();
let working_dir = session.working_dir.clone();
tokio::spawn(async move {
if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await {
warn!("Failed to generate session description: {}", e);
}
});
let naming_handle = tokio::spawn(crate::session_context::with_session_id(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexhancock so this code and the test below makes sure the session rename task is also grouped by the session.id, which ack is not something people always care about, but if you think systemically there's probably no value in intentionally orphaning it. I noticed your ACP pr, if it top-levels the session ID stuff, will end up needing a similar line in spawn. #5657

Some(session_id.clone()),
async move {
if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await {
warn!("Failed to generate session description: {}", e);
}
},
));

Ok(Box::pin(async_stream::try_stream! {
let _ = reply_span.enter();
Expand Down Expand Up @@ -1220,6 +1223,8 @@ impl Agent {

tokio::task::yield_now().await;
}

let _ = naming_handle.await;
}))
}

Expand Down Expand Up @@ -1528,7 +1533,13 @@ impl Agent {
#[cfg(test)]
mod tests {
use super::*;
use crate::model::ModelConfig;
use crate::providers::base::ProviderUsage;
use crate::recipe::Response;
use crate::session::session_manager::SessionType;
use crate::session::SessionManager;
use async_trait::async_trait;
use test_case::test_case;

#[tokio::test]
async fn test_add_final_output_tool() -> Result<()> {
Expand Down Expand Up @@ -1587,4 +1598,144 @@ mod tests {

Ok(())
}

enum NamingBehavior {
Success,
Error,
Panic,
}

#[derive(Clone)]
struct MockNamingProvider {
model_config: ModelConfig,
behavior: Arc<std::sync::Mutex<NamingBehavior>>,
captured_session_id: Arc<std::sync::Mutex<Option<String>>>,
}

impl MockNamingProvider {
fn new(behavior: NamingBehavior) -> Self {
Self {
model_config: ModelConfig::new_or_fail("test-model"),
behavior: Arc::new(std::sync::Mutex::new(behavior)),
captured_session_id: Arc::new(std::sync::Mutex::new(None)),
}
}

fn get_captured_session_id(&self) -> Option<String> {
self.captured_session_id.lock().unwrap().clone()
}
}

#[async_trait]
impl Provider for MockNamingProvider {
fn metadata() -> crate::providers::base::ProviderMetadata {
crate::providers::base::ProviderMetadata::empty()
}

fn get_name(&self) -> &str {
"mock"
}

async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message::assistant().with_text("Response"),
ProviderUsage::new("mock".to_string(), Default::default()),
))
}

async fn generate_session_name(
&self,
_messages: &Conversation,
) -> Result<String, ProviderError> {
*self.captured_session_id.lock().unwrap() =
crate::session_context::current_session_id();

let behavior = self.behavior.lock().unwrap();
match *behavior {
NamingBehavior::Success => Ok("Generated Name".to_string()),
NamingBehavior::Error => {
Err(ProviderError::RequestFailed("naming failed".to_string()))
}
NamingBehavior::Panic => panic!("naming panicked"),
}
}

fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
}

// Verifies session ID is visible in generate_session_name when user hasn't provided a name.
// When user provides name, maybe_update_name early-returns and session ID isn't captured.
#[test_case(NamingBehavior::Success, None, "Generated Name", true)]
#[test_case(NamingBehavior::Error, None, "initial", true)]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error and panic cases here are just to make sure our ignoring the same doesn't accidentally crash the critical code. we didn't have a problem so far, but it is easy to create one, so these cases help IMHO

#[test_case(NamingBehavior::Panic, None, "initial", true)]
#[test_case(
NamingBehavior::Success,
Some("my-custom-name"),
"my-custom-name",
false
)]
#[tokio::test]
async fn test_session_naming(
behavior: NamingBehavior,
user_provided_name: Option<&str>,
expected_name: &str,
should_capture_session_id: bool,
) {
let provider = Arc::new(MockNamingProvider::new(behavior));
let agent = Agent::new();
agent.update_provider(provider.clone()).await.unwrap();

let session = SessionManager::create_session(
std::env::current_dir().unwrap(),
user_provided_name.unwrap_or("initial").to_string(),
SessionType::User,
)
.await
.unwrap();

if let Some(name) = user_provided_name {
SessionManager::update_session(&session.id)
.user_provided_name(name)
.apply()
.await
.unwrap();
}

let stream = agent
.reply(
Message::user().with_text("test"),
SessionConfig {
id: session.id.clone(),
schedule_id: None,
max_turns: Some(1),
retry_config: None,
},
None,
)
.await
.unwrap();
tokio::pin!(stream);
while stream.next().await.is_some() {}

let session = SessionManager::get_session(&session.id, false)
.await
.unwrap();
assert_eq!(session.name, expected_name);

if should_capture_session_id {
assert_eq!(provider.get_captured_session_id(), Some(session.id.clone()));
} else {
assert_eq!(provider.get_captured_session_id(), None);
}

SessionManager::delete_session(&session.id).await.unwrap();
}
}
83 changes: 83 additions & 0 deletions crates/goose/tests/session_id_propagation_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use goose::agents::{Agent, SessionConfig};
use goose::conversation::message::Message;
use goose::model::ModelConfig;
use goose::providers::api_client::{ApiClient, AuthMethod};
use goose::providers::base::Provider;
use goose::providers::openai::OpenAiProvider;
use goose::session::session_manager::SessionType;
use goose::session::SessionManager;
use goose::session_context;
use goose::session_context::SESSION_ID_HEADER;
use serde_json::json;
Expand Down Expand Up @@ -153,3 +156,83 @@ async fn test_different_sessions_have_different_ids() {
]
);
}

#[tokio::test]
async fn test_session_id_propagation_in_rename_task() {
let mock_server = MockServer::start().await;
let capture = HeaderCapture::new();
let capture_clone = capture.clone();

Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(move |req: &Request| {
capture_clone.capture_session_header(req);
ResponseTemplate::new(200).set_body_json(json!({
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "Test response",
"role": "assistant"
}
}],
"created": 1755133833,
"id": "chatcmpl-test",
"model": "gpt-5-nano",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 8,
"total_tokens": 18
}
}))
})
.mount(&mock_server)
.await;

let api_client = ApiClient::new(
mock_server.uri(),
AuthMethod::BearerToken("test-key".to_string()),
)
.unwrap();
let model = ModelConfig::new_or_fail("gpt-5-nano");
let provider = Arc::new(OpenAiProvider::new(api_client, model));

let agent = Agent::new();
agent.update_provider(provider).await.unwrap();

let session = SessionManager::create_session(
std::env::current_dir().unwrap(),
"initial".to_string(),
SessionType::User,
)
.await
.unwrap();

session_context::with_session_id(Some(session.id.clone()), async {
let stream = agent
.reply(
Message::user().with_text("test"),
SessionConfig {
id: session.id.clone(),
schedule_id: None,
max_turns: Some(1),
retry_config: None,
},
None,
)
.await
.unwrap();

use futures::StreamExt;
tokio::pin!(stream);
while stream.next().await.is_some() {}
})
.await;

let captured = capture.get_captured();
assert_eq!(captured.len(), 2);
assert_eq!(captured[0], Some(session.id.clone()));
assert_eq!(captured[1], Some(session.id.clone()));

SessionManager::delete_session(&session.id).await.unwrap();
}