Skip to content

Commit 1a0862e

Browse files
authored
Write rate limit batch update per point count (#6152)
1 parent f4d4d40 commit 1a0862e

2 files changed

Lines changed: 72 additions & 17 deletions

File tree

lib/collection/src/shards/replica_set/update.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::{ReplicaSetState, ReplicaState, ShardReplicaSet, clock_set};
1010
use crate::operations::point_ops::WriteOrdering;
1111
use crate::operations::types::{CollectionError, CollectionResult, UpdateResult, UpdateStatus};
1212
use crate::operations::{ClockTag, CollectionUpdateOperations, OperationWithClockTag};
13-
use crate::shards::shard::PeerId;
13+
use crate::shards::shard::{PeerId, Shard};
1414
use crate::shards::shard_trait::ShardOperation as _;
1515

1616
/// Maximum number of attempts for applying an update with a new clock.
@@ -54,8 +54,7 @@ impl ShardReplicaSet {
5454
let result = match state {
5555
ReplicaState::Active => {
5656
// Rate limit update operations on Active replica
57-
// TODO(ratelimits) determine cost of update based on operation
58-
self.check_write_rate_limiter(1, &hw_measurement)?;
57+
self.check_operation_write_rate_limiter(&hw_measurement, local, &operation)?;
5958
local.get().update(operation, wait, hw_measurement).await
6059
}
6160

@@ -296,19 +295,11 @@ impl ShardReplicaSet {
296295

297296
if self.peer_is_active(this_peer_id) {
298297
// Check write rate limiter before proceeding if replica active
299-
// TODO(ratelimits) determine cost of update based on operation
300-
301-
self.check_write_rate_limiter_lazy(&hw_measurement_acc, || {
302-
let mut ratelimiter_cost = 1;
303-
304-
// Estimate the cost based on affected points if filter is available.
305-
match local.estimate_request_cardinality(&operation.operation) {
306-
Ok(est) => ratelimiter_cost = 1.max(est.exp),
307-
Err(err) => log::error!("Estimating cardinality: {err:?}"),
308-
}
309-
310-
ratelimiter_cost
311-
})?;
298+
self.check_operation_write_rate_limiter(
299+
&hw_measurement_acc,
300+
local,
301+
&operation,
302+
)?;
312303
}
313304

314305
let operation = operation.clone();
@@ -541,6 +532,29 @@ impl ShardReplicaSet {
541532
Ok(Some(res))
542533
}
543534

535+
/// Check write rate limiter for the operation
536+
///
537+
/// Lazily compute the cost of the operation and check against the write rate limiter
538+
fn check_operation_write_rate_limiter(
539+
&self,
540+
hw_measurement: &HwMeasurementAcc,
541+
local: &Shard,
542+
operation: &OperationWithClockTag,
543+
) -> CollectionResult<()> {
544+
self.check_write_rate_limiter_lazy(hw_measurement, || {
545+
let mut ratelimiter_cost = 1;
546+
547+
// Estimate the cost based on affected points if filter is available.
548+
match local.estimate_request_cardinality(&operation.operation) {
549+
Ok(est) => ratelimiter_cost = 1.max(est.exp),
550+
Err(err) => log::error!("Estimating cardinality: {err:?}"),
551+
}
552+
553+
ratelimiter_cost
554+
})?;
555+
Ok(())
556+
}
557+
544558
/// Whether to send updates to the given peer
545559
///
546560
/// A peer in dead state, or a locally disabled peer, will not accept updates.

tests/openapi/test_strictmode.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def test_strict_mode_write_rate_limiting(collection_name):
10911091
assert response.ok, "Rate limiting should be disabled now"
10921092

10931093

1094-
def test_strict_mode_write_rate_limiting_update_op(collection_name):
1094+
def test_strict_mode_write_rate_limiting_filtered_update_op(collection_name):
10951095
set_strict_mode(collection_name, {
10961096
"enabled": True,
10971097
"write_rate_limit": 7,
@@ -1131,6 +1131,47 @@ def test_strict_mode_write_rate_limiting_update_op(collection_name):
11311131
assert response.status_code == 429
11321132
assert "Rate limiting exceeded: Write rate limit exceeded: Operation requires 5 tokens but only" in response.json()['status']['error']
11331133

1134+
def test_strict_mode_write_rate_limiting_batch_update_op(collection_name):
1135+
def upsert_points(ids: list[int]):
1136+
length = len(ids)
1137+
payloads = [{} for _ in range(length)]
1138+
vectors = [[1, 2, 3, 5] for _ in range(length)]
1139+
return request_with_validation(
1140+
api='/collections/{collection_name}/points/batch',
1141+
method="POST",
1142+
path_params={'collection_name': collection_name},
1143+
body={
1144+
"operations": [
1145+
{
1146+
"upsert": {
1147+
"batch": {
1148+
"ids": ids,
1149+
"payloads": payloads,
1150+
"vectors": vectors
1151+
}
1152+
}
1153+
}
1154+
]
1155+
}
1156+
)
1157+
1158+
set_strict_mode(collection_name, {
1159+
"enabled": True,
1160+
"write_rate_limit": 10,
1161+
})
1162+
1163+
# validate that updates with 11 points will never be allowed
1164+
response = upsert_points([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
1165+
assert response.status_code == 429
1166+
assert "Rate limiting exceeded: Write rate limit exceeded, request larger than than rate limiter capacity, please try to split your request" in response.json()['status']['error']
1167+
1168+
# validate that updates with 10 points is allowed because there are enough tokens for each point
1169+
upsert_points([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).raise_for_status()
1170+
1171+
# doing it again fails because we already consumed 10 tokens
1172+
response = upsert_points([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
1173+
assert response.status_code == 429
1174+
assert "Rate limiting exceeded: Write rate limit exceeded: Operation requires 10 tokens but only 0.0 were available. Retry after 60s" in response.json()['status']['error']
11341175

11351176
def test_filter_many_conditions(collection_name):
11361177
def search_request(condition_count: int):

0 commit comments

Comments
 (0)