Skip to content

Commit d7a1e55

Browse files
committed
feat: add configurable maxPayloadSize for WebSocket (#4955)
1 parent a9d1848 commit d7a1e55

12 files changed

Lines changed: 648 additions & 65 deletions

File tree

docs/docs/api/Client.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Returns: `Client`
2626
* **keepAliveTimeoutThreshold** `number | null` (optional) - Default: `2e3` - A number of milliseconds subtracted from server *keep-alive* hints when overriding `keepAliveTimeout` to account for timing inaccuracies caused by e.g. transport latency. Defaults to 2 seconds.
2727
* **maxHeaderSize** `number | null` (optional) - Default: `--max-http-header-size` or `16384` - The maximum length of request headers in bytes. Defaults to Node.js' --max-http-header-size or 16KiB.
2828
* **maxResponseSize** `number | null` (optional) - Default: `-1` - The maximum length of response body in bytes. Set to `-1` to disable.
29+
* **webSocket** `WebSocketOptions` (optional) - WebSocket-specific configuration options.
30+
* **maxPayloadSize** `number` (optional) - Default: `134217728` (128 MB) - Maximum allowed payload size in bytes for WebSocket messages. Applied to uncompressed messages, compressed frame payloads, and decompressed (permessage-deflate) messages. Set to 0 to disable the limit.
2931
* **pipelining** `number | null` (optional) - Default: `1` - The amount of concurrent requests to be sent over the single TCP/TLS connection according to [RFC7230](https://tools.ietf.org/html/rfc7230#section-6.3.2). Carefully consider your workload and environment before enabling concurrent requests as pipelining may reduce performance if used incorrectly. Pipelining is sensitive to network stack settings as well as head of line blocking caused by e.g. long running requests. Set to `0` to disable keep-alive connections.
3032
* **connect** `ConnectOptions | Function | null` (optional) - Default: `null`.
3133
* **strictContentLength** `Boolean` (optional) - Default: `true` - Whether to treat request content length mismatches as errors. If true, an error is thrown when the request content-length header doesn't match the length of the request body.

lib/dispatcher/agent.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ function defaultFactory (origin, opts) {
2424

2525
class Agent extends DispatcherBase {
2626
constructor ({ factory = defaultFactory, maxRedirections = 0, connect, ...options } = {}) {
27-
super()
2827

2928
if (typeof factory !== 'function') {
3029
throw new InvalidArgumentError('factory must be a function.')
@@ -38,6 +37,8 @@ class Agent extends DispatcherBase {
3837
throw new InvalidArgumentError('maxRedirections must be a positive number')
3938
}
4039

40+
super(options)
41+
4142
if (connect && typeof connect !== 'function') {
4243
connect = { ...connect }
4344
}

lib/dispatcher/client.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ class Client extends DispatcherBase {
106106
autoSelectFamilyAttemptTimeout,
107107
// h2
108108
maxConcurrentStreams,
109-
allowH2
109+
allowH2,
110+
webSocket
110111
} = {}) {
111-
super()
112+
super({ webSocket })
112113

113114
if (keepAlive !== undefined) {
114115
throw new InvalidArgumentError('unsupported keepAlive, use pipelining=0 instead')

lib/dispatcher/dispatcher-base.js

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@ const { kDestroy, kClose, kClosed, kDestroyed, kDispatch, kInterceptors } = requ
1111
const kOnDestroyed = Symbol('onDestroyed')
1212
const kOnClosed = Symbol('onClosed')
1313
const kInterceptedDispatch = Symbol('Intercepted Dispatch')
14+
const kWebSocketOptions = Symbol('webSocketOptions')
1415

1516
class DispatcherBase extends Dispatcher {
16-
constructor () {
17+
constructor (opts) {
1718
super()
1819

1920
this[kDestroyed] = false
2021
this[kOnDestroyed] = null
2122
this[kClosed] = false
2223
this[kOnClosed] = []
24+
this[kWebSocketOptions] = opts?.webSocket ?? {}
25+
}
26+
27+
get webSocketOptions () {
28+
return {
29+
maxPayloadSize: this[kWebSocketOptions].maxPayloadSize ?? 128 * 1024 * 1024
30+
}
2331
}
2432

2533
get destroyed () {

lib/dispatcher/pool-base.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ const kRemoveClient = Symbol('remove client')
1919
const kStats = Symbol('stats')
2020

2121
class PoolBase extends DispatcherBase {
22-
constructor () {
23-
super()
22+
constructor (opts) {
23+
super(opts)
2424

2525
this[kQueue] = new FixedQueue()
2626
this[kClients] = []

lib/dispatcher/pool.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class Pool extends PoolBase {
3737
allowH2,
3838
...options
3939
} = {}) {
40-
super()
41-
4240
if (connections != null && (!Number.isFinite(connections) || connections < 0)) {
4341
throw new InvalidArgumentError('invalid connections')
4442
}
@@ -63,6 +61,8 @@ class Pool extends PoolBase {
6361
})
6462
}
6563

64+
super(options)
65+
6666
this[kInterceptors] = options.interceptors?.Pool && Array.isArray(options.interceptors.Pool)
6767
? options.interceptors.Pool
6868
: []

lib/web/websocket/permessage-deflate.js

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,35 @@ const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
88
const kBuffer = Symbol('kBuffer')
99
const kLength = Symbol('kLength')
1010

11-
// Default maximum decompressed message size: 4 MB
12-
const kDefaultMaxDecompressedSize = 4 * 1024 * 1024
13-
1411
class PerMessageDeflate {
1512
/** @type {import('node:zlib').InflateRaw} */
1613
#inflate
1714

1815
#options = {}
1916

20-
/** @type {boolean} */
21-
#aborted = false
22-
23-
/** @type {Function|null} */
24-
#currentCallback = null
17+
#maxPayloadSize = 0
2518

2619
/**
2720
* @param {Map<string, string>} extensions
2821
*/
29-
constructor (extensions) {
22+
constructor (extensions, options) {
3023
this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover')
3124
this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits')
25+
26+
this.#maxPayloadSize = options.maxPayloadSize
3227
}
3328

29+
/**
30+
* Decompress a compressed payload.
31+
* @param {Buffer} chunk Compressed data
32+
* @param {boolean} fin Final fragment flag
33+
* @param {Function} callback Callback function
34+
*/
3435
decompress (chunk, fin, callback) {
3536
// An endpoint uses the following algorithm to decompress a message.
3637
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
3738
// payload of the message.
3839
// 2. Decompress the resulting data using DEFLATE.
39-
40-
if (this.#aborted) {
41-
callback(new MessageSizeExceededError())
42-
return
43-
}
44-
4540
if (!this.#inflate) {
4641
let windowBits = Z_DEFAULT_WINDOWBITS
4742

@@ -64,23 +59,12 @@ class PerMessageDeflate {
6459
this.#inflate[kLength] = 0
6560

6661
this.#inflate.on('data', (data) => {
67-
if (this.#aborted) {
68-
return
69-
}
70-
7162
this.#inflate[kLength] += data.length
7263

73-
if (this.#inflate[kLength] > kDefaultMaxDecompressedSize) {
74-
this.#aborted = true
64+
if (this.#maxPayloadSize > 0 && this.#inflate[kLength] > this.#maxPayloadSize) {
65+
callback(new MessageSizeExceededError())
7566
this.#inflate.removeAllListeners()
76-
this.#inflate.destroy()
7767
this.#inflate = null
78-
79-
if (this.#currentCallback) {
80-
const cb = this.#currentCallback
81-
this.#currentCallback = null
82-
cb(new MessageSizeExceededError())
83-
}
8468
return
8569
}
8670

@@ -93,22 +77,20 @@ class PerMessageDeflate {
9377
})
9478
}
9579

96-
this.#currentCallback = callback
9780
this.#inflate.write(chunk)
9881
if (fin) {
9982
this.#inflate.write(tail)
10083
}
10184

10285
this.#inflate.flush(() => {
103-
if (this.#aborted || !this.#inflate) {
86+
if (!this.#inflate) {
10487
return
10588
}
10689

10790
const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])
10891

10992
this.#inflate[kBuffer].length = 0
11093
this.#inflate[kLength] = 0
111-
this.#currentCallback = null
11294

11395
callback(null, full)
11496
})

lib/web/websocket/receiver.js

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const {
1818
const { WebsocketFrameSend } = require('./frame')
1919
const { closeWebSocketConnection } = require('./connection')
2020
const { PerMessageDeflate } = require('./permessage-deflate')
21+
const { MessageSizeExceededError } = require('../../core/errors')
2122

2223
// This code was influenced by ws released under the MIT license.
2324
// Copyright (c) 2011 Einar Otto Stangvik <[email protected]>
@@ -26,6 +27,7 @@ const { PerMessageDeflate } = require('./permessage-deflate')
2627

2728
class ByteParser extends Writable {
2829
#buffers = []
30+
#fragmentsBytes = 0
2931
#byteOffset = 0
3032
#loop = false
3133

@@ -37,18 +39,23 @@ class ByteParser extends Writable {
3739
/** @type {Map<string, PerMessageDeflate>} */
3840
#extensions
3941

42+
/** @type {number} */
43+
#maxPayloadSize
44+
4045
/**
4146
* @param {import('./websocket').WebSocket} ws
4247
* @param {Map<string, string>|null} extensions
48+
* @param {{ maxPayloadSize?: number }} [options]
4349
*/
44-
constructor (ws, extensions) {
50+
constructor (ws, extensions, options = {}) {
4551
super()
4652

4753
this.ws = ws
4854
this.#extensions = extensions == null ? new Map() : extensions
55+
this.#maxPayloadSize = options.maxPayloadSize ?? 0
4956

5057
if (this.#extensions.has('permessage-deflate')) {
51-
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
58+
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions, options))
5259
}
5360
}
5461

@@ -64,6 +71,19 @@ class ByteParser extends Writable {
6471
this.run(callback)
6572
}
6673

74+
#validatePayloadLength () {
75+
if (
76+
this.#maxPayloadSize > 0 &&
77+
!isControlFrame(this.#info.opcode) &&
78+
this.#info.payloadLength > this.#maxPayloadSize
79+
) {
80+
failWebsocketConnection(this.ws, 'Payload size exceeds maximum allowed size')
81+
return false
82+
}
83+
84+
return true
85+
}
86+
6787
/**
6888
* Runs whenever a new chunk is received.
6989
* Callback is called whenever there are no more chunks buffering,
@@ -152,6 +172,10 @@ class ByteParser extends Writable {
152172
if (payloadLength <= 125) {
153173
this.#info.payloadLength = payloadLength
154174
this.#state = parserStates.READ_DATA
175+
176+
if (!this.#validatePayloadLength()) {
177+
return
178+
}
155179
} else if (payloadLength === 126) {
156180
this.#state = parserStates.PAYLOADLENGTH_16
157181
} else if (payloadLength === 127) {
@@ -176,6 +200,10 @@ class ByteParser extends Writable {
176200

177201
this.#info.payloadLength = buffer.readUInt16BE(0)
178202
this.#state = parserStates.READ_DATA
203+
204+
if (!this.#validatePayloadLength()) {
205+
return
206+
}
179207
} else if (this.#state === parserStates.PAYLOADLENGTH_64) {
180208
if (this.#byteOffset < 8) {
181209
return callback()
@@ -198,6 +226,10 @@ class ByteParser extends Writable {
198226

199227
this.#info.payloadLength = lower
200228
this.#state = parserStates.READ_DATA
229+
230+
if (!this.#validatePayloadLength()) {
231+
return
232+
}
201233
} else if (this.#state === parserStates.READ_DATA) {
202234
if (this.#byteOffset < this.#info.payloadLength) {
203235
return callback()
@@ -210,42 +242,53 @@ class ByteParser extends Writable {
210242
this.#state = parserStates.INFO
211243
} else {
212244
if (!this.#info.compressed) {
213-
this.#fragments.push(body)
245+
this.writeFragments(body)
246+
247+
if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) {
248+
failWebsocketConnection(this.ws, new MessageSizeExceededError().message)
249+
return
250+
}
214251

215252
// If the frame is not fragmented, a message has been received.
216253
// If the frame is fragmented, it will terminate with a fin bit set
217254
// and an opcode of 0 (continuation), therefore we handle that when
218255
// parsing continuation frames, not here.
219256
if (!this.#info.fragmented && this.#info.fin) {
220-
const fullMessage = Buffer.concat(this.#fragments)
221-
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
222-
this.#fragments.length = 0
257+
websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments())
223258
}
224259

225260
this.#state = parserStates.INFO
226261
} else {
227-
this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
228-
if (error) {
229-
failWebsocketConnection(this.ws, error.message)
230-
return
231-
}
262+
this.#extensions.get('permessage-deflate').decompress(
263+
body,
264+
this.#info.fin,
265+
(error, data) => {
266+
if (error) {
267+
failWebsocketConnection(this.ws, error.message)
268+
return
269+
}
270+
271+
this.writeFragments(data)
272+
273+
if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) {
274+
failWebsocketConnection(this.ws, new MessageSizeExceededError().message)
275+
return
276+
}
277+
278+
if (!this.#info.fin) {
279+
this.#state = parserStates.INFO
280+
this.#loop = true
281+
this.run(callback)
282+
return
283+
}
284+
285+
websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments())
232286

233-
this.#fragments.push(data)
234-
235-
if (!this.#info.fin) {
236-
this.#state = parserStates.INFO
237287
this.#loop = true
288+
this.#state = parserStates.INFO
238289
this.run(callback)
239-
return
240290
}
241-
242-
websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments))
243-
244-
this.#loop = true
245-
this.#state = parserStates.INFO
246-
this.#fragments.length = 0
247-
this.run(callback)
248-
})
291+
)
249292

250293
this.#loop = false
251294
break
@@ -297,6 +340,26 @@ class ByteParser extends Writable {
297340
return buffer
298341
}
299342

343+
writeFragments (fragment) {
344+
this.#fragmentsBytes += fragment.length
345+
this.#fragments.push(fragment)
346+
}
347+
348+
consumeFragments () {
349+
const fragments = this.#fragments
350+
351+
if (fragments.length === 1) {
352+
this.#fragmentsBytes = 0
353+
return fragments.shift()
354+
}
355+
356+
const output = Buffer.concat(fragments, this.#fragmentsBytes)
357+
this.#fragments = []
358+
this.#fragmentsBytes = 0
359+
360+
return output
361+
}
362+
300363
parseCloseBody (data) {
301364
assert(data.length !== 1)
302365

0 commit comments

Comments
 (0)