-
Notifications
You must be signed in to change notification settings - Fork 3.3k
fix: propagate session ID to LLM request headers in rename task #5624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -873,14 +873,17 @@ impl Agent { | |
| let reply_span = tracing::Span::current(); | ||
| self.reset_retry_attempts().await; | ||
|
|
||
| let working_dir = session.working_dir.clone(); | ||
codefromthecrypt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| let provider = self.provider().await?; | ||
| let session_id = session_config.id.clone(); | ||
| let working_dir = session.working_dir.clone(); | ||
| tokio::spawn(async move { | ||
| if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await { | ||
| warn!("Failed to generate session description: {}", e); | ||
| } | ||
| }); | ||
| let naming_handle = tokio::spawn(crate::session_context::with_session_id( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexhancock so this code and the test below makes sure the session rename task is also grouped by the session.id, which ack is not something people always care about, but if you think systemically there's probably no value in intentionally orphaning it. I noticed your ACP pr, if it top-levels the session ID stuff, will end up needing a similar line in spawn. #5657 |
||
| Some(session_id.clone()), | ||
| async move { | ||
| if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await { | ||
| warn!("Failed to generate session description: {}", e); | ||
| } | ||
| }, | ||
| )); | ||
|
|
||
| Ok(Box::pin(async_stream::try_stream! { | ||
| let _ = reply_span.enter(); | ||
|
|
@@ -1220,6 +1223,8 @@ impl Agent { | |
|
|
||
| tokio::task::yield_now().await; | ||
| } | ||
|
|
||
| let _ = naming_handle.await; | ||
| })) | ||
| } | ||
|
|
||
|
|
@@ -1528,7 +1533,13 @@ impl Agent { | |
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use crate::model::ModelConfig; | ||
| use crate::providers::base::ProviderUsage; | ||
| use crate::recipe::Response; | ||
| use crate::session::session_manager::SessionType; | ||
| use crate::session::SessionManager; | ||
| use async_trait::async_trait; | ||
| use test_case::test_case; | ||
|
|
||
| #[tokio::test] | ||
| async fn test_add_final_output_tool() -> Result<()> { | ||
|
|
@@ -1587,4 +1598,144 @@ mod tests { | |
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| enum NamingBehavior { | ||
| Success, | ||
| Error, | ||
| Panic, | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| struct MockNamingProvider { | ||
| model_config: ModelConfig, | ||
| behavior: Arc<std::sync::Mutex<NamingBehavior>>, | ||
| captured_session_id: Arc<std::sync::Mutex<Option<String>>>, | ||
| } | ||
|
|
||
| impl MockNamingProvider { | ||
| fn new(behavior: NamingBehavior) -> Self { | ||
| Self { | ||
| model_config: ModelConfig::new_or_fail("test-model"), | ||
| behavior: Arc::new(std::sync::Mutex::new(behavior)), | ||
| captured_session_id: Arc::new(std::sync::Mutex::new(None)), | ||
| } | ||
| } | ||
|
|
||
| fn get_captured_session_id(&self) -> Option<String> { | ||
| self.captured_session_id.lock().unwrap().clone() | ||
| } | ||
| } | ||
|
|
||
| #[async_trait] | ||
| impl Provider for MockNamingProvider { | ||
| fn metadata() -> crate::providers::base::ProviderMetadata { | ||
| crate::providers::base::ProviderMetadata::empty() | ||
| } | ||
|
|
||
| fn get_name(&self) -> &str { | ||
| "mock" | ||
| } | ||
|
|
||
| async fn complete_with_model( | ||
| &self, | ||
| _model_config: &ModelConfig, | ||
| _system: &str, | ||
| _messages: &[Message], | ||
| _tools: &[Tool], | ||
| ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { | ||
| Ok(( | ||
| Message::assistant().with_text("Response"), | ||
| ProviderUsage::new("mock".to_string(), Default::default()), | ||
| )) | ||
| } | ||
|
|
||
| async fn generate_session_name( | ||
| &self, | ||
| _messages: &Conversation, | ||
| ) -> Result<String, ProviderError> { | ||
| *self.captured_session_id.lock().unwrap() = | ||
| crate::session_context::current_session_id(); | ||
|
|
||
| let behavior = self.behavior.lock().unwrap(); | ||
| match *behavior { | ||
| NamingBehavior::Success => Ok("Generated Name".to_string()), | ||
| NamingBehavior::Error => { | ||
| Err(ProviderError::RequestFailed("naming failed".to_string())) | ||
| } | ||
| NamingBehavior::Panic => panic!("naming panicked"), | ||
| } | ||
| } | ||
|
|
||
| fn get_model_config(&self) -> ModelConfig { | ||
| self.model_config.clone() | ||
| } | ||
| } | ||
|
|
||
| // Verifies session ID is visible in generate_session_name when user hasn't provided a name. | ||
| // When user provides name, maybe_update_name early-returns and session ID isn't captured. | ||
| #[test_case(NamingBehavior::Success, None, "Generated Name", true)] | ||
| #[test_case(NamingBehavior::Error, None, "initial", true)] | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. error and panic cases here are just to make sure our ignoring the same doesn't accidentally crash the critical code. we didn't have a problem so far, but it is easy to create one, so these cases help IMHO |
||
| #[test_case(NamingBehavior::Panic, None, "initial", true)] | ||
| #[test_case( | ||
| NamingBehavior::Success, | ||
| Some("my-custom-name"), | ||
| "my-custom-name", | ||
| false | ||
| )] | ||
| #[tokio::test] | ||
| async fn test_session_naming( | ||
| behavior: NamingBehavior, | ||
| user_provided_name: Option<&str>, | ||
| expected_name: &str, | ||
| should_capture_session_id: bool, | ||
| ) { | ||
| let provider = Arc::new(MockNamingProvider::new(behavior)); | ||
| let agent = Agent::new(); | ||
| agent.update_provider(provider.clone()).await.unwrap(); | ||
|
|
||
| let session = SessionManager::create_session( | ||
| std::env::current_dir().unwrap(), | ||
| user_provided_name.unwrap_or("initial").to_string(), | ||
| SessionType::User, | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| if let Some(name) = user_provided_name { | ||
| SessionManager::update_session(&session.id) | ||
| .user_provided_name(name) | ||
| .apply() | ||
| .await | ||
| .unwrap(); | ||
| } | ||
|
|
||
| let stream = agent | ||
| .reply( | ||
| Message::user().with_text("test"), | ||
| SessionConfig { | ||
| id: session.id.clone(), | ||
| schedule_id: None, | ||
| max_turns: Some(1), | ||
| retry_config: None, | ||
| }, | ||
| None, | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
| tokio::pin!(stream); | ||
| while stream.next().await.is_some() {} | ||
|
|
||
| let session = SessionManager::get_session(&session.id, false) | ||
| .await | ||
| .unwrap(); | ||
| assert_eq!(session.name, expected_name); | ||
|
|
||
| if should_capture_session_id { | ||
| assert_eq!(provider.get_captured_session_id(), Some(session.id.clone())); | ||
| } else { | ||
| assert_eq!(provider.get_captured_session_id(), None); | ||
| } | ||
|
|
||
| SessionManager::delete_session(&session.id).await.unwrap(); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.