Skip to content

array_has UDF performance is slow for smaller number of needles #14533

@cetra3

Description

@cetra3

Describe the bug

When using array_has the performance is quite slow when there is a single needle or smaller needle amount to check for.

To Reproduce

Here's an example:

DataFusion CLI v44.0.0
> CREATE TABLE test AS (SELECT substr(md5(i)::text, 1, 32) as haystack FROM generate_series(1, 100000) t(i));
0 row(s) fetched. 
Elapsed 0.015 seconds.

> SELECT * FROM test limit 1;
+----------------------------------+
| haystack                         |
+----------------------------------+
| 7f4b18de3cfeb9b4ac78c381ee2ad278 |
+----------------------------------+
1 row(s) fetched. 
Elapsed 0.005 seconds.
> SELECT count(*) FROM test WHERE haystack = '7f4b18de3cfeb9b4ac78c381ee2ad278';
+----------+
| count(*) |
+----------+
| 1        |
+----------+
1 row(s) fetched. 
Elapsed 0.001 seconds.

> SELECT count(*) FROM test WHERE haystack IN ('7f4b18de3cfeb9b4ac78c381ee2ad278');
+----------+
| count(*) |
+----------+
| 1        |
+----------+
1 row(s) fetched. 
Elapsed 0.002 seconds.

> SELECT count(*) FROM test WHERE haystack = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278']);
+----------+
| count(*) |
+----------+
| 1        |
+----------+
1 row(s) fetched. 
Elapsed 0.031 seconds.

> SELECT count(*) FROM test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278'], haystack);
+----------+
| count(*) |
+----------+
| 1        |
+----------+
1 row(s) fetched. 
Elapsed 0.032 seconds.

Expected behavior

I'd expect that the array_has would be able to be pretty performant in this case.

Additional context

I've actually got an optimization that will convert this into a binary expression limiting it to 10 values. I'm wondering if there is a path forward to integrate this behaviour into the UDF itself, or if it exists as a separate optimization.

use datafusion::prelude::Expr;

const ARRAY_ARG_LIMIT: usize = 10;

/// converts `array_has(['val'], column)` into `trace_id = 'val'`
pub fn array_breakout(expr: &Expr) -> Option<Expr> {
    if let Expr::ScalarFunction(scalar_fn) = expr {
        if scalar_fn.name() == "array_has" && scalar_fn.args.len() == 2 {
            let array_exprs = args_from_make_array(&scalar_fn.args[0]);
            let expr_arg = &scalar_fn.args[1];
            if matches!(expr_arg, Expr::Column(_)) && array_exprs.len() <= ARRAY_ARG_LIMIT {
                return Some(
                    array_exprs
                        .iter()
                        .map(|val| val.clone().eq(expr_arg.clone()))
                        .reduce(Expr::or)
                        .unwrap_or_else(|| expr.clone()),
                );
            }
        }
    }

    None
}

fn args_from_make_array(expr: &Expr) -> &[Expr] {
    if let Expr::ScalarFunction(ref scalar_fn) = expr {
        if scalar_fn.name() == "make_array" {
            return &scalar_fn.args;
        }
    }

    &[]
}

#[cfg(test)]
mod tests {
    use arrow_schema::{DataType, Field, Schema};
    use datafusion::{common::DFSchema, prelude::SessionContext, sql::unparser::Unparser};

    use super::array_breakout;

    #[test]
    fn test_array_breakout() {
        let ctx = SessionContext::new();

        let schema = Schema::new(vec![
            Field::new("trace_id", DataType::Utf8, true)
        ]);
        let schema = DFSchema::try_from(schema).unwrap();
        let unparser = Unparser::default().with_pretty(true);

        let tests = [
            (
                "array_has(['0001'], trace_id)",
                Some("'0001' = trace_id")
            ),
            (
                "array_has(['0001', '0002'], trace_id)",
                Some("'0001' = trace_id OR '0002' = trace_id")
            ),
            (
                "array_has(['0001', '0002', '0003', '0004', '0005', '0006', '0007', '0008', '0009', '0010', '0011'], trace_id)",
                None
            )
        ];

        for (input_sql, output_sql) in tests {
            let input_expr = ctx.parse_sql_expr(input_sql, &schema).unwrap();
            let output_expr = array_breakout(&input_expr);

            if let Some(output_sql) = output_sql {
                let output = unparser.expr_to_sql(&output_expr.unwrap()).unwrap().to_string();
                assert_eq!(output, output_sql);
            } else {
                assert!(output_expr.is_none());
            }
        }
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions