Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 197 additions & 19 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,195 @@ pub async fn handle_tetrate_auth() -> anyhow::Result<()> {
Ok(())
}

async fn try_fetch_custom_provider_models(
provider_type: &str,
api_url: &str,
api_key: Option<&str>,
) -> Result<Vec<String>, anyhow::Error> {
use goose::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine};
use goose::model::ModelConfig;
use goose::providers::anthropic::AnthropicProvider;
use goose::providers::base::{ModelInfo, Provider};
use goose::providers::ollama::OllamaProvider;
use goose::providers::openai::OpenAiProvider;

let url = url::Url::parse(api_url)?;
let host = if let Some(port) = url.port() {
format!(
"{}://{}:{}",
url.scheme(),
url.host_str()
.ok_or_else(|| anyhow::anyhow!("Invalid URL"))?,
port
)
} else {
format!(
"{}://{}",
url.scheme(),
url.host_str()
.ok_or_else(|| anyhow::anyhow!("Invalid URL"))?
)
};
let base_path = url.path().trim_start_matches('/').to_string();

let key = api_key.unwrap_or("not-required").to_string();
let config = Config::global();
config.set_secret("TEMP_FETCH_KEY", &key)?;

let temp_config = DeclarativeProviderConfig {
name: "temp_fetch".to_string(),
engine: match provider_type {
"openai_compatible" => ProviderEngine::OpenAI,
"anthropic_compatible" => ProviderEngine::Anthropic,
"ollama_compatible" => ProviderEngine::Ollama,
_ => return Err(anyhow::anyhow!("Invalid provider type")),
},
display_name: "Temporary".to_string(),
description: None,
api_key_env: "TEMP_FETCH_KEY".to_string(),
base_url: format!("{}/{}", host, base_path),
models: vec![ModelInfo::new("temp", 128000)],
headers: None,
timeout_seconds: Some(30),
supports_streaming: None,
};

let model_config = ModelConfig::new("temp")?;
let result = match temp_config.engine {
ProviderEngine::OpenAI => {
let provider = OpenAiProvider::from_custom_config(model_config, temp_config)?;
provider.fetch_supported_models().await
}
ProviderEngine::Anthropic => {
let provider = AnthropicProvider::from_custom_config(model_config, temp_config)?;
provider.fetch_supported_models().await
}
ProviderEngine::Ollama => {
let provider = OllamaProvider::from_custom_config(model_config, temp_config)?;
provider.fetch_supported_models().await
}
};

let _ = config.delete_secret("TEMP_FETCH_KEY");

match result {
Ok(Some(models)) if !models.is_empty() => Ok(models),
Ok(_) => Err(anyhow::anyhow!("No models returned")),
Err(e) => Err(anyhow::anyhow!("Provider error: {}", e)),
}
}

fn fetch_models_with_retry(
provider_type: &str,
api_url: &str,
) -> anyhow::Result<(Option<Vec<String>>, String)> {
let spin = spinner();
spin.start("Attempting to fetch available models...");

let mut models: Option<Vec<String>> = None;
let mut api_key = String::new();

match tokio::runtime::Runtime::new().unwrap().block_on(
try_fetch_custom_provider_models(provider_type, api_url, None),
) {
Ok(fetched_models) => {
spin.stop(style("Models fetched successfully").green());
models = Some(fetched_models);
}
Err(e) => {
spin.stop(style(format!("Could not fetch models: {}", e)).yellow());
let _ = cliclack::log::info("You may need to provide an API key to fetch models");

let should_retry =
cliclack::confirm("Would you like to provide an API key and try again?")
.initial_value(true)
.interact()?;

if should_retry {
api_key = cliclack::password("API key:").mask('▪').interact()?;

let spin = spinner();
spin.start("Retrying with API key...");

match tokio::runtime::Runtime::new().unwrap().block_on(
try_fetch_custom_provider_models(provider_type, api_url, Some(&api_key)),
) {
Ok(fetched_models) => {
spin.stop(style("Models fetched successfully").green());
models = Some(fetched_models);
}
Err(e) => {
spin.stop(
style(format!("Still could not fetch models: {}", e)).yellow(),
);
let _ = cliclack::log::warning("You will need to enter models manually");
}
}
}
}
}

Ok((models, api_key))
}

