@@ -32,6 +32,22 @@ pub struct FormulaScorer<'a> {
3232 defaults : HashMap < VariableId , Value > ,
3333}
3434
35+ pub trait FriendlyName {
36+ fn friendly_name ( ) -> & ' static str ;
37+ }
38+
39+ impl FriendlyName for ScoreType {
40+ fn friendly_name ( ) -> & ' static str {
41+ "number"
42+ }
43+ }
44+
45+ impl FriendlyName for GeoPoint {
46+ fn friendly_name ( ) -> & ' static str {
47+ "geo point"
48+ }
49+ }
50+
3551impl StructPayloadIndex {
3652 pub fn formula_scorer < ' s , ' q > (
3753 & ' s self ,
@@ -97,38 +113,29 @@ impl FormulaScorer<'_> {
97113 . map ( |score| score as ScoreType )
98114 } )
99115 . unwrap_or ( DEFAULT_SCORE ) ) ,
100- VariableId :: Payload ( path) => Ok ( self
101- . get_payload_value ( path, point_id)
102- . and_then ( |value| value. as_f64 ( ) )
103- . or_else ( || {
104- self . defaults
105- . get ( & VariableId :: Payload ( path. clone ( ) ) )
106- . and_then ( |value| value. as_f64 ( ) )
116+ VariableId :: Payload ( path) => {
117+ self . get_parsed_payload_value ( path, point_id, |value| {
118+ value
119+ . as_f64 ( )
120+ . map ( |value| value as ScoreType )
121+ . ok_or_else ( || "Value is not a number" )
107122 } )
108- . map ( |v| v as ScoreType )
109- . unwrap_or ( DEFAULT_SCORE ) ) ,
123+ }
110124 VariableId :: Condition ( id) => {
111125 let value = check_condition ( & self . condition_checkers [ * id] , point_id) ;
112126 let score = if value { 1.0 } else { 0.0 } ;
113127 Ok ( score)
114128 }
115129 } ,
116130 ParsedExpression :: GeoDistance { origin, key } => {
117- let value: GeoPoint = self
118- . get_payload_value ( key, point_id)
119- . and_then ( |value| serde_json:: from_value :: < GeoPoint > ( value) . ok ( ) )
120- . or_else ( || {
121- self . defaults
122- . get ( & VariableId :: Payload ( key. clone ( ) ) )
123- . and_then ( |value| serde_json:: from_value ( value. clone ( ) ) . ok ( ) )
124- } )
125- . ok_or_else ( || OperationError :: VariableTypeError {
126- field_name : key. clone ( ) ,
127- expected_type : "geo point" . into ( ) ,
128- } ) ?;
131+ let value = self . get_parsed_payload_value (
132+ key,
133+ point_id,
134+ serde_json:: from_value :: < GeoPoint > ,
135+ ) ?;
129136
130137 Ok ( Haversine :: distance ( ( * origin) . into ( ) , value. into ( ) ) as ScoreType )
131- } ,
138+ }
132139 ParsedExpression :: Mult ( expressions) => {
133140 let mut product = 1.0 ;
134141 for expr in expressions {
@@ -238,6 +245,39 @@ impl FormulaScorer<'_> {
238245 . get ( json_path)
239246 . and_then ( |retriever| retriever ( point_id) )
240247 }
248+
249+ /// Tries to get a value from payload or from the defaults. Then tries to convert it to the desired type.
250+ fn get_parsed_payload_value < T , F , E > (
251+ & self ,
252+ json_path : & JsonPath ,
253+ point_id : PointOffsetType ,
254+ from_value : F ,
255+ ) -> OperationResult < T >
256+ where
257+ F : Fn ( Value ) -> Result < T , E > ,
258+ E : ToString ,
259+ T : FriendlyName ,
260+ {
261+ self . get_payload_value ( json_path, point_id)
262+ . or_else ( || {
263+ self . defaults
264+ . get ( & VariableId :: Payload ( json_path. clone ( ) ) )
265+ . cloned ( )
266+ } )
267+ . map ( |value| {
268+ from_value ( value) . map_err ( |e| OperationError :: VariableTypeError {
269+ field_name : json_path. clone ( ) ,
270+ expected_type : T :: friendly_name ( ) . to_owned ( ) ,
271+ description : e. to_string ( ) ,
272+ } )
273+ } )
274+ . transpose ( ) ?
275+ . ok_or_else ( || OperationError :: VariableTypeError {
276+ field_name : json_path. clone ( ) ,
277+ expected_type : T :: friendly_name ( ) . to_owned ( ) ,
278+ description : "No value found in a payload nor defaults" . to_string ( ) ,
279+ } )
280+ }
241281}
242282
243283#[ cfg( test) ]
@@ -324,7 +364,7 @@ mod tests {
324364 GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( GEO_FIELD_NAME )
325365 ) , 21926.494 ) ]
326366 #[ should_panic(
327- expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point" } "#
367+ expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point", "#
328368 ) ]
329369 #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( FIELD_NAME ) ) , 0.0 ) ]
330370 #[ should_panic( expected = r#"NonFiniteNumber { expression: "-1^0.4 = NaN" }"# ) ]
@@ -357,23 +397,30 @@ mod tests {
357397 // Default values
358398 #[ rstest]
359399 // Defined default score
360- #[ case( ParsedExpression :: new_score_id( 3 ) , 1.5 ) ]
400+ #[ case( ParsedExpression :: new_score_id( 3 ) , Ok ( 1.5 ) ) ]
361401 // score idx not defined
362- #[ case( ParsedExpression :: new_score_id( 10 ) , DEFAULT_SCORE ) ]
402+ #[ case( ParsedExpression :: new_score_id( 10 ) , Ok ( DEFAULT_SCORE ) ) ]
363403 // missing value in payload
364404 #[ case(
365405 ParsedExpression :: new_payload_id( JsonPath :: new( NO_VALUE_FIELD_NAME ) ) ,
366- 85.0
406+ Ok ( 85.0 )
367407 ) ]
368408 // missing value and no default value provided
369409 #[ case(
370410 ParsedExpression :: new_payload_id( JsonPath :: new( "missing_field" ) ) ,
371- DEFAULT_SCORE
411+ Err ( OperationError :: VariableTypeError {
412+ field_name: JsonPath :: new( "missing_field" ) ,
413+ expected_type: ScoreType :: friendly_name( ) . to_string( ) ,
414+ description: "No value found in a payload nor defaults" . to_string( ) ,
415+ } )
372416 ) ]
373417 // geo distance with default value
374- #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( NO_VALUE_GEO_POINT ) ) , 90951.3 ) ]
418+ #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( NO_VALUE_GEO_POINT ) ) , Ok ( 90951.3 ) ) ]
375419 #[ test]
376- fn test_default_values ( #[ case] expr : ParsedExpression , #[ case] expected : ScoreType ) {
420+ fn test_default_values (
421+ #[ case] expr : ParsedExpression ,
422+ #[ case] expected : OperationResult < ScoreType > ,
423+ ) {
377424 let defaults = [
378425 ( VariableId :: Score ( 3 ) , json ! ( 1.5 ) ) ,
379426 (
@@ -392,6 +439,6 @@ mod tests {
392439
393440 let scorer = scorer_fixture. borrow_dependent ( ) ;
394441
395- assert_eq ! ( scorer. eval_expression( & expr, 0 ) . unwrap ( ) , expected) ;
442+ assert_eq ! ( scorer. eval_expression( & expr, 0 ) , expected) ;
396443 }
397444}
0 commit comments