Skip to content

Commit 19c9b60

Browse files
committed
feat: Implement session resumption and context window compression for live sessions.
1 parent 76e6821 commit 19c9b60

6 files changed

Lines changed: 192 additions & 9 deletions

File tree

packages/ai/src/methods/live-session.test.ts

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import {
2525
LiveServerContent,
2626
LiveServerGoingAwayNotice,
2727
LiveServerToolCall,
28-
LiveServerToolCallCancellation
28+
LiveServerToolCallCancellation,
29+
LiveSessionResumptionUpdate
2930
} from '../types';
3031
import { LiveSession } from './live-session';
3132
import { WebSocketHandler } from '../websocket';
@@ -83,7 +84,15 @@ describe('LiveSession', () => {
8384
beforeEach(() => {
8485
mockHandler = new MockWebSocketHandler();
8586
serverMessagesGenerator = mockHandler.listen();
86-
session = new LiveSession(mockHandler, serverMessagesGenerator);
87+
session = new LiveSession(
88+
mockHandler,
89+
serverMessagesGenerator,
90+
async (resumptionConfig) => {
91+
// mock reconnector that replaces the handler
92+
mockHandler = new MockWebSocketHandler();
93+
return mockHandler.listen();
94+
}
95+
);
8796
});
8897

8998
describe('send()', () => {
@@ -220,6 +229,27 @@ describe('LiveSession', () => {
220229
});
221230
});
222231

232+
describe('resumeSession()', () => {
233+
it('should close existing session and start a new one using reconnector', async () => {
234+
expect(session.isClosed).to.be.false;
235+
236+
const oldServerMessages = (session as any).serverMessages;
237+
await session.resumeSession({ handle: 'testHandle' });
238+
239+
expect(mockHandler.close).to.have.been.calledOnce;
240+
expect(session.isClosed).to.be.false;
241+
expect((session as any).serverMessages).to.not.equal(oldServerMessages);
242+
});
243+
244+
it('should throw if reconnector is not provided', async () => {
245+
const basicSession = new LiveSession(mockHandler, serverMessagesGenerator);
246+
await expect(basicSession.resumeSession()).to.be.rejectedWith(
247+
AIError,
248+
/resumeSession is not supported on this session/
249+
);
250+
});
251+
});
252+
223253
describe('receive()', () => {
224254
it('should correctly parse and transform all server message types', async () => {
225255
const receivePromise = (async () => {
@@ -242,14 +272,17 @@ describe('LiveSession', () => {
242272
mockHandler.simulateServerMessage({
243273
goAway: { timeLeft: '30s' }
244274
});
275+
mockHandler.simulateServerMessage({
276+
sessionResumptionUpdate: { newHandle: 'test', resumable: true, lastConsumedClientMessageIndex: 5 }
277+
});
245278
mockHandler.simulateServerMessage({
246279
serverContent: { turnComplete: true }
247280
});
248281
await new Promise<void>(r => setTimeout(() => r(), 10)); // Wait for the listener to process messages
249282
mockHandler.endStream();
250283

251284
const responses = await receivePromise;
252-
expect(responses).to.have.lengthOf(5);
285+
expect(responses).to.have.lengthOf(6);
253286
expect(responses[0]).to.deep.equal({
254287
type: LiveResponseType.SERVER_CONTENT,
255288
modelTurn: { parts: [{ text: 'response 1' }] }
@@ -267,6 +300,12 @@ describe('LiveSession', () => {
267300
timeLeft: 30
268301
} as LiveServerGoingAwayNotice);
269302
expect(responses[4]).to.deep.equal({
303+
type: LiveResponseType.SESSION_RESUMPTION_UPDATE,
304+
newHandle: 'test',
305+
resumable: true,
306+
lastConsumedClientMessageIndex: 5
307+
} as LiveSessionResumptionUpdate);
308+
expect(responses[5]).to.deep.equal({
270309
type: LiveResponseType.SERVER_CONTENT,
271310
turnComplete: true
272311
} as LiveServerContent);

packages/ai/src/methods/live-session.ts

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ import {
2424
LiveServerGoingAwayNotice,
2525
LiveServerToolCall,
2626
LiveServerToolCallCancellation,
27-
Part
27+
LiveSessionResumptionUpdate,
28+
Part,
29+
SessionResumptionConfig
2830
} from '../public-types';
2931
import { formatNewContent } from '../requests/request-helpers';
3032
import { AIError } from '../errors';
@@ -62,9 +64,36 @@ export class LiveSession {
6264
*/
6365
constructor(
6466
private webSocketHandler: WebSocketHandler,
65-
private serverMessages: AsyncGenerator<unknown>
67+
private serverMessages: AsyncGenerator<unknown>,
68+
private reconnector?: (sessionResumption?: SessionResumptionConfig) => Promise<AsyncGenerator<unknown>>
6669
) {}
6770

71+
/**
72+
* Resumes an existing live session with the server.
73+
*
74+
* This closes the current WebSocket connection and establishes a new one using
75+
* the same configuration (URI, headers, model, system instruction, tools, etc.)
76+
* as the original session.
77+
*
78+
* @param sessionResumption - The configuration for session resumption, such as the handle to the previous session state to restore.
79+
* @throws If the session resumption configuration is unsupported.
80+
*
81+
* @beta
82+
*/
83+
async resumeSession(
84+
sessionResumption?: SessionResumptionConfig
85+
): Promise<void> {
86+
if (!this.reconnector) {
87+
throw new AIError(
88+
AIErrorCode.UNSUPPORTED,
89+
'resumeSession is not supported on this session.'
90+
);
91+
}
92+
await this.close();
93+
this.isClosed = false;
94+
this.serverMessages = await this.reconnector(sessionResumption);
95+
}
96+
6897
/**
6998
* Sends content to the server.
7099
*
@@ -232,6 +261,7 @@ export class LiveSession {
232261
| LiveServerToolCall
233262
| LiveServerToolCallCancellation
234263
| LiveServerGoingAwayNotice
264+
| LiveSessionResumptionUpdate
235265
> {
236266
if (this.isClosed) {
237267
throw new AIError(
@@ -275,6 +305,18 @@ export class LiveSession {
275305
type: LiveResponseType.GOING_AWAY_NOTICE,
276306
timeLeft: parseDuration(notice.timeLeft)
277307
} as LiveServerGoingAwayNotice;
308+
} else if (LiveResponseType.SESSION_RESUMPTION_UPDATE in message) {
309+
yield {
310+
type: LiveResponseType.SESSION_RESUMPTION_UPDATE,
311+
...(
312+
message as {
313+
sessionResumptionUpdate: Omit<
314+
LiveSessionResumptionUpdate,
315+
'type'
316+
>;
317+
}
318+
).sessionResumptionUpdate
319+
} as LiveSessionResumptionUpdate;
278320
} else {
279321
logger.warn(
280322
`Received an unknown message type from the server: ${JSON.stringify(

packages/ai/src/models/live-generative-model.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
Content,
2626
LiveGenerationConfig,
2727
LiveModelParams,
28+
SessionResumptionConfig,
2829
Tool,
2930
ToolConfig
3031
} from '../public-types';
@@ -70,12 +71,28 @@ export class LiveGenerativeModel extends AIModel {
7071
/**
7172
* Starts a {@link LiveSession}.
7273
*
74+
* @param sessionResumption - Optional configuration for session resumption.
7375
* @returns A {@link LiveSession}.
7476
* @throws If the connection failed to be established with the server.
7577
*
7678
* @beta
7779
*/
78-
async connect(): Promise<LiveSession> {
80+
async connect(
81+
sessionResumption?: SessionResumptionConfig
82+
): Promise<LiveSession> {
83+
const serverMessages = await this._internalConnect(sessionResumption);
84+
return new LiveSession(
85+
this._webSocketHandler,
86+
serverMessages,
87+
async (resumptionConfig?: SessionResumptionConfig) => {
88+
return this._internalConnect(resumptionConfig);
89+
}
90+
);
91+
}
92+
93+
private async _internalConnect(
94+
sessionResumption?: SessionResumptionConfig
95+
): Promise<AsyncGenerator<unknown>> {
7996
const url = new WebSocketUrl(this._apiSettings);
8097
await this._webSocketHandler.connect(url.toString());
8198

@@ -102,7 +119,8 @@ export class LiveGenerativeModel extends AIModel {
102119
toolConfig: this.toolConfig,
103120
systemInstruction: this.systemInstruction,
104121
inputAudioTranscription,
105-
outputAudioTranscription
122+
outputAudioTranscription,
123+
sessionResumption
106124
}
107125
};
108126

@@ -125,7 +143,7 @@ export class LiveGenerativeModel extends AIModel {
125143
);
126144
}
127145

128-
return new LiveSession(this._webSocketHandler, serverMessages);
146+
return serverMessages;
129147
} catch (e) {
130148
// Ensure connection is closed on any setup error
131149
await this._webSocketHandler.close();

packages/ai/src/types/live-responses.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import {
2424
import {
2525
AudioTranscriptionConfig,
2626
LiveGenerationConfig,
27+
SessionResumptionConfig,
2728
Tool,
2829
ToolConfig
2930
} from './requests';
@@ -88,6 +89,7 @@ export interface _LiveClientSetup {
8889
systemInstruction?: string | Part | Content;
8990
inputAudioTranscription?: AudioTranscriptionConfig;
9091
outputAudioTranscription?: AudioTranscriptionConfig;
92+
sessionResumption?: SessionResumptionConfig;
9193
};
9294
}
9395

packages/ai/src/types/requests.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,63 @@ export interface LiveGenerationConfig {
203203
* "How are you today?", the model may transcribe that output across three messages, broken up as "How a", "re yo", "u today?".
204204
*/
205205
outputAudioTranscription?: AudioTranscriptionConfig;
206+
/**
207+
* The context window compression configuration.
208+
*
209+
* @beta
210+
*/
211+
contextWindowCompression?: ContextWindowCompressionConfig;
212+
}
213+
214+
/**
215+
* Configures the sliding window context compression mechanism.
216+
*
217+
* The context window will be truncated by keeping only a suffix of it.
218+
*
219+
* @beta
220+
*/
221+
export interface SlidingWindow {
222+
/**
223+
* The session reduction target, i.e., how many tokens we should keep.
224+
*/
225+
targetTokens?: number;
226+
}
227+
228+
/**
229+
* Enables context window compression to manage the model's context window.
230+
*
231+
* This mechanism prevents the context from exceeding a given length.
232+
*
233+
* @beta
234+
*/
235+
export interface ContextWindowCompressionConfig {
236+
/**
237+
* The number of tokens (before running a turn) that triggers the context
238+
* window compression.
239+
*/
240+
triggerTokens?: number;
241+
242+
/**
243+
* The sliding window compression mechanism.
244+
*/
245+
slidingWindow?: SlidingWindow;
246+
}
247+
248+
/**
249+
* Configuration for the session resumption mechanism.
250+
*
251+
* When included in the session setup, the server will send
252+
* {@link LiveSessionResumptionUpdate} messages in the response stream.
253+
*
254+
* @beta
255+
*/
256+
export interface SessionResumptionConfig {
257+
/**
258+
* The session resumption handle of the previous session to restore.
259+
*
260+
* If not present, a new session will be started.
261+
*/
262+
handle?: string;
206263
}
207264

208265
/**

packages/ai/src/types/responses.ts

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,30 @@ export interface LiveServerGoingAwayNotice {
629629
timeLeft: number;
630630
}
631631

632+
/**
633+
* An update of the session resumption state.
634+
*
635+
* This message is only sent if {@link SessionResumptionConfig} was set in the
636+
* session setup.
637+
*
638+
* @beta
639+
*/
640+
export interface LiveSessionResumptionUpdate {
641+
type: 'sessionResumptionUpdate';
642+
/**
643+
* The new handle that represents the state that can be resumed. Empty if `resumable` is false.
644+
*/
645+
newHandle?: string;
646+
/**
647+
* Indicates if the session can be resumed at this point.
648+
*/
649+
resumable?: boolean;
650+
/**
651+
* The index of the last client message that is included in the state represented by this update.
652+
*/
653+
lastConsumedClientMessageIndex?: number;
654+
}
655+
632656
/**
633657
* The types of responses that can be returned by {@link LiveSession.receive}.
634658
*
@@ -638,7 +662,8 @@ export const LiveResponseType = {
638662
SERVER_CONTENT: 'serverContent',
639663
TOOL_CALL: 'toolCall',
640664
TOOL_CALL_CANCELLATION: 'toolCallCancellation',
641-
GOING_AWAY_NOTICE: 'goingAwayNotice'
665+
GOING_AWAY_NOTICE: 'goingAwayNotice',
666+
SESSION_RESUMPTION_UPDATE: 'sessionResumptionUpdate'
642667
};
643668

644669
/**

0 commit comments

Comments
 (0)