fn select_models_interactive(fetched_models: Option<Vec<String>>) -> anyhow::Result<Vec<String>> {
if let Some(fetched_models) = fetched_models {
let mut items = vec![(
"__manual__".to_string(),
"Enter models manually...".to_string(),
"Manually specify model names",
)];
for model in &fetched_models {
items.push((model.clone(), model.clone(), ""));
}

let selection = cliclack::select("Select a model or enter manually:")
.items(
&items
.iter()
.map(|(k, v, d)| (k.as_str(), v.as_str(), *d))
.collect::<Vec<_>>(),
)
.interact()?;

if selection == "__manual__" {
let models_input: String = cliclack::input("Available models (separate with commas):")
.placeholder("model-a, model-b, model-c")
.validate(|input: &String| {
if input.trim().is_empty() {
Err("Please enter at least one model name")
} else {
Ok(())
}
})
.interact()?;
Ok(models_input
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
} else {
Ok(vec![selection.to_string()])
}
} else {
let models_input: String = cliclack::input("Available models (separate with commas):")
.placeholder("model-a, model-b, model-c")
.validate(|input: &String| {
if input.trim().is_empty() {
Err("Please enter at least one model name")
} else {
Ok(())
}
})
.interact()?;
Ok(models_input
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
}
}

fn add_provider() -> anyhow::Result<()> {
let provider_type = cliclack::select("What type of API is this?")
.item(
Expand Down Expand Up @@ -1819,27 +2008,16 @@ fn add_provider() -> anyhow::Result<()> {
})
.interact()?;

let api_key: String = cliclack::password("API key:")
.allow_empty()
.mask('▪')
.interact()?;
let (fetched_models, mut api_key) = fetch_models_with_retry(provider_type, &api_url)?;

let models_input: String = cliclack::input("Available models (seperate with commas):")
.placeholder("model-a, model-b, model-c")
.validate(|input: &String| {
if input.trim().is_empty() {
Err("Please enter at least one model name")
} else {
Ok(())
}
})
.interact()?;
if api_key.is_empty() && fetched_models.is_some() {
api_key = cliclack::password("API key (optional, press Enter to skip):")
.allow_empty()
.mask('▪')
.interact()?;
}

let models: Vec<String> = models_input
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
let models = select_models_interactive(fetched_models)?;

let supports_streaming = cliclack::confirm("Does this provider support streaming responses?")
.initial_value(true)
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::config_management::read_all_config,
super::routes::config_management::providers,
super::routes::config_management::get_provider_models,
super::routes::config_management::fetch_custom_provider_models,
super::routes::config_management::upsert_permissions,
super::routes::config_management::create_custom_provider,
super::routes::config_management::get_custom_provider,
Expand Down Expand Up @@ -394,6 +395,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::config_management::ToolPermission,
super::routes::config_management::UpsertPermissionsQuery,
super::routes::config_management::UpdateCustomProviderRequest,
super::routes::config_management::FetchModelsRequest,
super::routes::reply::PermissionConfirmationRequest,
super::routes::reply::ChatRequest,
super::routes::session::ImportSessionRequest,
Expand Down
130 changes: 130 additions & 0 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ pub struct UpdateCustomProviderRequest {
pub supports_streaming: Option<bool>,
}

#[derive(Deserialize, ToSchema)]
pub struct FetchModelsRequest {
pub engine: String,
pub api_url: String,
pub api_key: Option<String>,
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct MaskedSecret {
Expand Down Expand Up @@ -378,6 +385,125 @@ pub async fn get_provider_models(
}
}

#[utoipa::path(
post,
path = "/config/providers/fetch-models",
request_body = FetchModelsRequest,
responses(
(status = 200, description = "Models fetched successfully", body = [String]),
(status = 400, description = "Invalid request or authentication error"),
(status = 500, description = "Internal server error")
)
)]
pub async fn fetch_custom_provider_models(
Json(request): Json<FetchModelsRequest>,
) -> Result<Json<Vec<String>>, StatusCode> {
use goose::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine};
use goose::providers::anthropic::AnthropicProvider;
use goose::providers::base::{ModelInfo, Provider};
use goose::providers::ollama::OllamaProvider;
use goose::providers::openai::OpenAiProvider;
use reqwest::Url;

let url = Url::parse(&request.api_url).map_err(|_| StatusCode::BAD_REQUEST)?;

let host = if let Some(port) = url.port() {
format!(
"{}://{}:{}",
url.scheme(),
url.host_str().ok_or(StatusCode::BAD_REQUEST)?,
port
)
} else {
format!(
"{}://{}",
url.scheme(),
url.host_str().ok_or(StatusCode::BAD_REQUEST)?
)
};

let base_path = url.path().trim_start_matches('/').to_string();
let api_key = request
.api_key
.unwrap_or_else(|| "not-required".to_string());

let base_url = if base_path.is_empty() {
host.clone()
} else {
let url_with_path = format!("{}/{}", host, base_path);
if request.engine == "openai_compatible" && !base_path.contains("chat/completions") {
format!("{}/chat/completions", url_with_path)
} else {
url_with_path
}
};

let temp_config = DeclarativeProviderConfig {
name: "temp_fetch".to_string(),
engine: match request.engine.as_str() {
"openai_compatible" => ProviderEngine::OpenAI,
"anthropic_compatible" => ProviderEngine::Anthropic,
"ollama_compatible" => ProviderEngine::Ollama,
_ => return Err(StatusCode::BAD_REQUEST),
},
display_name: "Temporary".to_string(),
description: None,
api_key_env: "TEMP_API_KEY".to_string(),
base_url,
models: vec![ModelInfo::new("temp", 128000)],
headers: None,
timeout_seconds: Some(30),
supports_streaming: None,
};

// Temporarily set the API key in config
let config = Config::global();
config
.set_secret("TEMP_API_KEY", &api_key)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let model_config = ModelConfig::new("temp").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let result = match temp_config.engine {
ProviderEngine::OpenAI => {
let provider = OpenAiProvider::from_custom_config(model_config, temp_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
provider.fetch_supported_models().await
}
ProviderEngine::Anthropic => {
let provider = AnthropicProvider::from_custom_config(model_config, temp_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
provider.fetch_supported_models().await
}
ProviderEngine::Ollama => {
let provider = OllamaProvider::from_custom_config(model_config, temp_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
provider.fetch_supported_models().await
}
};

let _ = config.delete_secret("TEMP_API_KEY");

match result {
Ok(Some(models)) => Ok(Json(models)),
Ok(None) => Ok(Json(Vec::new())),
Err(provider_error) => {
use goose::providers::errors::ProviderError;
let status_code = match provider_error {
ProviderError::Authentication(_) => StatusCode::BAD_REQUEST,
ProviderError::UsageError(_) => StatusCode::BAD_REQUEST,
ProviderError::RateLimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};

tracing::warn!(
"Failed to fetch models for custom provider: {}",
provider_error
);
Err(status_code)
}
}
}

#[derive(Serialize, ToSchema)]
pub struct PricingData {
pub provider: String,
Expand Down Expand Up @@ -744,6 +870,10 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/extensions", post(add_extension))
.route("/config/extensions/{name}", delete(remove_extension))
.route("/config/providers", get(providers))
.route(
"/config/providers/fetch-models",
post(fetch_custom_provider_models),
)
.route("/config/providers/{name}/models", get(get_provider_models))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a new handler for this? we have /config/providers/{name}/models here that fetches the models for a particular provider. to make this work I think all you need to do is either assume that the custom models have model fetch in the same location as openai (which should be straightforward), or add a field to custom providers where the user can specify the path to fetch the models (which would be more general)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting models for custom providers used to work in older versions already, something changed in that area with the custom provider rewrite.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good point

.route("/config/pricing", post(get_pricing))
.route("/config/init", post(init_config))
Expand Down
Loading