Skip to content

Commit 0eca749

Browse files
committed
Add decay expressions
1 parent 4b74671 commit 0eca749

10 files changed

Lines changed: 462 additions & 8 deletions

File tree

docs/grpc/docs.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
- [CountResult](#qdrant-CountResult)
139139
- [CreateFieldIndexCollection](#qdrant-CreateFieldIndexCollection)
140140
- [DatetimeRange](#qdrant-DatetimeRange)
141+
- [DecayParamsExpression](#qdrant-DecayParamsExpression)
141142
- [DeleteFieldIndexCollection](#qdrant-DeleteFieldIndexCollection)
142143
- [DeletePayloadPoints](#qdrant-DeletePayloadPoints)
143144
- [DeletePointVectors](#qdrant-DeletePointVectors)
@@ -2446,6 +2447,24 @@ The JSON representation for `Value` is a JSON value.
24462447

24472448

24482449

2450+
<a name="qdrant-DecayParamsExpression"></a>
2451+
2452+
### DecayParamsExpression
2453+
2454+
2455+
2456+
| Field | Type | Label | Description |
2457+
| ----- | ---- | ----- | ----------- |
2458+
| x | [Expression](#qdrant-Expression) | | The variable to decay |
2459+
| target | [Expression](#qdrant-Expression) | optional | The target value to start decaying from. Defaults to 0. |
2460+
| scale | [float](#float) | optional | The scale factor of the decay, in terms of `x`. Defaults to 1.0. Must be a non-zero positive number. |
2461+
| midpoint | [float](#float) | optional | The midpoint of the decay. Defaults to 0.5. Output will be this value when `|x - target| == scale`. |
2462+
2463+
2464+
2465+
2466+
2467+
24492468
<a name="qdrant-DeleteFieldIndexCollection"></a>
24502469

24512470
### DeleteFieldIndexCollection
@@ -2706,6 +2725,9 @@ The JSON representation for `Value` is a JSON value.
27062725
| exp | [Expression](#qdrant-Expression) | | Exponential |
27072726
| log10 | [Expression](#qdrant-Expression) | | Logarithm |
27082727
| ln | [Expression](#qdrant-Expression) | | Natural logarithm |
2728+
| exp_decay | [DecayParamsExpression](#qdrant-DecayParamsExpression) | | Exponential decay |
2729+
| gauss_decay | [DecayParamsExpression](#qdrant-DecayParamsExpression) | | Gaussian decay |
2730+
| lin_decay | [DecayParamsExpression](#qdrant-DecayParamsExpression) | | Linear decay |
27092731

27102732

27112733

docs/redoc/master/openapi.json

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14490,6 +14490,15 @@
1449014490
},
1449114491
{
1449214492
"$ref": "#/components/schemas/GeoDistance"
14493+
},
14494+
{
14495+
"$ref": "#/components/schemas/LinDecayExpression"
14496+
},
14497+
{
14498+
"$ref": "#/components/schemas/ExpDecayExpression"
14499+
},
14500+
{
14501+
"$ref": "#/components/schemas/GaussDecayExpression"
1449314502
}
1449414503
]
1449514504
},
@@ -14671,6 +14680,73 @@
1467114680
}
1467214681
}
1467314682
},
14683+
"LinDecayExpression": {
14684+
"type": "object",
14685+
"required": [
14686+
"lin_decay"
14687+
],
14688+
"properties": {
14689+
"lin_decay": {
14690+
"$ref": "#/components/schemas/DecayParamsExpression"
14691+
}
14692+
}
14693+
},
14694+
"DecayParamsExpression": {
14695+
"type": "object",
14696+
"required": [
14697+
"x"
14698+
],
14699+
"properties": {
14700+
"x": {
14701+
"$ref": "#/components/schemas/Expression"
14702+
},
14703+
"target": {
14704+
"description": "The target value to start decaying from. Defaults to 0.",
14705+
"anyOf": [
14706+
{
14707+
"$ref": "#/components/schemas/Expression"
14708+
},
14709+
{
14710+
"nullable": true
14711+
}
14712+
]
14713+
},
14714+
"scale": {
14715+
"description": "The scale factor of the decay, in terms of `x`. Defaults to 1.0. Must be a non-zero positive number.",
14716+
"type": "number",
14717+
"format": "float",
14718+
"nullable": true
14719+
},
14720+
"midpoint": {
14721+
"description": "The midpoint of the decay. Defaults to 0.5. Output will be this value when `|x - target| == scale`.",
14722+
"type": "number",
14723+
"format": "float",
14724+
"nullable": true
14725+
}
14726+
}
14727+
},
14728+
"ExpDecayExpression": {
14729+
"type": "object",
14730+
"required": [
14731+
"exp_decay"
14732+
],
14733+
"properties": {
14734+
"exp_decay": {
14735+
"$ref": "#/components/schemas/DecayParamsExpression"
14736+
}
14737+
}
14738+
},
14739+
"GaussDecayExpression": {
14740+
"type": "object",
14741+
"required": [
14742+
"gauss_decay"
14743+
],
14744+
"properties": {
14745+
"gauss_decay": {
14746+
"$ref": "#/components/schemas/DecayParamsExpression"
14747+
}
14748+
}
14749+
},
1467414750
"SampleQuery": {
1467514751
"type": "object",
1467614752
"required": [

lib/api/src/grpc/conversions.rs

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use segment::data_types::index::{
1313
};
1414
use segment::data_types::{facets as segment_facets, vectors as segment_vectors};
1515
use segment::index::query_optimization::rescore_formula::parsed_formula::{
16-
ParsedExpression, ParsedFormula,
16+
DecayKind, ParsedExpression, ParsedFormula,
1717
};
1818
use segment::types::{DateTimePayloadType, FloatPayloadType, default_quantization_ignore_value};
1919
use segment::vector_storage::query as segment_query;
@@ -50,7 +50,9 @@ use crate::grpc::qdrant::{
5050
TextIndexParams, TokenizerType, UpdateResult, UpdateResultInternal, ValuesCount,
5151
VectorsSelector, WithPayloadSelector, WithVectorsSelector, shard_key, with_vectors_selector,
5252
};
53-
use crate::grpc::{DivExpression, GeoDistance, MultExpression, PowExpression, SumExpression};
53+
use crate::grpc::{
54+
DecayParamsExpression, DivExpression, GeoDistance, MultExpression, PowExpression, SumExpression,
55+
};
5456
use crate::rest::models::{CollectionsResponse, VersionInfo};
5557
use crate::rest::schema as rest;
5658

@@ -2780,14 +2782,14 @@ impl Formula {
27802782
}
27812783

27822784
fn unparse_expression(
2783-
formula: ParsedExpression,
2785+
expression: ParsedExpression,
27842786
conditions: &Vec<segment::types::Condition>,
27852787
) -> Expression {
27862788
use segment::index::query_optimization::rescore_formula::parsed_formula::VariableId;
27872789

27882790
use super::expression::Variant;
27892791

2790-
let variant = match formula {
2792+
let variant = match expression {
27912793
ParsedExpression::Constant(c) => Variant::Constant(c),
27922794
ParsedExpression::Variable(variable_id) => match variable_id {
27932795
var_id @ VariableId::Score(_) => Variant::Variable(var_id.unparse()),
@@ -2841,9 +2843,66 @@ fn unparse_expression(
28412843
origin: Some(GeoPoint::from(origin)),
28422844
to: key.to_string(),
28432845
}),
2846+
ParsedExpression::Decay {
2847+
kind,
2848+
target,
2849+
lambda,
2850+
x,
2851+
} => {
2852+
let (midpoint, scale) = decay_lambda_to_params(lambda, kind);
2853+
let params = DecayParamsExpression {
2854+
x: Some(Box::new(unparse_expression(*x, conditions))),
2855+
target: target.map(|t| Box::new(unparse_expression(*t, conditions))),
2856+
midpoint: Some(midpoint),
2857+
scale: Some(scale),
2858+
};
2859+
match kind {
2860+
DecayKind::Lin => Variant::LinDecay(Box::new(params)),
2861+
DecayKind::Exp => Variant::ExpDecay(Box::new(params)),
2862+
DecayKind::Gauss => Variant::GaussDecay(Box::new(params)),
2863+
}
2864+
}
28442865
};
28452866

28462867
Expression {
28472868
variant: Some(variant),
28482869
}
28492870
}
2871+
2872+
/// Converts the already computed lambda value to parameters which will result in
2873+
/// the same lambda when used in a decay function on the peer node.
2874+
///
2875+
/// Returns a tuple of (midpoint, scale) parameters.
2876+
fn decay_lambda_to_params(lambda: f32, kind: DecayKind) -> (f32, f32) {
2877+
// We assume lambda is in the range (0, 1)
2878+
debug_assert!(0.0 < lambda && lambda < 1.0);
2879+
match kind {
2880+
// Linear lambda is (1.0 - midpoint) / scale,
2881+
// setting scale to 1.0 allows us to ignore the division,
2882+
// and only set the midpoint to some value.
2883+
//
2884+
// (1.0 - midpoint) / 1.0 = lambda
2885+
// 1.0 - midpoint = lambda
2886+
// midpoint = 1.0 - lambda
2887+
DecayKind::Lin => ((-lambda + 1.0), 1.0),
2888+
2889+
// Gauss lambda is scale^2 / ln(midpoint)
2890+
// setting midpoint to e allows us to ignore the division, since ln(e) = 1
2891+
// Then we set scale to sqrt(lambda)
2892+
//
2893+
// scale^2 / ln(e) = lambda
2894+
// scale^2 / 1.0 = lambda
2895+
// scale^2 = lambda
2896+
// scale = sqrt(lambda)
2897+
DecayKind::Gauss => (std::f32::consts::E, lambda.sqrt()),
2898+
2899+
// Exponential lambda is ln(midpoint) / scale
2900+
// setting midpoint to e allows us to ignore the division, since ln(e) = 1
2901+
// Then we set scale to 1 / lambda
2902+
//
2903+
// ln(e) / scale = lambda
2904+
// 1.0 / scale = lambda
2905+
// scale = 1.0 / lambda
2906+
DecayKind::Exp => (std::f32::consts::E, 1.0 / lambda),
2907+
}
2908+
}

lib/api/src/grpc/proto/points.proto

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,9 @@ message Expression {
593593
Expression exp = 12; // Exponential
594594
Expression log10 = 13; // Logarithm
595595
Expression ln = 14; // Natural logarithm
596+
DecayParamsExpression exp_decay = 15; // Exponential decay
597+
DecayParamsExpression gauss_decay = 16; // Gaussian decay
598+
DecayParamsExpression lin_decay = 17; // Linear decay
596599
}
597600
}
598601

@@ -620,6 +623,17 @@ message PowExpression {
620623
Expression exponent = 2;
621624
}
622625

626+
message DecayParamsExpression {
627+
// The variable to decay
628+
Expression x = 1;
629+
// The target value to start decaying from. Defaults to 0.
630+
optional Expression target = 2;
631+
// The scale factor of the decay, in terms of `x`. Defaults to 1.0. Must be a non-zero positive number.
632+
optional float scale = 3;
633+
// The midpoint of the decay. Defaults to 0.5. Output will be this value when `|x - target| == scale`.
634+
optional float midpoint = 4;
635+
}
636+
623637
message Query {
624638
oneof variant {
625639
VectorInput nearest = 1; // Find the nearest neighbors to this vector.

lib/api/src/grpc/qdrant.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5182,7 +5182,7 @@ pub struct Formula {
51825182
pub struct Expression {
51835183
#[prost(
51845184
oneof = "expression::Variant",
5185-
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14"
5185+
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17"
51865186
)]
51875187
pub variant: ::core::option::Option<expression::Variant>,
51885188
}
@@ -5233,6 +5233,15 @@ pub mod expression {
52335233
/// Natural logarithm
52345234
#[prost(message, tag = "14")]
52355235
Ln(::prost::alloc::boxed::Box<super::Expression>),
5236+
/// Exponential decay
5237+
#[prost(message, tag = "15")]
5238+
ExpDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
5239+
/// Gaussian decay
5240+
#[prost(message, tag = "16")]
5241+
GaussDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
5242+
/// Linear decay
5243+
#[prost(message, tag = "17")]
5244+
LinDecay(::prost::alloc::boxed::Box<super::DecayParamsExpression>),
52365245
}
52375246
}
52385247
#[derive(serde::Serialize)]
@@ -5281,6 +5290,23 @@ pub struct PowExpression {
52815290
#[derive(serde::Serialize)]
52825291
#[allow(clippy::derive_partial_eq_without_eq)]
52835292
#[derive(Clone, PartialEq, ::prost::Message)]
5293+
pub struct DecayParamsExpression {
5294+
/// The variable to decay
5295+
#[prost(message, optional, boxed, tag = "1")]
5296+
pub x: ::core::option::Option<::prost::alloc::boxed::Box<Expression>>,
5297+
/// The target value to start decaying from. Defaults to 0.
5298+
#[prost(message, optional, boxed, tag = "2")]
5299+
pub target: ::core::option::Option<::prost::alloc::boxed::Box<Expression>>,
5300+
/// The scale factor of the decay, in terms of `x`. Defaults to 1.0. Must be a non-zero positive number.
5301+
#[prost(float, optional, tag = "3")]
5302+
pub scale: ::core::option::Option<f32>,
5303+
/// The midpoint of the decay. Defaults to 0.5. Output will be this value when `|x - target| == scale`.
5304+
#[prost(float, optional, tag = "4")]
5305+
pub midpoint: ::core::option::Option<f32>,
5306+
}
5307+
#[derive(serde::Serialize)]
5308+
#[allow(clippy::derive_partial_eq_without_eq)]
5309+
#[derive(Clone, PartialEq, ::prost::Message)]
52845310
pub struct Query {
52855311
#[prost(oneof = "query::Variant", tags = "1, 2, 3, 4, 5, 6, 7, 8")]
52865312
pub variant: ::core::option::Option<query::Variant>,

lib/api/src/rest/schema.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,9 @@ pub enum Expression {
651651
Log10(Log10Expression),
652652
Ln(LnExpression),
653653
GeoDistance(GeoDistance),
654+
LinDecay(LinDecayExpression),
655+
ExpDecay(ExpDecayExpression),
656+
GaussDecay(GaussDecayExpression),
654657
}
655658

656659
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -716,6 +719,33 @@ pub struct LnExpression {
716719
pub ln: Box<Expression>,
717720
}
718721

722+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
723+
pub struct LinDecayExpression {
724+
pub lin_decay: DecayParamsExpression,
725+
}
726+
727+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
728+
pub struct ExpDecayExpression {
729+
pub exp_decay: DecayParamsExpression,
730+
}
731+
732+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
733+
pub struct GaussDecayExpression {
734+
pub gauss_decay: DecayParamsExpression,
735+
}
736+
737+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
738+
pub struct DecayParamsExpression {
739+
/// The variable to decay.
740+
pub x: Box<Expression>,
741+
/// The target value to start decaying from. Defaults to 0.
742+
pub target: Option<Box<Expression>>,
743+
/// The scale factor of the decay, in terms of `x`. Defaults to 1.0. Must be a non-zero positive number.
744+
pub scale: Option<f32>,
745+
/// The midpoint of the decay. Defaults to 0.5. Output will be this value when `|x - target| == scale`.
746+
pub midpoint: Option<f32>,
747+
}
748+
719749
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
720750
pub struct GeoDistance {
721751
pub geo_distance: GeoDistanceParams,

0 commit comments

Comments
 (0)