Skip to content

Commit f36467b

Browse files
fix: improve REST datetime_range RFC3339 error message
1 parent e2cc142 commit f36467b

1 file changed

Lines changed: 162 additions & 8 deletions

File tree

lib/segment/src/types.rs

Lines changed: 162 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ impl<'de> Deserialize<'de> for DateTimePayloadType {
8686
where
8787
D: Deserializer<'de>,
8888
{
89-
let str_datetime = <&str>::deserialize(deserializer)?;
90-
let parse_result = DateTimePayloadType::from_str(str_datetime).ok();
91-
match parse_result {
92-
Some(datetime) => Ok(datetime),
93-
None => Err(serde::de::Error::custom(format!(
94-
"'{str_datetime}' is not in a supported date/time format, please use RFC 3339"
89+
let str_datetime: Cow<'de, str> = Cow::deserialize(deserializer)?;
90+
91+
match DateTimePayloadType::from_str(str_datetime.as_ref()) {
92+
Ok(datetime) => Ok(datetime),
93+
Err(_) => Err(serde::de::Error::custom(format!(
94+
"'{}' does not match accepted datetime format (RFC3339). Example: 2014-01-01T00:00:00Z",
95+
str_datetime
9596
))),
9697
}
9798
}
@@ -2592,7 +2593,7 @@ impl From<Vec<IntPayloadType>> for MatchExcept {
25922593
}
25932594
}
25942595

2595-
#[derive(Copy, Clone, Debug, Eq, PartialEq, Deserialize, Serialize, JsonSchema)]
2596+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, JsonSchema)]
25962597
#[serde(untagged)]
25972598
pub enum RangeInterface {
25982599
Float(Range<OrderedFloat<FloatPayloadType>>),
@@ -2620,6 +2621,45 @@ impl Hash for RangeInterface {
26202621
}
26212622
}
26222623

2624+
#[derive(serde::Deserialize)]
2625+
#[serde(untagged)]
2626+
enum RangeInterfaceUntagged {
2627+
Float(Range<OrderedFloatPayloadType>),
2628+
DateTime(Range<DateTimePayloadType>),
2629+
}
2630+
2631+
impl<'de> serde::Deserialize<'de> for RangeInterface {
2632+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
2633+
where
2634+
D: serde::Deserializer<'de>,
2635+
{
2636+
let value = serde_json::Value::deserialize(deserializer)?;
2637+
2638+
// If any range bound is a string -> treat as datetime range
2639+
if let Some(obj) = value.as_object() {
2640+
let keys = ["lt", "gt", "lte", "gte"];
2641+
let has_string_bound = keys
2642+
.iter()
2643+
.any(|k| obj.get(*k).map(|v| v.is_string()).unwrap_or(false));
2644+
2645+
if has_string_bound {
2646+
return serde_json::from_value::<Range<DateTimePayloadType>>(value)
2647+
.map(RangeInterface::DateTime)
2648+
.map_err(serde::de::Error::custom);
2649+
}
2650+
}
2651+
2652+
// Fallback to existing untagged behavior
2653+
let parsed = serde_json::from_value::<RangeInterfaceUntagged>(value)
2654+
.map_err(serde::de::Error::custom)?;
2655+
2656+
Ok(match parsed {
2657+
RangeInterfaceUntagged::Float(r) => RangeInterface::Float(r),
2658+
RangeInterfaceUntagged::DateTime(r) => RangeInterface::DateTime(r),
2659+
})
2660+
}
2661+
}
2662+
26232663
type OrderedFloatPayloadType = OrderedFloat<FloatPayloadType>;
26242664

26252665
/// Range filter request
@@ -3208,7 +3248,7 @@ impl NestedCondition {
32083248
}
32093249
}
32103250

3211-
#[derive(Clone, Debug, Deserialize, Serialize, JsonSchema, PartialEq, Eq, Hash)]
3251+
#[derive(Clone, Debug, Serialize, JsonSchema, PartialEq, Eq, Hash)]
32123252
#[serde(untagged)]
32133253
#[serde(
32143254
expecting = "Expected some form of condition, which can be a field condition (like {\"key\": ..., \"match\": ... }), or some other mentioned in the documentation: https://qdrant.tech/documentation/concepts/filtering/#filtering-conditions"
@@ -3234,6 +3274,102 @@ pub enum Condition {
32343274
CustomIdChecker(CustomIdChecker),
32353275
}
32363276

3277+
#[derive(Deserialize)]
3278+
#[serde(untagged)]
3279+
#[serde(
3280+
expecting = "Expected some form of condition, which can be a field condition (like {\"key\": ..., \"match\": ... }), or some other mentioned in the documentation: https://qdrant.tech/documentation/concepts/filtering/#filtering-conditions"
3281+
)]
3282+
#[allow(clippy::large_enum_variant)]
3283+
#[allow(dead_code)]
3284+
enum ConditionUntagged {
3285+
Field(FieldCondition),
3286+
IsEmpty(IsEmptyCondition),
3287+
IsNull(IsNullCondition),
3288+
HasId(HasIdCondition),
3289+
HasVector(HasVectorCondition),
3290+
Nested(NestedCondition),
3291+
Filter(Filter),
3292+
3293+
#[serde(skip)]
3294+
CustomIdChecker(CustomIdChecker),
3295+
}
3296+
3297+
impl From<ConditionUntagged> for Condition {
3298+
fn from(condition: ConditionUntagged) -> Self {
3299+
match condition {
3300+
ConditionUntagged::Field(condition) => Condition::Field(condition),
3301+
ConditionUntagged::IsEmpty(condition) => Condition::IsEmpty(condition),
3302+
ConditionUntagged::IsNull(condition) => Condition::IsNull(condition),
3303+
ConditionUntagged::HasId(condition) => Condition::HasId(condition),
3304+
ConditionUntagged::HasVector(condition) => Condition::HasVector(condition),
3305+
ConditionUntagged::Nested(condition) => Condition::Nested(condition),
3306+
ConditionUntagged::Filter(condition) => Condition::Filter(condition),
3307+
ConditionUntagged::CustomIdChecker(condition) => Condition::CustomIdChecker(condition),
3308+
}
3309+
}
3310+
}
3311+
3312+
impl<'de> serde::Deserialize<'de> for Condition {
3313+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
3314+
where
3315+
D: serde::Deserializer<'de>,
3316+
{
3317+
let value = serde_json::Value::deserialize(deserializer)?;
3318+
let Some(obj) = value.as_object() else {
3319+
return serde_json::from_value::<ConditionUntagged>(value)
3320+
.map(Condition::from)
3321+
.map_err(serde::de::Error::custom);
3322+
};
3323+
3324+
// IMPORTANT: avoid untagged swallowing datetime parse error
3325+
if obj.contains_key("key") {
3326+
return serde_json::from_value::<FieldCondition>(value)
3327+
.map(Condition::Field)
3328+
.map_err(serde::de::Error::custom);
3329+
}
3330+
3331+
if obj.contains_key("is_empty") {
3332+
return serde_json::from_value::<IsEmptyCondition>(value)
3333+
.map(Condition::IsEmpty)
3334+
.map_err(serde::de::Error::custom);
3335+
}
3336+
3337+
if obj.contains_key("is_null") {
3338+
return serde_json::from_value::<IsNullCondition>(value)
3339+
.map(Condition::IsNull)
3340+
.map_err(serde::de::Error::custom);
3341+
}
3342+
3343+
if obj.contains_key("has_id") {
3344+
return serde_json::from_value::<HasIdCondition>(value)
3345+
.map(Condition::HasId)
3346+
.map_err(serde::de::Error::custom);
3347+
}
3348+
3349+
if obj.contains_key("has_vector") {
3350+
return serde_json::from_value::<HasVectorCondition>(value)
3351+
.map(Condition::HasVector)
3352+
.map_err(serde::de::Error::custom);
3353+
}
3354+
3355+
if obj.contains_key("nested") {
3356+
return serde_json::from_value::<NestedCondition>(value)
3357+
.map(Condition::Nested)
3358+
.map_err(serde::de::Error::custom);
3359+
}
3360+
3361+
if obj.contains_key("filter") {
3362+
return serde_json::from_value::<Filter>(value)
3363+
.map(Condition::Filter)
3364+
.map_err(serde::de::Error::custom);
3365+
}
3366+
3367+
serde_json::from_value::<ConditionUntagged>(value)
3368+
.map(Condition::from)
3369+
.map_err(serde::de::Error::custom)
3370+
}
3371+
}
3372+
32373373
impl Condition {
32383374
pub fn new_custom(checker: Arc<dyn CustomIdCheckerCondition + Send + Sync + 'static>) -> Self {
32393375
Condition::CustomIdChecker(CustomIdChecker(checker))
@@ -3929,6 +4065,24 @@ mod tests {
39294065
assert_eq!(datetime.timestamp(), datetime_no_z.timestamp());
39304066
}
39314067

4068+
#[test]
4069+
fn test_invalid_datetime_range_returns_clear_rfc3339_error() {
4070+
let json = r#"{
4071+
"key": "created_at",
4072+
"range": {
4073+
"gte": "2014-01-01T00:00:00BAD"
4074+
}
4075+
}"#;
4076+
4077+
let err = serde_json::from_str::<Condition>(json)
4078+
.unwrap_err()
4079+
.to_string();
4080+
4081+
assert!(err.contains("RFC3339"), "err was: {}", err);
4082+
assert!(err.contains("2014-01-01T00:00:00BAD"), "err was: {}", err);
4083+
assert!(err.contains("Example"), "err was: {}", err);
4084+
}
4085+
39324086
#[test]
39334087
fn test_datetime_wrapper_transcoding() {
39344088
let expected = DateTimeWrapper(chrono::Utc::now());

0 commit comments

Comments
 (0)