Skip to content

Commit c373daa

Browse files
committed
Test a provider-oriented welcome screen
1 parent 038f830 commit c373daa

31 files changed

+455
-34
lines changed

packages/cloud/src/CloudService.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ export class CloudService extends EventEmitter<CloudServiceEvents> implements Di
178178

179179
// AuthService
180180

181-
public async login(landingPageSlug?: string): Promise<void> {
181+
public async login(landingPageSlug?: string, useProviderSignup: boolean = false): Promise<void> {
182182
this.ensureInitialized()
183-
return this.authService!.login(landingPageSlug)
183+
return this.authService!.login(landingPageSlug, useProviderSignup)
184184
}
185185

186186
public async logout(): Promise<void> {
@@ -245,9 +245,10 @@ export class CloudService extends EventEmitter<CloudServiceEvents> implements Di
245245
code: string | null,
246246
state: string | null,
247247
organizationId?: string | null,
248+
providerModel?: string | null,
248249
): Promise<void> {
249250
this.ensureInitialized()
250-
return this.authService!.handleCallback(code, state, organizationId)
251+
return this.authService!.handleCallback(code, state, organizationId, providerModel)
251252
}
252253

253254
public async switchOrganization(organizationId: string | null): Promise<void> {

packages/cloud/src/StaticTokenAuthService.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export class StaticTokenAuthService extends EventEmitter<AuthServiceEvents> impl
4747
this.emit("user-info", { userInfo: this.userInfo })
4848
}
4949

50-
public async login(): Promise<void> {
50+
public async login(_landingPageSlug?: string, _useProviderSignup?: boolean): Promise<void> {
5151
throw new Error("Authentication methods are disabled in StaticTokenAuthService")
5252
}
5353

@@ -59,6 +59,7 @@ export class StaticTokenAuthService extends EventEmitter<AuthServiceEvents> impl
5959
_code: string | null,
6060
_state: string | null,
6161
_organizationId?: string | null,
62+
_providerModel?: string | null,
6263
): Promise<void> {
6364
throw new Error("Authentication methods are disabled in StaticTokenAuthService")
6465
}

packages/cloud/src/WebAuthService.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,9 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
252252
* and opening the browser to the authorization URL.
253253
*
254254
* @param landingPageSlug Optional slug of a specific landing page (e.g., "supernova", "special-offer", etc.)
255+
* @param useProviderSignup If true, uses provider signup flow (/extension/provider-sign-up). If false, uses standard sign-in (/extension/sign-in). Defaults to false.
255256
*/
256-
public async login(landingPageSlug?: string): Promise<void> {
257+
public async login(landingPageSlug?: string, useProviderSignup: boolean = false): Promise<void> {
257258
try {
258259
const vscode = await importVscode()
259260

@@ -272,10 +273,12 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
272273
auth_redirect: `${vscode.env.uriScheme}://${publisher}.${name}`,
273274
})
274275

275-
// Use landing page URL if slug is provided, otherwise use default sign-in URL
276+
// Use landing page URL if slug is provided, otherwise use provider sign-up or sign-in URL based on parameter
276277
const url = landingPageSlug
277278
? `${getRooCodeApiUrl()}/l/${landingPageSlug}?${params.toString()}`
278-
: `${getRooCodeApiUrl()}/extension/sign-in?${params.toString()}`
279+
: useProviderSignup
280+
? `${getRooCodeApiUrl()}/extension/provider-sign-up?${params.toString()}`
281+
: `${getRooCodeApiUrl()}/extension/sign-in?${params.toString()}`
279282

280283
await vscode.env.openExternal(vscode.Uri.parse(url))
281284
} catch (error) {
@@ -294,11 +297,13 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
294297
* @param code The authorization code from the callback
295298
* @param state The state parameter from the callback
296299
* @param organizationId The organization ID from the callback (null for personal accounts)
300+
* @param providerModel The model ID selected during signup (optional)
297301
*/
298302
public async handleCallback(
299303
code: string | null,
300304
state: string | null,
301305
organizationId?: string | null,
306+
providerModel?: string | null,
302307
): Promise<void> {
303308
if (!code || !state) {
304309
const vscode = await importVscode()
@@ -326,6 +331,12 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
326331

327332
await this.storeCredentials(credentials)
328333

334+
// Store the provider model if provided
335+
if (providerModel) {
336+
await this.context.globalState.update("roo-provider-model", providerModel)
337+
this.log(`[auth] Stored provider model: ${providerModel}`)
338+
}
339+
329340
const vscode = await importVscode()
330341

331342
if (vscode) {

packages/cloud/src/__tests__/CloudService.test.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,22 @@ describe("CloudService", () => {
296296

297297
it("should delegate handleAuthCallback to AuthService", async () => {
298298
await cloudService.handleAuthCallback("code", "state")
299-
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", undefined)
299+
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", undefined, undefined)
300300
})
301301

302302
it("should delegate handleAuthCallback with organizationId to AuthService", async () => {
303303
await cloudService.handleAuthCallback("code", "state", "org_123")
304-
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", "org_123")
304+
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", "org_123", undefined)
305+
})
306+
307+
it("should delegate handleAuthCallback with providerModel to AuthService", async () => {
308+
await cloudService.handleAuthCallback("code", "state", "org_123", "xai/grok-code-fast-1")
309+
expect(mockAuthService.handleCallback).toHaveBeenCalledWith(
310+
"code",
311+
"state",
312+
"org_123",
313+
"xai/grok-code-fast-1",
314+
)
305315
})
306316

307317
it("should return stored organization ID from AuthService", () => {

packages/cloud/src/__tests__/WebAuthService.spec.ts

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ describe("WebAuthService", () => {
261261
)
262262
})
263263

264-
it("should use package.json values for redirect URI", async () => {
264+
it("should use package.json values for redirect URI with default sign-in endpoint", async () => {
265265
const mockOpenExternal = vi.fn()
266266
const vscode = await import("vscode")
267267
vi.mocked(vscode.env.openExternal).mockImplementation(mockOpenExternal)
@@ -281,6 +281,26 @@ describe("WebAuthService", () => {
281281
expect(calledUri.toString()).toBe(expectedUrl)
282282
})
283283

284+
it("should use provider signup URL when useProviderSignup is true", async () => {
285+
const mockOpenExternal = vi.fn()
286+
const vscode = await import("vscode")
287+
vi.mocked(vscode.env.openExternal).mockImplementation(mockOpenExternal)
288+
289+
await authService.login(undefined, true)
290+
291+
const expectedUrl =
292+
"https://api.test.com/extension/provider-sign-up?state=746573742d72616e646f6d2d6279746573&auth_redirect=vscode%3A%2F%2FRooVeterinaryInc.roo-cline"
293+
expect(mockOpenExternal).toHaveBeenCalledWith(
294+
expect.objectContaining({
295+
toString: expect.any(Function),
296+
}),
297+
)
298+
299+
// Verify the actual URL
300+
const calledUri = mockOpenExternal.mock.calls[0]?.[0]
301+
expect(calledUri.toString()).toBe(expectedUrl)
302+
})
303+
284304
it("should handle errors during login", async () => {
285305
vi.mocked(crypto.randomBytes).mockImplementation(() => {
286306
throw new Error("Crypto error")
@@ -351,6 +371,33 @@ describe("WebAuthService", () => {
351371
expect(mockShowInfo).toHaveBeenCalledWith("Successfully authenticated with Roo Code Cloud")
352372
})
353373

374+
it("should store provider model when provided in callback", async () => {
375+
const storedState = "valid-state"
376+
mockContext.globalState.get.mockReturnValue(storedState)
377+
378+
// Mock successful Clerk sign-in response
379+
const mockResponse = {
380+
ok: true,
381+
json: () =>
382+
Promise.resolve({
383+
response: { created_session_id: "session-123" },
384+
}),
385+
headers: {
386+
get: (header: string) => (header === "authorization" ? "Bearer token-123" : null),
387+
},
388+
}
389+
mockFetch.mockResolvedValue(mockResponse)
390+
391+
const vscode = await import("vscode")
392+
const mockShowInfo = vi.fn()
393+
vi.mocked(vscode.window.showInformationMessage).mockImplementation(mockShowInfo)
394+
395+
await authService.handleCallback("auth-code", storedState, null, "xai/grok-code-fast-1")
396+
397+
expect(mockContext.globalState.update).toHaveBeenCalledWith("roo-provider-model", "xai/grok-code-fast-1")
398+
expect(mockLog).toHaveBeenCalledWith("[auth] Stored provider model: xai/grok-code-fast-1")
399+
})
400+
354401
it("should handle Clerk API errors", async () => {
355402
const storedState = "valid-state"
356403
mockContext.globalState.get.mockReturnValue(storedState)

packages/types/src/cloud.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,14 @@ export interface AuthService extends EventEmitter<AuthServiceEvents> {
239239
broadcast(): void
240240

241241
// Authentication methods
242-
login(landingPageSlug?: string): Promise<void>
242+
login(landingPageSlug?: string, useProviderSignup?: boolean): Promise<void>
243243
logout(): Promise<void>
244-
handleCallback(code: string | null, state: string | null, organizationId?: string | null): Promise<void>
244+
handleCallback(
245+
code: string | null,
246+
state: string | null,
247+
organizationId?: string | null,
248+
providerModel?: string | null,
249+
): Promise<void>
245250
switchOrganization(organizationId: string | null): Promise<void>
246251

247252
// State methods

src/activate/handleUri.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ export const handleUri = async (uri: vscode.Uri) => {
4040
const code = query.get("code")
4141
const state = query.get("state")
4242
const organizationId = query.get("organizationId")
43+
const providerModel = query.get("provider_model")
4344

4445
await CloudService.instance.handleAuthCallback(
4546
code,
4647
state,
4748
organizationId === "null" ? null : organizationId,
49+
providerModel,
4850
)
4951
break
5052
}

src/core/webview/webviewMessageHandler.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,8 @@ export const webviewMessageHandler = async (
21462146
case "rooCloudSignIn": {
21472147
try {
21482148
TelemetryService.instance.captureEvent(TelemetryEventName.AUTHENTICATION_INITIATED)
2149-
await CloudService.instance.login()
2149+
// Use provider signup flow if useProviderSignup is explicitly true
2150+
await CloudService.instance.login(undefined, message.useProviderSignup ?? false)
21502151
} catch (error) {
21512152
provider.log(`AuthService#login failed: ${error}`)
21522153
vscode.window.showErrorMessage("Sign in failed.")

src/extension.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,31 @@ export async function activate(context: vscode.ExtensionContext) {
168168

169169
if (data.state === "active-session" || data.state === "logged-out") {
170170
await handleRooModelsCache()
171+
172+
// Apply stored provider model to API configuration if present
173+
if (data.state === "active-session") {
174+
try {
175+
const storedModel = context.globalState.get<string>("roo-provider-model")
176+
if (storedModel) {
177+
cloudLogger(`[authStateChangedHandler] Applying stored provider model: ${storedModel}`)
178+
// Get the current API configuration name
179+
const currentConfigName =
180+
provider.contextProxy.getGlobalState("currentApiConfigName") || "default"
181+
// Update it with the stored model using upsertProviderProfile
182+
await provider.upsertProviderProfile(currentConfigName, {
183+
apiProvider: "roo",
184+
apiModelId: storedModel,
185+
})
186+
// Clear the stored model after applying
187+
await context.globalState.update("roo-provider-model", undefined)
188+
cloudLogger(`[authStateChangedHandler] Applied and cleared stored provider model`)
189+
}
190+
} catch (error) {
191+
cloudLogger(
192+
`[authStateChangedHandler] Failed to apply stored provider model: ${error instanceof Error ? error.message : String(error)}`,
193+
)
194+
}
195+
}
171196
}
172197
}
173198

src/shared/WebviewMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ export interface WebviewMessage {
214214
upsellId?: string // For dismissUpsell
215215
list?: string[] // For dismissedUpsells response
216216
organizationId?: string | null // For organization switching
217+
useProviderSignup?: boolean // For rooCloudSignIn to use provider signup flow
217218
codeIndexSettings?: {
218219
// Global state settings
219220
codebaseIndexEnabled: boolean

0 commit comments

Comments
 (0)