Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions native/core/src/execution/shuffle/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,25 +444,18 @@ pub(crate) fn append_field(
// Appending value into struct field builder of Arrow struct builder.
let field_builder = struct_builder.field_builder::<StructBuilder>(idx).unwrap();

if row.is_null_row() {
// The row is null.
let nested_row = if row.is_null_row() || row.is_null_at(idx) {
// The row is null, or the field in the row is null, i.e., a null nested row.
// Append a null value to the row builder.
field_builder.append_null();
SparkUnsafeRow::default()
} else {
let is_null = row.is_null_at(idx);
field_builder.append(true);
row.get_struct(idx, fields.len())
};

let nested_row = if is_null {
// The field in the row is null, i.e., a null nested row.
// Append a null value to the row builder.
field_builder.append_null();
SparkUnsafeRow::default()
} else {
field_builder.append(true);
row.get_struct(idx, fields.len())
};

for (field_idx, field) in fields.into_iter().enumerate() {
append_field(field.data_type(), field_builder, &nested_row, field_idx)?;
}
for (field_idx, field) in fields.into_iter().enumerate() {
append_field(field.data_type(), field_builder, &nested_row, field_idx)?;
}
}
DataType::Map(field, _) => {
Expand Down Expand Up @@ -3302,3 +3295,45 @@ fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) -> Result<RecordBatch, Ar
let options = RecordBatchOptions::new().with_row_count(Option::from(row_count));
RecordBatch::try_new_with_options(schema, arrays, &options)
}

#[cfg(test)]
mod test {
use arrow::datatypes::Fields;

use super::*;

#[test]
fn test_append_null_row_to_struct_builder() {
let data_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Boolean, true),
]));
let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]);
let mut struct_builder = StructBuilder::from_fields(fields, 1);
let row = SparkUnsafeRow::default();
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
struct_builder.append_null();
let struct_array = struct_builder.finish();
assert_eq!(struct_array.len(), 1);
assert!(struct_array.is_null(0));
}

#[test]
#[cfg_attr(miri, ignore)] // Unaligned memory access in SparkUnsafeRow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to file a follow-on issue to remove the unaligned memory access?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logged #1849

fn test_append_null_struct_field_to_struct_builder() {
let data_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Boolean, true),
]));
let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]);
let mut struct_builder = StructBuilder::from_fields(fields, 1);
let mut row = SparkUnsafeRow::new_with_num_fields(1);
let data = [0; 8];
row.point_to_slice(&data);
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
struct_builder.append_null();
let struct_array = struct_builder.finish();
assert_eq!(struct_array.len(), 1);
assert!(struct_array.is_null(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package org.apache.comet.exec

import java.nio.file.Files
import java.nio.file.Paths

import scala.reflect.runtime.universe._
import scala.util.Random

Expand Down Expand Up @@ -820,6 +823,33 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
}
}

test("columnar shuffle on null struct fields") {
withTempDir { dir =>
val testData = "{}\n"
val path = Paths.get(dir.toString, "test.json")
Files.write(path, testData.getBytes)

// Define the nested struct schema
val readSchema = StructType(
Array(
StructField(
"metaData",
StructType(
Array(StructField(
"format",
StructType(Array(StructField("provider", StringType, nullable = true))),
nullable = true))),
nullable = true)))

// Read JSON with custom schema and repartition, this will repartition rows that contain
// null struct fields.
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
assert(df.count() == 1)
val row = df.collect()(0)
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
}
}

/**
* Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet
* exchange operators.
Expand Down
Loading