Skip to content

[Bug] Ingestion with Arrow Flight Sql panic when the input stream is empty or fallible #7329

@niebayes

Description

@niebayes

Describe the bug

Arrow Flight Sql provides an ingest API for bulk ingestion. Specifically, the flight sql server should implement the do_put_statement_ingest and the flight sql client should call FlightSqlServiceClient::execute_ingest to ingest a stream of record batches.

The ingestion works well for non-empty infallible stream. However, the server will panic if the stream is fallible or empty.
The panic position is:

let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;

To Reproduce

I have written a test to reproduce the bug:

#[cfg(test)]
mod tests {
    use std::net::SocketAddr;
    use std::str::FromStr;
    use std::sync::Arc;

    use arrow::array::{Int32Array, StringArray};
    use arrow::datatypes::{DataType, Field, Schema};
    use arrow::record_batch::RecordBatch;
    use arrow_flight::decode::FlightRecordBatchStream;
    use arrow_flight::error::FlightError;
    use arrow_flight::sql::CommandStatementIngest;
    use arrow_flight::sql::SqlInfo;
    use arrow_flight::sql::client::FlightSqlServiceClient;
    use arrow_flight::sql::server::PeekableFlightDataStream;
    use arrow_flight::{flight_service_server::FlightServiceServer, sql::server::FlightSqlService};
    use futures::TryStreamExt;
    use tokio::sync::oneshot;
    use tonic::transport::Endpoint;
    use tonic::{Request, Status, transport::Server};

    #[derive(Clone)]
    struct DummyFlightSqlServer;

    #[tonic::async_trait]
    impl FlightSqlService for DummyFlightSqlServer {
        type FlightService = DummyFlightSqlServer;

        /// Execute a bulk ingestion.
        async fn do_put_statement_ingest(
            &self,
            _ticket: CommandStatementIngest,
            request: Request<PeekableFlightDataStream>,
        ) -> Result<i64, Status> {
            let stream = FlightRecordBatchStream::new_from_flight_data(
                request.into_inner().map_err(FlightError::from),
            );
            let batches = stream
                .try_collect::<Vec<_>>()
                .await
                .map_err(|e| Status::unknown(e.to_string()))?;
            let affected_rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
            Ok(affected_rows as i64)
        }

        async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
    }

    #[tokio::test]
    async fn test_flight_sql_ingest() {
        // 1. Create a dummy Flight SQL server
        let server = DummyFlightSqlServer {};

        // 2. Start the server on port 4000
        let addr: SocketAddr = "127.0.0.1:4000".parse().unwrap();
        let (tx, rx) = oneshot::channel::<()>();

        let server_handle = tokio::spawn(async move {
            Server::builder()
                .add_service(FlightServiceServer::new(server))
                .serve_with_shutdown(addr, async {
                    rx.await.ok();
                })
                .await
                .unwrap();
        });

        // Wait for the server to start
        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;

        // 3. Create a Flight SQL client and connect to the server
        let channel = Endpoint::from_str("http://127.0.0.1:4000")
            .unwrap()
            .connect()
            .await
            .unwrap();
        let mut client = FlightSqlServiceClient::new(channel);

        // 4. Prepare data for ingestion
        let schema = Arc::new(Schema::new(vec![
            Field::new("id", DataType::Int32, false),
            Field::new("name", DataType::Utf8, false),
        ]));
        let batch = RecordBatch::try_new(
            schema.clone(),
            vec![
                Arc::new(Int32Array::from(vec![1, 2, 3])),
                Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
            ],
        )
        .unwrap();

        // 5. Execute ingestion
        let stream = futures::stream::iter(vec![Ok(batch)]);
        let affected_rows = client
            .execute_ingest(
                CommandStatementIngest {
                    catalog: None,
                    schema: None,
                    table: "t".into(),
                    ..Default::default()
                },
                stream,
            )
            .await
            .unwrap();
        assert_eq!(affected_rows, 3);

        // 6. Execute ingestion with a fallible stream.
        let stream = futures::stream::iter(vec![Err(FlightError::ProtocolError("error".into()))]);
        let result = client
            .execute_ingest(
                CommandStatementIngest {
                    catalog: None,
                    schema: None,
                    table: "t".into(),
                    ..Default::default()
                },
                stream,
            )
            .await;
        if let Err(e) = result {
            println!("ingest error: {}", e);
        }

        // 7. Execute ingestion with an empty stream.
        let stream = futures::stream::empty();
        let result = client
            .execute_ingest(
                CommandStatementIngest {
                    catalog: None,
                    schema: None,
                    table: "t".into(),
                    ..Default::default()
                },
                stream,
            )
            .await;
        if let Err(e) = result {
            println!("ingest error: {}", e);
        }

        // Shutdown the server
        tx.send(()).unwrap();
        server_handle.await.unwrap();
    }
}

Expected behavior

Additional context

Metadata

Metadata

Labels

arrowChanges to the arrow cratearrow-flightChanges to the arrow-flight cratebug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions