@@ -32,6 +32,22 @@ pub struct FormulaScorer<'a> {
3232 defaults : HashMap < VariableId , Value > ,
3333}
3434
35+ pub trait FriendlyName {
36+ fn friendly_name ( ) -> & ' static str ;
37+ }
38+
39+ impl FriendlyName for ScoreType {
40+ fn friendly_name ( ) -> & ' static str {
41+ "number"
42+ }
43+ }
44+
45+ impl FriendlyName for GeoPoint {
46+ fn friendly_name ( ) -> & ' static str {
47+ "geo point"
48+ }
49+ }
50+
3551impl StructPayloadIndex {
3652 pub fn formula_scorer < ' s , ' q > (
3753 & ' s self ,
@@ -97,24 +113,29 @@ impl FormulaScorer<'_> {
97113 . map ( |score| score as ScoreType )
98114 } )
99115 . unwrap_or ( DEFAULT_SCORE ) ) ,
100- VariableId :: Payload ( path) => Ok ( self
101- . payload_retrievers
102- . get ( path)
103- . and_then ( |retriever| retriever ( point_id) )
104- . and_then ( |value| value. as_f64 ( ) )
105- . or_else ( || {
106- self . defaults
107- . get ( & VariableId :: Payload ( path. clone ( ) ) )
108- . and_then ( |value| value. as_f64 ( ) )
116+ VariableId :: Payload ( path) => {
117+ self . get_parsed_payload_value ( path, point_id, |value| {
118+ value
119+ . as_f64 ( )
120+ . map ( |value| value as ScoreType )
121+ . ok_or ( "Value is not a number" )
109122 } )
110- . map ( |v| v as ScoreType )
111- . unwrap_or ( DEFAULT_SCORE ) ) ,
123+ }
112124 VariableId :: Condition ( id) => {
113125 let value = check_condition ( & self . condition_checkers [ * id] , point_id) ;
114126 let score = if value { 1.0 } else { 0.0 } ;
115127 Ok ( score)
116128 }
117129 } ,
130+ ParsedExpression :: GeoDistance { origin, key } => {
131+ let value = self . get_parsed_payload_value (
132+ key,
133+ point_id,
134+ serde_json:: from_value :: < GeoPoint > ,
135+ ) ?;
136+
137+ Ok ( Haversine :: distance ( ( * origin) . into ( ) , value. into ( ) ) as ScoreType )
138+ }
118139 ParsedExpression :: Mult ( expressions) => {
119140 let mut product = 1.0 ;
120141 for expr in expressions {
@@ -216,26 +237,47 @@ impl FormulaScorer<'_> {
216237 let value = self . eval_expression ( expr, point_id) ?;
217238 Ok ( value. abs ( ) )
218239 }
219- ParsedExpression :: GeoDistance { origin, key } => {
220- let value: GeoPoint = self
221- . payload_retrievers
222- . get ( key)
223- . and_then ( |retriever| retriever ( point_id) )
224- . and_then ( |value| serde_json:: from_value ( value) . ok ( ) )
225- . or_else ( || {
226- self . defaults
227- . get ( & VariableId :: Payload ( key. clone ( ) ) )
228- . and_then ( |value| serde_json:: from_value ( value. clone ( ) ) . ok ( ) )
229- } )
230- . ok_or_else ( || OperationError :: VariableTypeError {
231- field_name : key. clone ( ) ,
232- expected_type : "geo point" . into ( ) ,
233- } ) ?;
234-
235- Ok ( Haversine :: distance ( ( * origin) . into ( ) , value. into ( ) ) as ScoreType )
236- }
237240 }
238241 }
242+
243+ fn get_payload_value ( & self , json_path : & JsonPath , point_id : PointOffsetType ) -> Option < Value > {
244+ self . payload_retrievers
245+ . get ( json_path)
246+ . and_then ( |retriever| retriever ( point_id) )
247+ }
248+
249+ /// Tries to get a value from payload or from the defaults. Then tries to convert it to the desired type.
250+ fn get_parsed_payload_value < T , F , E > (
251+ & self ,
252+ json_path : & JsonPath ,
253+ point_id : PointOffsetType ,
254+ from_value : F ,
255+ ) -> OperationResult < T >
256+ where
257+ F : Fn ( Value ) -> Result < T , E > ,
258+ E : ToString ,
259+ T : FriendlyName ,
260+ {
261+ self . get_payload_value ( json_path, point_id)
262+ . or_else ( || {
263+ self . defaults
264+ . get ( & VariableId :: Payload ( json_path. clone ( ) ) )
265+ . cloned ( )
266+ } )
267+ . map ( |value| {
268+ from_value ( value) . map_err ( |e| OperationError :: VariableTypeError {
269+ field_name : json_path. clone ( ) ,
270+ expected_type : T :: friendly_name ( ) . to_owned ( ) ,
271+ description : e. to_string ( ) ,
272+ } )
273+ } )
274+ . transpose ( ) ?
275+ . ok_or_else ( || OperationError :: VariableTypeError {
276+ field_name : json_path. clone ( ) ,
277+ expected_type : T :: friendly_name ( ) . to_owned ( ) ,
278+ description : "No value found in a payload nor defaults" . to_string ( ) ,
279+ } )
280+ }
239281}
240282
241283#[ cfg( test) ]
@@ -322,7 +364,7 @@ mod tests {
322364 GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( GEO_FIELD_NAME )
323365 ) , 21926.494 ) ]
324366 #[ should_panic(
325- expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point" } "#
367+ expected = r#"VariableTypeError { field_name: JsonPath { first_key: "number", rest: [] }, expected_type: "geo point", "#
326368 ) ]
327369 #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( FIELD_NAME ) ) , 0.0 ) ]
328370 #[ should_panic( expected = r#"NonFiniteNumber { expression: "-1^0.4 = NaN" }"# ) ]
@@ -355,23 +397,30 @@ mod tests {
355397 // Default values
356398 #[ rstest]
357399 // Defined default score
358- #[ case( ParsedExpression :: new_score_id( 3 ) , 1.5 ) ]
400+ #[ case( ParsedExpression :: new_score_id( 3 ) , Ok ( 1.5 ) ) ]
359401 // score idx not defined
360- #[ case( ParsedExpression :: new_score_id( 10 ) , DEFAULT_SCORE ) ]
402+ #[ case( ParsedExpression :: new_score_id( 10 ) , Ok ( DEFAULT_SCORE ) ) ]
361403 // missing value in payload
362404 #[ case(
363405 ParsedExpression :: new_payload_id( JsonPath :: new( NO_VALUE_FIELD_NAME ) ) ,
364- 85.0
406+ Ok ( 85.0 )
365407 ) ]
366408 // missing value and no default value provided
367409 #[ case(
368410 ParsedExpression :: new_payload_id( JsonPath :: new( "missing_field" ) ) ,
369- DEFAULT_SCORE
411+ Err ( OperationError :: VariableTypeError {
412+ field_name: JsonPath :: new( "missing_field" ) ,
413+ expected_type: ScoreType :: friendly_name( ) . to_string( ) ,
414+ description: "No value found in a payload nor defaults" . to_string( ) ,
415+ } )
370416 ) ]
371417 // geo distance with default value
372- #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( NO_VALUE_GEO_POINT ) ) , 90951.3 ) ]
418+ #[ case( ParsedExpression :: new_geo_distance( GeoPoint { lat: 25.717877679163667 , lon: -100.43383200156751 } , JsonPath :: new( NO_VALUE_GEO_POINT ) ) , Ok ( 90951.3 ) ) ]
373419 #[ test]
374- fn test_default_values ( #[ case] expr : ParsedExpression , #[ case] expected : ScoreType ) {
420+ fn test_default_values (
421+ #[ case] expr : ParsedExpression ,
422+ #[ case] expected : OperationResult < ScoreType > ,
423+ ) {
375424 let defaults = [
376425 ( VariableId :: Score ( 3 ) , json ! ( 1.5 ) ) ,
377426 (
@@ -390,6 +439,6 @@ mod tests {
390439
391440 let scorer = scorer_fixture. borrow_dependent ( ) ;
392441
393- assert_eq ! ( scorer. eval_expression( & expr, 0 ) . unwrap ( ) , expected) ;
442+ assert_eq ! ( scorer. eval_expression( & expr, 0 ) , expected) ;
394443 }
395444}
0 commit comments