Skip to content

Commit 264d7a7

Browse files
committed
alt api key
1 parent e12fb6f commit 264d7a7

14 files changed

Lines changed: 151 additions & 42 deletions

File tree

lib/collection/src/shards/channel_service.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,25 @@ pub struct ChannelService {
2727
pub current_rest_port: u16,
2828
/// Instance wide API key if configured, must be used with care.
2929
pub api_key: Option<String>,
30+
31+
/// Alternative API key, works the same as `api_key`. Intended for rolling key updates.
32+
pub alt_api_key: Option<String>,
3033
}
3134

3235
impl ChannelService {
3336
/// Construct a new channel service with the given REST port.
34-
pub fn new(current_rest_port: u16, api_key: Option<String>) -> Self {
37+
pub fn new(
38+
current_rest_port: u16,
39+
api_key: Option<String>,
40+
alt_api_key: Option<String>,
41+
) -> Self {
3542
Self {
3643
id_to_address: Default::default(),
3744
id_to_metadata: Default::default(),
3845
channel_pool: Default::default(),
3946
current_rest_port,
4047
api_key,
48+
alt_api_key,
4149
}
4250
}
4351

@@ -222,6 +230,7 @@ impl Default for ChannelService {
222230
channel_pool: Default::default(),
223231
current_rest_port: 6333,
224232
api_key: None,
233+
alt_api_key: None,
225234
}
226235
}
227236
}

