Skip to content

Commit 691a506

Browse files
authored
feat(ai): Add tag to log requests made to cloud models while in hybrid mode (#9469)
1 parent 65a553b commit 691a506

File tree

6 files changed

+52
-7
lines changed

6 files changed

+52
-7
lines changed

.changeset/nasty-squids-push.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@firebase/ai': patch
3+
---
4+
5+
Internal: Add tag to log requests made to cloud while in hybrid mode.

packages/ai/src/api.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,15 @@ export function getGenerativeModel(
156156
hybridParams.onDeviceParams
157157
);
158158

159-
return new GenerativeModel(ai, inCloudParams, requestOptions, chromeAdapter);
159+
const generativeModel = new GenerativeModel(
160+
ai,
161+
inCloudParams,
162+
requestOptions,
163+
chromeAdapter
164+
);
165+
166+
generativeModel._apiSettings.inferenceMode = hybridParams.mode;
167+
return generativeModel;
160168
}
161169

162170
/**

packages/ai/src/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ export const PACKAGE_VERSION = version;
3232

3333
export const LANGUAGE_TAG = 'gl-js';
3434

35+
export const HYBRID_TAG = 'hybrid';
36+
3537
export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000;
3638

3739
/**

packages/ai/src/requests/request.test.ts

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import {
3030
} from './request';
3131
import { ApiSettings } from '../types/internal';
3232
import { DEFAULT_API_VERSION } from '../constants';
33-
import { AIErrorCode } from '../types';
33+
import { AIErrorCode, InferenceMode } from '../types';
3434
import { AIError } from '../errors';
3535
import { getMockResponse } from '../../test-utils/mock-response';
3636
import { VertexAIBackend } from '../backend';
@@ -139,10 +139,26 @@ describe('request methods', () => {
139139
stream: true,
140140
singleRequestOptions: undefined
141141
});
142-
it('adds client headers', async () => {
142+
it('adds client headers (no hybrid)', async () => {
143143
const headers = await getHeaders(fakeUrl);
144144
expect(headers.get('x-goog-api-client')).to.match(
145-
/gl-js\/[0-9\.]+ fire\/[0-9\.]+/
145+
/gl-js\/[0-9\.]+ fire\/[0-9\.]+$/
146+
);
147+
});
148+
it('adds client headers (if hybrid)', async () => {
149+
const fakeUrlWithHybrid = new RequestURL({
150+
model: 'models/model-name',
151+
task: Task.GENERATE_CONTENT,
152+
apiSettings: {
153+
...fakeApiSettings,
154+
inferenceMode: InferenceMode.PREFER_ON_DEVICE
155+
},
156+
stream: true,
157+
singleRequestOptions: undefined
158+
});
159+
const headers = await getHeaders(fakeUrlWithHybrid);
160+
expect(headers.get('x-goog-api-client')).to.match(
161+
/gl-js\/[0-9\.]+ fire\/[0-9\.]+ hybrid$/
146162
);
147163
});
148164
it('adds api key', async () => {

packages/ai/src/requests/request.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import { ApiSettings } from '../types/internal';
2121
import {
2222
DEFAULT_DOMAIN,
2323
DEFAULT_FETCH_TIMEOUT_MS,
24+
HYBRID_TAG,
2425
LANGUAGE_TAG,
2526
PACKAGE_VERSION
2627
} from '../constants';
2728
import { logger } from '../logger';
28-
import { BackendType } from '../public-types';
29+
import { BackendType, InferenceMode } from '../public-types';
2930

3031
export const TIMEOUT_EXPIRED_MESSAGE = 'Timeout has expired.';
3132
export const ABORT_ERROR_NAME = 'AbortError';
@@ -137,17 +138,28 @@ export class WebSocketUrl {
137138
/**
138139
* Log language and "fire/version" to x-goog-api-client
139140
*/
140-
function getClientHeaders(): string {
141+
function getClientHeaders(url: RequestURL): string {
141142
const loggingTags = [];
142143
loggingTags.push(`${LANGUAGE_TAG}/${PACKAGE_VERSION}`);
143144
loggingTags.push(`fire/${PACKAGE_VERSION}`);
145+
/**
146+
* No call would be made if ONLY_ON_DEVICE.
147+
* ONLY_IN_CLOUD does not indicate an intention to use hybrid.
148+
*/
149+
if (
150+
url.params.apiSettings.inferenceMode === InferenceMode.PREFER_ON_DEVICE ||
151+
url.params.apiSettings.inferenceMode === InferenceMode.PREFER_IN_CLOUD
152+
) {
153+
// No version
154+
loggingTags.push(HYBRID_TAG);
155+
}
144156
return loggingTags.join(' ');
145157
}
146158

147159
export async function getHeaders(url: RequestURL): Promise<Headers> {
148160
const headers = new Headers();
149161
headers.append('Content-Type', 'application/json');
150-
headers.append('x-goog-api-client', getClientHeaders());
162+
headers.append('x-goog-api-client', getClientHeaders(url));
151163
headers.append('x-goog-api-key', url.params.apiSettings.apiKey);
152164
if (url.params.apiSettings.automaticDataCollectionEnabled) {
153165
headers.append('X-Firebase-Appid', url.params.apiSettings.appId);

packages/ai/src/types/internal.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import { AppCheckTokenResult } from '@firebase/app-check-interop-types';
1919
import { FirebaseAuthTokenData } from '@firebase/auth-interop-types';
2020
import { Backend } from '../backend';
21+
import { InferenceMode } from './enums';
2122

2223
export * from './imagen/internal';
2324

@@ -33,4 +34,5 @@ export interface ApiSettings {
3334
backend: Backend;
3435
getAuthToken?: () => Promise<FirebaseAuthTokenData | null>;
3536
getAppCheckToken?: () => Promise<AppCheckTokenResult>;
37+
inferenceMode?: InferenceMode;
3638
}

0 commit comments

Comments
 (0)