Skip to content

Commit 4149e6d

Browse files
authored
[score boosting] Error on unexpected type (#6187)
* helper for getting payload value * Error instead of silent default * fix clippy * fix openapi test
1 parent 1fff666 commit 4149e6d

3 files changed

Lines changed: 98 additions & 46 deletions

File tree

lib/segment/src/common/operation_error.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::utils::mem::Mem;
1414

1515
pub const PROCESS_CANCELLED_BY_SERVICE_MESSAGE: &str = "process cancelled by service";
1616

17-
#[derive(Error, Debug, Clone)]
17+
#[derive(Error, Debug, Clone, PartialEq)]
1818
#[error("{0}")]
1919
pub enum OperationError {
2020
#[error("Vector dimension error: expected dim: {expected_dim}, got {received_dim}")]
@@ -62,10 +62,13 @@ pub enum OperationError {
6262
"No appropriate index for faceting: `{key}`. Please create one to facet on this field. Check https://qdrant.tech/documentation/concepts/indexing/#payload-index to see which payload schemas support Match conditions"
6363
)]
6464
MissingMapIndexForFacet { key: String },
65-
#[error("The variable nor the default value for {field_name} is a {expected_type}")]
65+
#[error(
66+
"Expected {expected_type} value for {field_name} in the payload and/or in the formula defaults. Error: {description}"
67+
)]
6668
VariableTypeError {
6769
field_name: PayloadKeyType,
6870
expected_type: String,
71+
description: String,
6972
},
7073
#[error("The expression {expression} produced a non-finite number")]
7174
NonFiniteNumber { expression: String },

lib/segment/src/index/query_optimization/rescore_formula/formula_scorer.rs

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ pub struct FormulaScorer<'a> {
3232
defaults: HashMap<VariableId, Value>,
3333
}
3434

35+
pub trait FriendlyName {
36+
fn friendly_name() -> &'static str;
37+
}
38+
39+
impl FriendlyName for ScoreType {
40+
fn friendly_name() -> &'static str {
41+
"number"
42+
}
43+
}
44+
45+
impl FriendlyName for GeoPoint {
46+
fn friendly_name() -> &'static str {
47+
"geo point"
48+
}
49+
}
50+
3551
impl StructPayloadIndex {
3652
pub fn formula_scorer<'s, 'q>(
3753
&'s self,
@@ -97,24 +113,29 @@ impl FormulaScorer<'_> {
97113
.map(|score| score as ScoreType)
98114
})
99115
.unwrap_or(DEFAULT_SCORE)),
100-
VariableId::Payload(path) => Ok(self
101-
.payload_retrievers
102-
.get(path)
103-
.and_then(|retriever| retriever(point_id))
104-
.and_then(|value| value.as_f64())
105-
.or_else(|| {
106-
self.defaults
107-
.get(&VariableId::Payload(path.clone()))
108-
.and_then(|value| value.as_f64())
116+
VariableId::Payload(path) => {
117+
self.get_parsed_payload_value(path, point_id, |value| {
118+
value
119+
.as_f64()
120+
.map(|value| value as ScoreType)
121+
.ok_or("Value is not a number")
109122
})
110-
.map(|v| v as ScoreType)
111-
.unwrap_or(DEFAULT_SCORE)),
123+
}
112124
VariableId::Condition(id) => {
113125
let value = check_condition(&self.condition_checkers[*id], point_id);
114126
let score = if value { 1.0 } else { 0.0 };
115127
Ok(score)
116128
}
117129
},
130+
ParsedExpression::GeoDistance { origin, key } => {
131+
let value = self.get_parsed_payload_value(
132+
key,
133+
point_id,
134+
serde_json::from_value::<GeoPoint>,
135+
)?;
136+
137+
Ok(Haversine::distance((*origin).into(), value.into()) as ScoreType)
138+
}
118139
ParsedExpression::Mult(expressions) => {
119140
let mut product = 1.0;
120141
for expr in expressions {
@@ -216,26 +237,47 @@ impl FormulaScorer<'_> {
216237
let value = self.eval_expression(expr, point_id)?;
217238
Ok(value.abs())
218239
}
219-
ParsedExpression::GeoDistance { origin, key } => {
220-
let value: GeoPoint = self
221-
.payload_retrievers
222-
.get(key)
223-
.and_then(|retriever| retriever(point_id))
224-
.and_then(|value| serde_json::from_value(value).ok())
225-
.or_else(|| {
226-
self.defaults
227-
.get(&VariableId::Payload(key.clone()))
228-
.and_then(|value| serde_json::from_value(value.clone()).ok())
229-
})
230-
.ok_or_else(|| OperationError::VariableTypeError {
231-
field_name: key.clone(),
232-
expected_type: "geo point".into(),
233-
})?;
234-
235-
Ok(Haversine::distance((*origin).into(), value.into()) as ScoreType)
236-
}
237240
}
238241
}
242+
243+
fn get_payload_value(&self, json_path: &JsonPath, point_id: PointOffsetType) -> Option<Value> {
244+
self.payload_retrievers
245+
.get(json_path)
246+
.and_then(|retriever| retriever(point_id))
247+
}
248+
249+
/// Tries to get a value from payload or from the defaults. Then tries to convert it to the desired type.
250+
fn get_parsed_payload_value<T, F, E>(
251+
&self,
252+
json_path: &JsonPath,
253+
point_id: PointOffsetType,
254+
from_value: F,
255+
) -> OperationResult<T>
256+
where
257+
F: Fn(Value) -> Result<T, E>,
258+
E: ToString,
259+
T: FriendlyName,
260+
{
261+
self.get_payload_value(json_path, point_id)
262+
.or_else(|| {
263+
self.defaults
264+
.get(&VariableId::Payload(json_path.clone()))
265+
.cloned()
266+
})
267+
.map(|value| {
268+
from_value(value).map_err(|e| OperationError::VariableTypeError {
269+
field_name: json_path.clone(),
270+
expected_type: T::friendly_name().to_owned(),
271+
description: e.to_string(),
272+
})
273+
})
274+
.transpose()?
275+
.ok_or_else(|| OperationError::VariableTypeError {
276+
field_name: json_path.clone(),
277+
expected_type: T::friendly_name().to_owned(),
278+
description: "No value found in a payload nor defaults".to_string(),
279+
})
280+
}
239281
}
240282