lib/collection/src/shards/transfer/snapshot.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,21 @@ pub(super) async fn transfer_snapshot(
236236

237237
// Recover shard snapshot on remote
238238
log::trace!("Transferring and recovering shard {shard_id} snapshot on peer {remote_peer_id}");
239+
240+
// Since we are providing access to local instance, any of the API keys can be used
241+
let local_api_key = channel_service
242+
.api_key
243+
.as_deref()
244+
.or(channel_service.alt_api_key.as_deref());
245+
239246
remote_shard
240247
.recover_shard_snapshot_from_url(
241248
collection_id,
242249
shard_id,
243250
&shard_download_url,
244251
SnapshotPriority::ShardTransfer,
245252
// Provide API key here so the remote can access our snapshot
246-
channel_service.api_key.as_deref(),
253+
local_api_key,
247254
)
248255
.await
249256
.map_err(|err| {

lib/collection/tests/integration/common/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub async fn new_local_collection(
101101
Default::default(),
102102
CollectionShardDistribution::all_local(Some(config.params.shard_number.into()), 0),
103103
None,
104-
ChannelService::new(REST_PORT, None),
104+
ChannelService::new(REST_PORT, None, None),
105105
dummy_on_replica_failure(),
106106
dummy_request_shard_transfer(),
107107
dummy_abort_shard_transfer(),
@@ -136,7 +136,7 @@ pub async fn load_local_collection(
136136
path,
137137
snapshots_path,
138138
Default::default(),
139-
ChannelService::new(REST_PORT, None),
139+
ChannelService::new(REST_PORT, None, None),
140140
dummy_on_replica_failure(),
141141
dummy_request_shard_transfer(),
142142
dummy_abort_shard_transfer(),

lib/collection/tests/integration/continuous_snapshot_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async fn test_continuous_snapshot() {
8080
Arc::new(storage_config),
8181
shard_distribution,
8282
None,
83-
ChannelService::new(REST_PORT, None),
83+
ChannelService::new(REST_PORT, None, None),
8484
dummy_on_replica_failure(),
8585
dummy_request_shard_transfer(),
8686
dummy_abort_shard_transfer(),

lib/collection/tests/integration/snapshot_recovery_test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async fn _test_snapshot_and_recover_collection(node_type: NodeType) {
7777
Arc::new(storage_config),
7878
shard_distribution,
7979
None,
80-
ChannelService::new(REST_PORT, None),
80+
ChannelService::new(REST_PORT, None, None),
8181
dummy_on_replica_failure(),
8282
dummy_request_shard_transfer(),
8383
dummy_abort_shard_transfer(),
@@ -137,7 +137,7 @@ async fn _test_snapshot_and_recover_collection(node_type: NodeType) {
137137
recover_dir.path(),
138138
snapshots_path.path(),
139139
Default::default(),
140-
ChannelService::new(REST_PORT, None),
140+
ChannelService::new(REST_PORT, None, None),
141141
dummy_on_replica_failure(),
142142
dummy_request_shard_transfer(),
143143
dummy_abort_shard_transfer(),

lib/storage/tests/integration/alias_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ fn test_alias_operation() {
9191
update_runtime,
9292
general_runtime,
9393
ResourceBudget::default(),
94-
ChannelService::new(6333, None),
94+
ChannelService::new(6333, None, None),
9595
0,
9696
Some(propose_operation_sender),
9797
));

src/common/auth/mod.rs

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::Arc;
33
use collection::operations::shard_selector_internal::ShardSelectorInternal;
44
use collection::operations::types::ScrollRequestInternal;
55
use common::counter::hardware_accumulator::HwMeasurementAcc;
6+
use itertools::Itertools;
67
use segment::types::{WithPayloadInterface, WithVector};
78
use storage::content_manager::errors::StorageError;
89
use storage::content_manager::toc::TableOfContent;
@@ -24,12 +25,18 @@ pub struct AuthKeys {
2425
/// A key allowing Read or Write operations
2526
read_write: Option<String>,
2627

28+
/// Alternative to `read_write` key
29+
alt_read_write: Option<String>,
30+
2731
/// A key allowing Read operations
2832
read_only: Option<String>,
2933

3034
/// A JWT parser, based on the read_write key
3135
jwt_parser: Option<JwtParser>,
3236

37+
/// Alternative JWT parser, based on the alt_read_write key
38+
alt_jwt_parser: Option<JwtParser>,
39+
3340
/// Table of content, needed to do stateful validation of JWT
3441
toc: Arc<TableOfContent>,
3542
}
@@ -42,14 +49,20 @@ pub enum AuthError {
4249
}
4350

4451
impl AuthKeys {
45-
fn get_jwt_parser(service_config: &ServiceConfig) -> Option<JwtParser> {
52+
fn get_jwt_parser(service_config: &ServiceConfig) -> (Option<JwtParser>, Option<JwtParser>) {
4653
if service_config.jwt_rbac.unwrap_or_default() {
47-
service_config
48-
.api_key
49-
.as_ref()
50-
.map(|secret| JwtParser::new(secret))
54+
(
55+
service_config
56+
.api_key
57+
.as_ref()
58+
.map(|secret| JwtParser::new(secret)),
59+
service_config
60+
.alt_api_key
61+
.as_ref()
62+
.map(|secret| JwtParser::new(secret)),
63+
)
5164
} else {
52-
None
65+
(None, None)
5366
}
5467
}
5568

@@ -59,15 +72,22 @@ impl AuthKeys {
5972
pub fn try_create(service_config: &ServiceConfig, toc: Arc<TableOfContent>) -> Option<Self> {
6073
match (
6174
service_config.api_key.clone(),
75+
service_config.alt_api_key.clone(),
6276
service_config.read_only_api_key.clone(),
6377
) {
64-
(None, None) => None,
65-
(read_write, read_only) => Some(Self {
66-
read_write,
67-
read_only,
68-
jwt_parser: Self::get_jwt_parser(service_config),
69-
toc,
70-
}),
78+
(None, None, None) => None,
79+
(read_write, alt_read_write, read_only) => {
80+
let (jwt_parser, alt_jwt_parser) = Self::get_jwt_parser(service_config);
81+
82+
Some(Self {
83+
read_write,
84+
alt_read_write,
85+
read_only,
86+
jwt_parser,
87+
alt_jwt_parser,
88+
toc,
89+
})
90+
}
7191
}
7292
}
7393

@@ -98,13 +118,20 @@ impl AuthKeys {
98118
));
99119
}
100120

101-
if let Some(claims) = self.jwt_parser.as_ref().and_then(|p| p.decode(key)) {
121+
let (claims, errors): (Vec<_>, Vec<_>) =
122+
[self.jwt_parser.as_ref(), self.alt_jwt_parser.as_ref()]
123+
.into_iter()
124+
.flatten()
125+
.filter_map(|p| p.decode(key))
126+
.partition_result();
127+
128+
if let Some(claims) = claims.into_iter().next() {
102129
let Claims {
103130
sub,
104131
exp: _, // already validated on decoding
105132
access,
106133
value_exists,
107-
} = claims?;
134+
} = claims;
108135

109136
if let Some(value_exists) = value_exists {
110137
self.validate_value_exists(&value_exists).await?;
@@ -113,6 +140,12 @@ impl AuthKeys {
113140
return Ok((access, InferenceToken(sub)));
114141
}
115142

143+
// JTW parser exists, but can't decode the token
144+
if let Some(error) = errors.into_iter().next() {
145+
return Err(error);
146+
}
147+
148+
// No JTW parser configured
116149
Err(AuthError::Unauthorized(
117150
"Invalid API key or JWT".to_string(),
118151
))
@@ -167,8 +200,14 @@ impl AuthKeys {
167200
/// Check if a key is allowed to write
168201
#[inline]
169202
fn can_write(&self, key: &str) -> bool {
170-
self.read_write
203+
let can_write = self
204+
.read_write
205+
.as_ref()
206+
.is_some_and(|rw_key| ct_eq(rw_key, key));
207+
let alt_can_write = self
208+
.alt_read_write
171209
.as_ref()
172-
.is_some_and(|rw_key| ct_eq(rw_key, key))
210+
.is_some_and(|alt_rw_key| ct_eq(alt_rw_key, key));
211+
can_write || alt_can_write
173212
}
174213
}

src/consensus.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,7 @@ mod tests {
14861486
update_runtime,
14871487
general_runtime,
14881488
ResourceBudget::default(),
1489-
ChannelService::new(settings.service.http_port, None),
1489+
ChannelService::new(settings.service.http_port, None, None),
14901490
persistent_state.this_peer_id(),
14911491
Some(operation_sender.clone()),
14921492
);
@@ -1511,7 +1511,7 @@ mod tests {
15111511
6335,
15121512
ConsensusConfig::default(),
15131513
None,
1514-
ChannelService::new(settings.service.http_port, None),
1514+
ChannelService::new(settings.service.http_port, None, None),
15151515
handle.clone(),
15161516
false,
15171517
)

src/main.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,11 @@ fn main() -> anyhow::Result<()> {
348348

349349
// Channel service is used to manage connections between peers.
350350
// It allocates required number of channels and manages proper reconnection handling
351-
let mut channel_service =
352-
ChannelService::new(settings.service.http_port, settings.service.api_key.clone());
351+
let mut channel_service = ChannelService::new(
352+
settings.service.http_port,
353+
settings.service.api_key.clone(),
354+
settings.service.alt_api_key.clone(),
355+
);
353356

354357
if is_distributed_deployment {
355358
// We only need channel_service in case if cluster is enabled.

src/settings.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ pub struct ServiceConfig {
3535
#[serde(default)]
3636
pub verify_https_client_certificate: bool,
3737
pub api_key: Option<String>,
38+
39+
/// Same as `api_key`, can be used for rolling key rotation.
40+
pub alt_api_key: Option<String>,
41+
3842
pub read_only_api_key: Option<String>,
3943
#[serde(default)]
4044
pub jwt_rbac: Option<bool>,
@@ -308,14 +312,25 @@ impl Settings {
308312
// Using HMAC-SHA256, recommended secret size is 32 bytes
309313
const JWT_RECOMMENDED_SECRET_LENGTH: usize = 256 / 8;
310314

315+
let all_keys_are_empty = self.service.api_key.clone().unwrap_or_default().is_empty()
316+
&& self
317+
.service
318+
.alt_api_key
319+
.clone()
320+
.unwrap_or_default()
321+
.is_empty();
322+
323+
let any_api_key_is_short = self.service.api_key.clone().unwrap_or_default().len()
324+
< JWT_RECOMMENDED_SECRET_LENGTH
325+
|| self.service.alt_api_key.clone().unwrap_or_default().len()
326+
< JWT_RECOMMENDED_SECRET_LENGTH;
327+
311328
// Log if JWT RBAC is enabled but no API key is set
312329
if self.service.jwt_rbac.unwrap_or_default() {
313-
if self.service.api_key.clone().unwrap_or_default().is_empty() {
330+
if all_keys_are_empty {
314331
log::warn!("JWT RBAC configured but no API key set, JWT RBAC is not enabled")
315332
// Log if JWT RAC is enabled, API key is set but smaller than recommended size for JWT secret
316-
} else if self.service.api_key.clone().unwrap_or_default().len()
317-
< JWT_RECOMMENDED_SECRET_LENGTH
318-
{
333+
} else if any_api_key_is_short {
319334
log::warn!(
320335
"It is highly recommended to use an API key of {JWT_RECOMMENDED_SECRET_LENGTH} bytes when JWT RBAC is enabled",
321336
)

0 commit comments

Comments
 (0)