Skip to content

Commit d043cd1

Browse files
committed
Add inference usage in errors as well
1 parent 3dc8b0a commit d043cd1

1 file changed

Lines changed: 35 additions & 17 deletions

File tree

src/actix/api/query_api.rs

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use actix_web_validator::{Json, Path, Query};
33
use api::rest::models::{InferenceUsage, ModelUsage};
44
use api::rest::{QueryGroupsRequest, QueryRequest, QueryRequestBatch, QueryResponse};
55
use collection::operations::shard_selector_internal::ShardSelectorInternal;
6-
use collection::operations::universal_query::collection_query::CollectionQueryGroupsRequestWithUsage;
6+
use collection::operations::universal_query::collection_query::{
7+
CollectionQueryGroupsRequestWithUsage, CollectionQueryRequestWithUsage,
8+
};
79
use itertools::Itertools;
810
use storage::content_manager::collection_verification::{
911
check_strict_mode, check_strict_mode_batch,
@@ -159,7 +161,7 @@ async fn query_points_batch(
159161
shard_key,
160162
} = request_item;
161163

162-
let collection::operations::universal_query::collection_query::CollectionQueryRequestWithUsage { request, usage } =
164+
let CollectionQueryRequestWithUsage { request, usage } =
163165
convert_query_request_from_rest(internal, &inference_token).await?;
164166

165167
inference_usages.push(usage);
@@ -176,17 +178,17 @@ async fn query_points_batch(
176178
let mut total_usage = InferenceUsage::default();
177179
for inference_usage in inference_usages.iter_mut().flatten() {
178180
let usage = inference_usage;
179-
for (model, usage) in usage.models.iter() {
180-
total_usage
181-
.models
182-
.entry(model.clone())
183-
.and_modify(|e| {
184-
e.tokens += usage.tokens;
185-
})
186-
.or_insert_with(|| ModelUsage {
187-
tokens: usage.tokens,
188-
});
189-
}
181+
for (model, usage) in usage.models.iter() {
182+
total_usage
183+
.models
184+
.entry(model.clone())
185+
.and_modify(|e| {
186+
e.tokens += usage.tokens;
187+
})
188+
.or_insert_with(|| ModelUsage {
189+
tokens: usage.tokens,
190+
});
191+
}
190192
}
191193

192194
let inference_usage: Option<InferenceUsage> = {
@@ -204,7 +206,11 @@ async fn query_points_batch(
204206
&dispatcher,
205207
&access,
206208
)
207-
.await?;
209+
.await
210+
.map_err(|err| StorageError::InferenceError {
211+
description: err.to_string(),
212+
usage: inference_usage.clone(),
213+
})?;
208214

209215
let res = dispatcher
210216
.toc(&access, &pass)
@@ -216,7 +222,11 @@ async fn query_points_batch(
216222
params.timeout(),
217223
hw_measurement_acc,
218224
)
219-
.await?
225+
.await
226+
.map_err(|err| StorageError::InferenceError {
227+
description: err.to_string(),
228+
usage: inference_usage.clone(),
229+
})?
220230
.into_iter()
221231
.map(|response| QueryResponse {
222232
points: response
@@ -287,7 +297,11 @@ async fn query_points_groups(
287297
&dispatcher,
288298
&access,
289299
)
290-
.await?;
300+
.await
301+
.map_err(|err| StorageError::InferenceError {
302+
description: err.to_string(),
303+
usage: usage.clone(),
304+
})?;
291305

292306
let query_result = do_query_point_groups(
293307
dispatcher.toc(&access, &pass),
@@ -299,7 +313,11 @@ async fn query_points_groups(
299313
params.timeout(),
300314
hw_measurement_acc,
301315
)
302-
.await?;
316+
.await
317+
.map_err(|err| StorageError::InferenceError {
318+
description: err.to_string(),
319+
usage: usage.clone(),
320+
})?;
303321
Ok((query_result, usage))
304322
}
305323
.await;

0 commit comments

Comments
 (0)