Skip to content

Commit eb467da

Browse files
committed
sqrt, pow, exp, log, ln, abs
1 parent 8ec07b6 commit eb467da

7 files changed

Lines changed: 190 additions & 1 deletion

File tree

lib/api/src/grpc/proto/points.proto

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,12 @@ message Expression {
587587
SumExpression sum = 6; // Sum
588588
DivExpression div = 7; // Divide
589589
Expression neg = 8; // Negate
590+
Expression abs = 9; // Absolute value
591+
Expression sqrt = 10; // Square root
592+
PowExpression pow = 11; // Power
593+
Expression exp = 12; // Exponential
594+
Expression log10 = 13; // Logarithm
595+
Expression ln = 14; // Natural logarithm
590596
}
591597
}
592598

@@ -609,6 +615,11 @@ message DivExpression {
609615
optional float by_zero_default = 3;
610616
}
611617

618+
message PowExpression {
619+
Expression base = 1;
620+
Expression exponent = 2;
621+
}
622+
612623
message Query {
613624
oneof variant {
614625
VectorInput nearest = 1; // Find the nearest neighbors to this vector.

lib/api/src/grpc/qdrant.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5185,7 +5185,10 @@ pub struct Formula {
51855185
#[allow(clippy::derive_partial_eq_without_eq)]
51865186
#[derive(Clone, PartialEq, ::prost::Message)]
51875187
pub struct Expression {
5188-
#[prost(oneof = "expression::Variant", tags = "1, 2, 3, 4, 5, 6, 7, 8")]
5188+
#[prost(
5189+
oneof = "expression::Variant",
5190+
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14"
5191+
)]
51895192
pub variant: ::core::option::Option<expression::Variant>,
51905193
}
51915194
/// Nested message and enum types in `Expression`.
@@ -5217,6 +5220,24 @@ pub mod expression {
52175220
/// Negate
52185221
#[prost(message, tag = "8")]
52195222
Neg(::prost::alloc::boxed::Box<super::Expression>),
5223+
/// Absolute value
5224+
#[prost(message, tag = "9")]
5225+
Abs(::prost::alloc::boxed::Box<super::Expression>),
5226+
/// Square root
5227+
#[prost(message, tag = "10")]
5228+
Sqrt(::prost::alloc::boxed::Box<super::Expression>),
5229+
/// Power
5230+
#[prost(message, tag = "11")]
5231+
Pow(::prost::alloc::boxed::Box<super::PowExpression>),
5232+
/// Exponential
5233+
#[prost(message, tag = "12")]
5234+
Exp(::prost::alloc::boxed::Box<super::Expression>),
5235+
/// Logarithm
5236+
#[prost(message, tag = "13")]
5237+
Log10(::prost::alloc::boxed::Box<super::Expression>),
5238+
/// Natural logarithm
5239+
#[prost(message, tag = "14")]
5240+
Ln(::prost::alloc::boxed::Box<super::Expression>),
52205241
}
52215242
}
52225243
#[derive(serde::Serialize)]
@@ -5256,6 +5277,15 @@ pub struct DivExpression {
52565277
#[derive(serde::Serialize)]
52575278
#[allow(clippy::derive_partial_eq_without_eq)]
52585279
#[derive(Clone, PartialEq, ::prost::Message)]
5280+
pub struct PowExpression {
5281+
#[prost(message, optional, boxed, tag = "1")]
5282+
pub base: ::core::option::Option<::prost::alloc::boxed::Box<Expression>>,
5283+
#[prost(message, optional, boxed, tag = "2")]
5284+
pub exponent: ::core::option::Option<::prost::alloc::boxed::Box<Expression>>,
5285+
}
5286+
#[derive(serde::Serialize)]
5287+
#[allow(clippy::derive_partial_eq_without_eq)]
5288+
#[derive(Clone, PartialEq, ::prost::Message)]
52595289
pub struct Query {
52605290
#[prost(oneof = "query::Variant", tags = "1, 2, 3, 4, 5, 6, 7, 8")]
52615291
pub variant: ::core::option::Option<query::Variant>,

lib/api/src/rest/schema.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,13 @@ pub enum Expression {
641641
Mult(MultExpression),
642642
Sum(SumExpression),
643643
Neg(NegExpression),
644+
Abs(AbsExpression),
644645
Div(DivExpression),
646+
Sqrt(SqrtExpression),
647+
Pow(PowExpression),
648+
Exp(ExpExpression),
649+
Log10(Log10Expression),
650+
Ln(LnExpression),
645651
GeoDistance(GeoDistance),
646652
}
647653

@@ -660,6 +666,11 @@ pub struct NegExpression {
660666
pub neg: Box<Expression>,
661667
}
662668

669+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
670+
pub struct AbsExpression {
671+
pub abs: Box<Expression>,
672+
}
673+
663674
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
664675
pub struct DivExpression {
665676
pub div: DivParams,
@@ -672,6 +683,37 @@ pub struct DivParams {
672683
pub by_zero_default: Option<ScoreType>,
673684
}
674685

686+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
687+
pub struct SqrtExpression {
688+
pub sqrt: Box<Expression>,
689+
}
690+
691+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
692+
pub struct PowExpression {
693+
pub pow: PowParams,
694+
}
695+
696+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
697+
pub struct PowParams {
698+
pub base: Box<Expression>,
699+
pub exponent: Box<Expression>,
700+
}
701+
702+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
703+
pub struct ExpExpression {
704+
pub exp: Box<Expression>,
705+
}
706+
707+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
708+
pub struct Log10Expression {
709+
pub log10: Box<Expression>,
710+
}
711+
712+
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
713+
pub struct LnExpression {
714+
pub ln: Box<Expression>,
715+
}
716+
675717
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
676718
pub struct GeoDistance {
677719
pub geo_distance: GeoDistanceParams,

lib/collection/src/operations/universal_query/formula.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ pub enum ExpressionInternal {
2626
right: Box<ExpressionInternal>,
2727
by_zero_default: Option<ScoreType>,
2828
},
29+
Sqrt(Box<ExpressionInternal>),
30+
Pow {
31+
base: Box<ExpressionInternal>,
32+
exponent: Box<ExpressionInternal>,
33+
},
34+
Exp(Box<ExpressionInternal>),
35+
Log10(Box<ExpressionInternal>),
36+
Ln(Box<ExpressionInternal>),
37+
Abs(Box<ExpressionInternal>),
2938
GeoDistance {
3039
origin: GeoPoint,
3140
to: JsonPath,
@@ -74,6 +83,25 @@ impl From<rest::Expression> for ExpressionInternal {
7483
by_zero_default,
7584
}
7685
}
86+
rest::Expression::Sqrt(sqrt_expression) => {
87+
ExpressionInternal::Sqrt(Box::new(ExpressionInternal::from(*sqrt_expression.sqrt)))
88+
}
89+
rest::Expression::Pow(rest::PowExpression { pow }) => ExpressionInternal::Pow {
90+
base: Box::new(ExpressionInternal::from(*pow.base)),
91+
exponent: Box::new(ExpressionInternal::from(*pow.exponent)),
92+
},
93+
rest::Expression::Exp(rest::ExpExpression { exp: expr }) => {
94+
ExpressionInternal::Exp(Box::new(ExpressionInternal::from(*expr)))
95+
}
96+
rest::Expression::Log10(rest::Log10Expression { log10: expr }) => {
97+
ExpressionInternal::Log10(Box::new(ExpressionInternal::from(*expr)))
98+
}
99+
rest::Expression::Ln(rest::LnExpression { ln: expr }) => {
100+
ExpressionInternal::Ln(Box::new(ExpressionInternal::from(*expr)))
101+
}
102+
rest::Expression::Abs(rest::AbsExpression { abs: expr }) => {
103+
ExpressionInternal::Abs(Box::new(ExpressionInternal::from(*expr)))
104+
}
77105
rest::Expression::GeoDistance(GeoDistance {
78106
geo_distance: rest::GeoDistanceParams { origin, to },
79107
}) => ExpressionInternal::GeoDistance { origin, to },

lib/collection/src/operations/universal_query/shard_query.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,25 @@ impl ExpressionInternal {
246246
ExpressionInternal::GeoDistance { origin, to } => {
247247
ParsedExpression::new_geo_distance(origin, to)
248248
}
249+
ExpressionInternal::Sqrt(expression_internal) => ParsedExpression::Sqrt(Box::new(
250+
expression_internal.parse_and_convert(payload_vars, conditions)?,
251+
)),
252+
ExpressionInternal::Pow { base, exponent } => ParsedExpression::Pow {
253+
base: Box::new(base.parse_and_convert(payload_vars, conditions)?),
254+
exponent: Box::new(exponent.parse_and_convert(payload_vars, conditions)?),
255+
},
256+
ExpressionInternal::Exp(expression_internal) => ParsedExpression::Exp(Box::new(
257+
expression_internal.parse_and_convert(payload_vars, conditions)?,
258+
)),
259+
ExpressionInternal::Log10(expression_internal) => ParsedExpression::Log10(Box::new(
260+
expression_internal.parse_and_convert(payload_vars, conditions)?,
261+
)),
262+
ExpressionInternal::Ln(expression_internal) => ParsedExpression::Ln(Box::new(
263+
expression_internal.parse_and_convert(payload_vars, conditions)?,
264+
)),
265+
ExpressionInternal::Abs(expression_internal) => ParsedExpression::Abs(Box::new(
266+
expression_internal.parse_and_convert(payload_vars, conditions)?,
267+
)),
249268
};
250269

251270
Ok(expr)
@@ -516,6 +535,31 @@ impl TryFrom<grpc::Expression> for ExpressionInternal {
516535
Variant::Neg(expression) => {
517536
ExpressionInternal::Neg(Box::new((*expression).try_into()?))
518537
}
538+
Variant::Abs(expression) => {
539+
ExpressionInternal::Abs(Box::new((*expression).try_into()?))
540+
}
541+
Variant::Sqrt(expression) => {
542+
ExpressionInternal::Sqrt(Box::new((*expression).try_into()?))
543+
}
544+
Variant::Pow(pow_expression) => {
545+
let grpc::PowExpression { base, exponent } = *pow_expression;
546+
let raw_base =
547+
*base.ok_or_else(|| tonic::Status::invalid_argument("missing field: base"))?;
548+
let raw_exponent = *exponent
549+
.ok_or_else(|| tonic::Status::invalid_argument("missing field: exponent"))?;
550+
551+
ExpressionInternal::Pow {
552+
base: Box::new(raw_base.try_into()?),
553+
exponent: Box::new(raw_exponent.try_into()?),
554+
}
555+
}
556+
Variant::Exp(expression) => {
557+
ExpressionInternal::Exp(Box::new((*expression).try_into()?))
558+
}
559+
Variant::Log10(expression) => {
560+
ExpressionInternal::Log10(Box::new((*expression).try_into()?))
561+
}
562+
Variant::Ln(expression) => ExpressionInternal::Ln(Box::new((*expression).try_into()?)),
519563
};
520564

521565
Ok(expression)

lib/segment/src/index/query_optimization/rescore_formula/formula_scorer.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,31 @@ impl FormulaScorer<'_> {
153153
let value = self.eval_expression(expr, point_id)?;
154154
Ok(value.neg())
155155
}
156+
ParsedExpression::Sqrt(expr) => {
157+
let value = self.eval_expression(expr, point_id)?;
158+
Ok(value.sqrt())
159+
}
160+
ParsedExpression::Pow { base, exponent } => {
161+
let base_value = self.eval_expression(base, point_id)?;
162+
let exponent_value = self.eval_expression(exponent, point_id)?;
163+
Ok(base_value.powf(exponent_value))
164+
}
165+
ParsedExpression::Exp(parsed_expression) => {
166+
let value = self.eval_expression(parsed_expression, point_id)?;
167+
Ok(value.exp())
168+
}
169+
ParsedExpression::Log10(expr) => {
170+
let value = self.eval_expression(expr, point_id)?;
171+
Ok(value.log10())
172+
}
173+
ParsedExpression::Ln(expr) => {
174+
let value = self.eval_expression(expr, point_id)?;
175+
Ok(value.ln())
176+
}
177+
ParsedExpression::Abs(expr) => {
178+
let value = self.eval_expression(expr, point_id)?;
179+
Ok(value.abs())
180+
}
156181
ParsedExpression::GeoDistance { origin, key } => {
157182
let value: GeoPoint = self
158183
.payload_retrievers

lib/segment/src/index/query_optimization/rescore_formula/parsed_formula.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ pub enum ParsedExpression {
4141
by_zero_default: ScoreType,
4242
},
4343
Neg(Box<ParsedExpression>),
44+
Sqrt(Box<ParsedExpression>),
45+
Pow {
46+
base: Box<ParsedExpression>,
47+
exponent: Box<ParsedExpression>,
48+
},
49+
Exp(Box<ParsedExpression>),
50+
Log10(Box<ParsedExpression>),
51+
Ln(Box<ParsedExpression>),
52+
Abs(Box<ParsedExpression>),
4453
GeoDistance {
4554
origin: GeoPoint,
4655
key: JsonPath,

0 commit comments

Comments
 (0)