Skip to content

Commit ec544ff

Browse files
authored
[score boosting] support datetime expressions (#6196)
* use f64 as internal expression score * allow datetime as expression * add test cases * fix schema, generate openapi and grpc docs * fmt * clippy * fix test precision * graceful handling of too large numbers, more tests * rename to DatetimeExpression for consistency * use datetime's timestamp as seconds * capture variable name when parsing * homogenize DateTime into Datetime * use interface distinction between datetime strings and payload keys * fix rebase * fix after rebase
1 parent dce4a94 commit ec544ff

13 files changed

Lines changed: 424 additions & 195 deletions

File tree

docs/grpc/docs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,6 +2715,8 @@ The JSON representation for `Value` is a JSON value.
27152715
| variable | [string](#string) | | Payload key or reference to score. |
27162716
| condition | [Condition](#qdrant-Condition) | | Payload condition. If true, becomes 1.0; otherwise 0.0 |
27172717
| geo_distance | [GeoDistance](#qdrant-GeoDistance) | | Geographic distance in meters |
2718+
| datetime | [string](#string) | | Date-time constant |
2719+
| datetime_key | [string](#string) | | Payload key with date-time values |
27182720
| mult | [MultExpression](#qdrant-MultExpression) | | Multiply |
27192721
| sum | [SumExpression](#qdrant-SumExpression) | | Sum |
27202722
| div | [DivExpression](#qdrant-DivExpression) | | Divide |

docs/redoc/master/openapi.json

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14470,6 +14470,15 @@
1447014470
{
1447114471
"$ref": "#/components/schemas/Condition"
1447214472
},
14473+
{
14474+
"$ref": "#/components/schemas/GeoDistance"
14475+
},
14476+
{
14477+
"$ref": "#/components/schemas/DatetimeExpression"
14478+
},
14479+
{
14480+
"$ref": "#/components/schemas/DatetimeKeyExpression"
14481+
},
1447314482
{
1447414483
"$ref": "#/components/schemas/MultExpression"
1447514484
},
@@ -14500,9 +14509,6 @@
1450014509
{
1450114510
"$ref": "#/components/schemas/LnExpression"
1450214511
},
14503-
{
14504-
"$ref": "#/components/schemas/GeoDistance"
14505-
},
1450614512
{
1450714513
"$ref": "#/components/schemas/LinDecayExpression"
1450814514
},
@@ -14514,6 +14520,55 @@
1451414520
}
1451514521
]
1451614522
},
14523+
"GeoDistance": {
14524+
"type": "object",
14525+
"required": [
14526+
"geo_distance"
14527+
],
14528+
"properties": {
14529+
"geo_distance": {
14530+
"$ref": "#/components/schemas/GeoDistanceParams"
14531+
}
14532+
}
14533+
},
14534+
"GeoDistanceParams": {
14535+
"type": "object",
14536+
"required": [
14537+
"origin",
14538+
"to"
14539+
],
14540+
"properties": {
14541+
"origin": {
14542+
"$ref": "#/components/schemas/GeoPoint"
14543+
},
14544+
"to": {
14545+
"description": "Payload field with the destination geo point",
14546+
"type": "string"
14547+
}
14548+
}
14549+
},
14550+
"DatetimeExpression": {
14551+
"type": "object",
14552+
"required": [
14553+
"datetime"
14554+
],
14555+
"properties": {
14556+
"datetime": {
14557+
"type": "string"
14558+
}
14559+
}
14560+
},
14561+
"DatetimeKeyExpression": {
14562+
"type": "object",
14563+
"required": [
14564+
"datetime_key"
14565+
],
14566+
"properties": {
14567+
"datetime_key": {
14568+
"type": "string"
14569+
}
14570+
}
14571+
},
1451714572
"MultExpression": {
1451814573
"type": "object",
1451914574
"required": [
@@ -14665,33 +14720,6 @@
1466514720
}
1466614721
}
1466714722
},
14668-
"GeoDistance": {
14669-
"type": "object",
14670-
"required": [
14671-
"geo_distance"
14672-
],
14673-
"properties": {
14674-
"geo_distance": {
14675-
"$ref": "#/components/schemas/GeoDistanceParams"
14676-
}
14677-
}
14678-
},
14679-
"GeoDistanceParams": {
14680-
"type": "object",
14681-
"required": [
14682-
"origin",
14683-
"to"
14684-
],
14685-
"properties": {
14686-
"origin": {
14687-
"$ref": "#/components/schemas/GeoPoint"
14688-
},
14689-
"to": {
14690-
"description": "Payload field with the destination geo point",
14691-
"type": "string"
14692-
}
14693-
}
14694-
},
1469514723
"LinDecayExpression": {
1469614724
"type": "object",
1469714725
"required": [

lib/api/src/grpc/conversions.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::time::Instant;
55
use chrono::{NaiveDateTime, Timelike};
66
use common::counter::hardware_accumulator::HwMeasurementAcc;
77
use common::counter::hardware_data::HardwareData;
8+
use common::types::ScoreType;
89
use itertools::Itertools;
910
use segment::common::operation_error::OperationError;
1011
use segment::data_types::index::{
@@ -13,7 +14,7 @@ use segment::data_types::index::{
1314
};
1415
use segment::data_types::{facets as segment_facets, vectors as segment_vectors};
1516
use segment::index::query_optimization::rescore_formula::parsed_formula::{
16-
DecayKind, ParsedExpression, ParsedFormula,
17+
DatetimeExpression, DecayKind, ParsedExpression, ParsedFormula,
1718
};
1819
use segment::types::{DateTimePayloadType, FloatPayloadType, default_quantization_ignore_value};
1920
use segment::vector_storage::query as segment_query;
@@ -2796,14 +2797,26 @@ fn unparse_expression(
27962797
use super::expression::Variant;
27972798

27982799
let variant = match expression {
2799-
ParsedExpression::Constant(c) => Variant::Constant(c),
2800+
ParsedExpression::Constant(c) => Variant::Constant(c as ScoreType),
28002801
ParsedExpression::Variable(variable_id) => match variable_id {
28012802
var_id @ VariableId::Score(_) => Variant::Variable(var_id.unparse()),
28022803
var_id @ VariableId::Payload(_) => Variant::Variable(var_id.unparse()),
28032804
VariableId::Condition(cond_idx) => {
28042805
Variant::Condition(Condition::from(conditions[cond_idx].clone()))
28052806
}
28062807
},
2808+
ParsedExpression::GeoDistance { origin, key } => Variant::GeoDistance(GeoDistance {
2809+
origin: Some(GeoPoint::from(origin)),
2810+
to: key.to_string(),
2811+
}),
2812+
ParsedExpression::Datetime(dt_expr) => match dt_expr {
2813+
DatetimeExpression::Constant(date_time_wrapper) => {
2814+
Variant::Datetime(date_time_wrapper.to_string())
2815+
}
2816+
DatetimeExpression::PayloadVariable(json_path) => {
2817+
Variant::DatetimeKey(json_path.to_string())
2818+
}
2819+
},
28072820
ParsedExpression::Mult(exprs) => Variant::Mult(MultExpression {
28082821
mult: exprs
28092822
.into_iter()
@@ -2826,7 +2839,7 @@ fn unparse_expression(
28262839
} => Variant::Div(Box::new(DivExpression {
28272840
left: Some(Box::new(unparse_expression(*left, conditions))),
28282841
right: Some(Box::new(unparse_expression(*right, conditions))),
2829-
by_zero_default,
2842+
by_zero_default: by_zero_default.map(|v| v as f32),
28302843
})),
28312844
ParsedExpression::Sqrt(expr) => {
28322845
Variant::Sqrt(Box::new(unparse_expression(*expr, conditions)))
@@ -2845,10 +2858,6 @@ fn unparse_expression(
28452858
ParsedExpression::Abs(expr) => {
28462859
Variant::Abs(Box::new(unparse_expression(*expr, conditions)))
28472860
}
2848-
ParsedExpression::GeoDistance { origin, key } => Variant::GeoDistance(GeoDistance {
2849-
origin: Some(GeoPoint::from(origin)),
2850-
to: key.to_string(),
2851-
}),
28522861
ParsedExpression::Decay {
28532862
kind,
28542863
target,

lib/api/src/grpc/proto/points.proto

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -583,19 +583,21 @@ message Expression {
583583
string variable = 2; // Payload key or reference to score.
584584
Condition condition = 3; // Payload condition. If true, becomes 1.0; otherwise 0.0
585585
GeoDistance geo_distance = 4; // Geographic distance in meters
586-
MultExpression mult = 5; // Multiply
587-
SumExpression sum = 6; // Sum
588-
DivExpression div = 7; // Divide
589-
Expression neg = 8; // Negate
590-
Expression abs = 9; // Absolute value
591-
Expression sqrt = 10; // Square root
592-
PowExpression pow = 11; // Power
593-
Expression exp = 12; // Exponential
594-
Expression log10 = 13; // Logarithm
595-
Expression ln = 14; // Natural logarithm
596-
DecayParamsExpression exp_decay = 15; // Exponential decay
597-
DecayParamsExpression gauss_decay = 16; // Gaussian decay
598-
DecayParamsExpression lin_decay = 17; // Linear decay
586+
string datetime = 5; // Date-time constant
587+
string datetime_key = 6; // Payload key with date-time values
588+
MultExpression mult = 7; // Multiply
589+
SumExpression sum = 8; // Sum
590+
DivExpression div = 9; // Divide
591+
Expression neg = 10; // Negate
592+
Expression abs = 11; // Absolute value
593+
Expression sqrt = 12; // Square root
594+
PowExpression pow = 13; // Power
595+
Expression exp = 14; // Exponential
596+
Expression log10 = 15; // Logarithm
597+
Expression ln = 16; // Natural logarithm
598+
DecayParamsExpression exp_decay = 17; // Exponential decay
599+
DecayParamsExpression gauss_decay = 18; // Gaussian decay
600+
DecayParamsExpression lin_decay = 19; // Linear decay
599601
}
600602
}
601603

lib/api/src/grpc/qdrant.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5182,7 +5182,7 @@ pub struct Formula {
51825182
pub struct Expression {
51835183
#[prost(
51845184
oneof = "expression::Variant",
5185-
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17"
5185+
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19"
51865186
)]
51875187
pub variant: ::core::option::Option<expression::Variant>,
51885188
}
@@ -5203,44 +5203,50 @@ pub mod expression {
52035203
/// Geographic distance in meters
52045204
#[prost(message, tag = "4")]
52055205
GeoDistance(super::GeoDistance),
5206+
/// Date-time constant
5207+
#[prost(string, tag = "5")]
5208+
Datetime(::prost::alloc::string::String),
5209+
/// Payload key with date-time values
5210+
#[prost(string, tag = "6")]
5211+
DatetimeKey(::prost::alloc::string::String),
52065212
/// Multiply
5207-
#[prost(message, tag = "5")]
5213+
#[prost(message, tag = "7")]
52085214
Mult(super::MultExpression),
52095215
/// Sum
5210-
#[prost(message, tag = "6")]
5216+
#[prost(message, tag = "8")]
52115217
Sum(super::SumExpression),
52125218
/// Divide
5213-
#[prost(message, tag = "7")]
5219+
#[prost(message, tag = "9")]
52145220
Div(::prost::alloc::boxed::Box<super::DivExpression>),
52155221
/// Negate
5216-
#[prost(message, tag = "8")]
5222+
#[prost(message, tag = "10")]
52175223
Neg(::prost::alloc::boxed::Box<super::Expression>),
52185224
/// Absolute value
5219-
#[prost(message, tag = "9")]
5225+
#[prost(message, tag = "11")]
52205226
Abs(::prost::alloc::boxed::Box<super::Expression>),
52215227
/// Square root
5222-
#[prost(message, tag = "10")]
5228+
#[prost(message, tag = "12")]
52235229
Sqrt(::prost::alloc::boxed::Box<super::Expression>),
52245230
/// Power
5225-
#[prost(message, tag = "11")]
5231+
#[prost(message, tag = "13")]
52265232
Pow(::prost::alloc::boxed::Box<super::PowExpression>),
52275233
/// Exponential
5228-
#[prost(message, tag = "12")]
5234+
#[prost(message, tag = "14")]
52295235
Exp(::prost::alloc::boxed::Box<super::Expression>),
52305236
/// Logarithm
5231-
#[prost(message, tag = "13")]
5237+
#[prost(message, tag = "15")]
52325238
Log10(::prost::alloc::boxed::Box<super::Expression>),
52335239
/// Natural logarithm
5234-
#[prost(message, tag = "14")]
5240+
#[prost(message, tag = "16")]
52355241
Ln(::prost::alloc::boxed::Box<super::Expression>),
52365242
/// Exponential decay
5237-
#[prost(message, tag = "15")]
5243+
#[prost(message, tag = "17")]
52385244
ExpDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
52395245
/// Gaussian decay
5240-
#[prost(message, tag = "16")]
5246+
#[prost(message, tag = "18")]
52415247
GaussDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
52425248
/// Linear decay
5243-
#[prost(message, tag = "17")]
5249+
#[prost(message, tag = "19")]
52445250
LinDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
52455251
}
52465252
}

lib/api/src/rest/schema.rs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ pub enum Expression {
640640
Constant(f32),
641641
Variable(String),
642642
Condition(Box<Condition>),
643+
GeoDistance(GeoDistance),
644+
Datetime(DatetimeExpression),
645+
DatetimeKey(DatetimeKeyExpression),
643646
Mult(MultExpression),
644647
Sum(SumExpression),
645648
Neg(NegExpression),
@@ -650,12 +653,34 @@ pub enum Expression {
650653
Exp(ExpExpression),
651654
Log10(Log10Expression),
652655
Ln(LnExpression),
653-
GeoDistance(GeoDistance),
654656
LinDecay(LinDecayExpression),
655657
ExpDecay(ExpDecayExpression),
656658
GaussDecay(GaussDecayExpression),
657659
}
658660

661+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
662+
pub struct GeoDistance {
663+
pub geo_distance: GeoDistanceParams,
664+
}
665+
666+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
667+
pub struct GeoDistanceParams {
668+
/// The origin geo point to measure from
669+
pub origin: GeoPoint,
670+
/// Payload field with the destination geo point
671+
pub to: JsonPath,
672+
}
673+
674+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
675+
pub struct DatetimeExpression {
676+
pub datetime: String,
677+
}
678+
679+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
680+
pub struct DatetimeKeyExpression {
681+
pub datetime_key: String,
682+
}
683+
659684
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
660685
pub struct MultExpression {
661686
pub mult: Vec<Expression>,
@@ -746,19 +771,6 @@ pub struct DecayParamsExpression {
746771
pub midpoint: Option<f32>,
747772
}
748773

749-
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
750-
pub struct GeoDistance {
751-
pub geo_distance: GeoDistanceParams,
752-
}
753-
754-
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
755-
pub struct GeoDistanceParams {
756-
/// The origin geo point to measure from
757-
pub origin: GeoPoint,
758-
/// Payload field with the destination geo point
759-
pub to: JsonPath,
760-
}
761-
762774
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
763775
#[serde(rename_all = "snake_case")]
764776
pub enum Sample {

0 commit comments

Comments
 (0)