241283
#[cfg(test)]
@@ -322,7 +364,7 @@ mod tests {
322364
GeoPoint { lat: 25.717877679163667, lon: -100.43383200156751 }, JsonPath::new(GEO_FIELD_NAME)
323365
), 21926.494)]
324366
#[should_panic(
325-
expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point" }"#
367+
expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point", "#
326368
)]
327369
#[case(ParsedExpression::new_geo_distance(GeoPoint { lat: 25.717877679163667, lon: -100.43383200156751 }, JsonPath::new(FIELD_NAME)), 0.0)]
328370
#[should_panic(expected = r#"NonFiniteNumber { expression: "-1^0.4 = NaN" }"#)]
@@ -355,23 +397,30 @@ mod tests {
355397
// Default values
356398
#[rstest]
357399
// Defined default score
358-
#[case(ParsedExpression::new_score_id(3), 1.5)]
400+
#[case(ParsedExpression::new_score_id(3), Ok(1.5))]
359401
// score idx not defined
360-
#[case(ParsedExpression::new_score_id(10), DEFAULT_SCORE)]
402+
#[case(ParsedExpression::new_score_id(10), Ok(DEFAULT_SCORE))]
361403
// missing value in payload
362404
#[case(
363405
ParsedExpression::new_payload_id(JsonPath::new(NO_VALUE_FIELD_NAME)),
364-
85.0
406+
Ok(85.0)
365407
)]
366408
// missing value and no default value provided
367409
#[case(
368410
ParsedExpression::new_payload_id(JsonPath::new("missing_field")),
369-
DEFAULT_SCORE
411+
Err(OperationError::VariableTypeError {
412+
field_name: JsonPath::new("missing_field"),
413+
expected_type: ScoreType::friendly_name().to_string(),
414+
description: "No value found in a payload nor defaults".to_string(),
415+
})
370416
)]
371417
// geo distance with default value
372-
#[case(ParsedExpression::new_geo_distance(GeoPoint { lat: 25.717877679163667, lon: -100.43383200156751 }, JsonPath::new(NO_VALUE_GEO_POINT)), 90951.3)]
418+
#[case(ParsedExpression::new_geo_distance(GeoPoint { lat: 25.717877679163667, lon: -100.43383200156751 }, JsonPath::new(NO_VALUE_GEO_POINT)), Ok(90951.3))]
373419
#[test]
374-
fn test_default_values(#[case] expr: ParsedExpression, #[case] expected: ScoreType) {
420+
fn test_default_values(
421+
#[case] expr: ParsedExpression,
422+
#[case] expected: OperationResult<ScoreType>,
423+
) {
375424
let defaults = [
376425
(VariableId::Score(3), json!(1.5)),
377426
(
@@ -390,6 +439,6 @@ mod tests {
390439

391440
let scorer = scorer_fixture.borrow_dependent();
392441

393-
assert_eq!(scorer.eval_expression(&expr, 0).unwrap(), expected);
442+
assert_eq!(scorer.eval_expression(&expr, 0), expected);
394443
}
395444
}

tests/openapi/test_query_formula.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_formula(collection_name, formula, expecting):
6161

6262
query = {
6363
"prefetch": {"query": point_id},
64-
"query": {"formula": formula},
64+
"query": {"formula": formula, "defaults": {"price": 0.0}},
6565
"with_payload": True,
6666
}
6767

@@ -77,9 +77,9 @@ def test_formula(collection_name, formula, expecting):
7777
# Assert that the response is in descending order
7878
points = response.json()["result"]["points"]
7979
scores = [point.get("score") for point in points]
80-
assert all(
81-
scores[i] >= scores[i + 1] for i in range(len(scores) - 1)
82-
), "Results should be ordered by score descending"
80+
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1)), (
81+
"Results should be ordered by score descending"
82+
)
8383

8484
# Sanity check that the evaluation was correct
8585
for point in points:
@@ -100,9 +100,9 @@ def test_formula(collection_name, formula, expecting):
100100
point_score = point.get("score")
101101

102102
# Compare with actual score within floating point precision
103-
assert isclose(
104-
point_score, expected_score, rel_tol=1e-5
105-
), f"Expected score {expected_score}, got {point_score}. Point: {point}"
103+
assert isclose(point_score, expected_score, rel_tol=1e-5), (
104+
f"Expected score {expected_score}, got {point_score}. Point: {point}"
105+
)
106106

107107
# Assert that the response contains all points
108108
assert len(points) == len(orig_scores), "Response should contain all points"

0 commit comments

Comments
 (0)