Skip to content

Commit f919124

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-54232][GEO][CONNECT] Enable Arrow serialization for Geography and Geometry types
### What changes were proposed in this pull request? Introduce Arrow serialization/deserialization for `Geography` and `Geometry`. ### Why are the changes needed? Enable geospatial result set serialization in Arrow format for Spark Connect and Thrift Server. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests: - `GeographyConnectDataFrameSuite` - `GeometryConnectDataFrameSuite` - `ArrowEncoderSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52930 from uros-db/geo-arrow. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bab98af commit f919124

File tree

8 files changed

+487
-3
lines changed

8 files changed

+487
-3
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.collection.immutable.Seq
21+
22+
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
23+
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
24+
import org.apache.spark.sql.types._
25+
26+
class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession {
27+
28+
private val point1: Array[Byte] = "010100000000000000000031400000000000001C40"
29+
.grouped(2)
30+
.map(Integer.parseInt(_, 16).toByte)
31+
.toArray
32+
private val point2: Array[Byte] = "010100000000000000000035400000000000001E40"
33+
.grouped(2)
34+
.map(Integer.parseInt(_, 16).toByte)
35+
.toArray
36+
37+
test("decode geography value: SRID schema does not match input SRID data schema") {
38+
val geography = Geography.fromWKB(point1, 0)
39+
40+
val seq = Seq((geography, 1))
41+
checkError(
42+
exception = intercept[SparkRuntimeException] {
43+
spark.createDataFrame(seq).collect()
44+
},
45+
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
46+
parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326"))
47+
48+
import testImplicits._
49+
checkError(
50+
exception = intercept[SparkRuntimeException] {
51+
Seq(geography).toDF().collect()
52+
},
53+
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
54+
parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326"))
55+
}
56+
57+
test("decode geography value: mixed SRID schema is provided") {
58+
val schema = StructType(Seq(StructField("col1", GeographyType("ANY"), nullable = false)))
59+
val expectedResult =
60+
Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))
61+
62+
val javaList = java.util.Arrays
63+
.asList(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))
64+
val resultJavaListDF = spark.createDataFrame(javaList, schema)
65+
checkAnswer(resultJavaListDF, expectedResult)
66+
67+
// Test that unsupported SRID with mixed schema will throw an error.
68+
val invalidData =
69+
java.util.Arrays
70+
.asList(Row(Geography.fromWKB(point1, 1)), Row(Geography.fromWKB(point2, 4326)))
71+
checkError(
72+
exception = intercept[SparkIllegalArgumentException] {
73+
spark.createDataFrame(invalidData, schema).collect()
74+
},
75+
condition = "ST_INVALID_SRID_VALUE",
76+
parameters = Map("srid" -> "1"))
77+
}
78+
79+
test("createDataFrame APIs with Geography.fromWKB") {
80+
val geography1 = Geography.fromWKB(point1, 4326)
81+
val geography2 = Geography.fromWKB(point2)
82+
83+
val seq = Seq((geography1, 1), (geography2, 2), (null, 3))
84+
val dfFromSeq = spark.createDataFrame(seq)
85+
checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2), Row(null, 3)))
86+
87+
val schema = StructType(Seq(StructField("geography", GeographyType(4326), nullable = true)))
88+
89+
val javaList = java.util.Arrays.asList(Row(geography1), Row(geography2), Row(null))
90+
val dfFromJavaList = spark.createDataFrame(javaList, schema)
91+
checkAnswer(dfFromJavaList, Seq(Row(geography1), Row(geography2), Row(null)))
92+
93+
import testImplicits._
94+
val implicitDf = Seq(geography1, geography2, null).toDF()
95+
checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null)))
96+
}
97+
98+
test("encode geography type") {
99+
// POINT (17 7)
100+
val wkb = "010100000000000000000031400000000000001C40"
101+
val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$wkb')")
102+
val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
103+
val expectedGeog = Geography.fromWKB(point, 4326)
104+
checkAnswer(df, Seq(Row(expectedGeog)))
105+
}
106+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.collection.immutable.Seq
21+
22+
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
23+
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
24+
import org.apache.spark.sql.types._
25+
26+
class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession {
27+
28+
private val point1: Array[Byte] = "010100000000000000000031400000000000001C40"
29+
.grouped(2)
30+
.map(Integer.parseInt(_, 16).toByte)
31+
.toArray
32+
private val point2: Array[Byte] = "010100000000000000000035400000000000001E40"
33+
.grouped(2)
34+
.map(Integer.parseInt(_, 16).toByte)
35+
.toArray
36+
37+
test("decode geometry value: SRID schema does not match input SRID data schema") {
38+
val geometry = Geometry.fromWKB(point1, 4326)
39+
40+
val seq = Seq((geometry, 1))
41+
checkError(
42+
exception = intercept[SparkRuntimeException] {
43+
spark.createDataFrame(seq).collect()
44+
},
45+
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
46+
parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0"))
47+
48+
import testImplicits._
49+
checkError(
50+
exception = intercept[SparkRuntimeException] {
51+
Seq(geometry).toDF().collect()
52+
},
53+
condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
54+
parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0"))
55+
}
56+
57+
test("decode geometry value: mixed SRID schema is provided") {
58+
val schema = StructType(Seq(StructField("col1", GeometryType("ANY"), nullable = false)))
59+
val expectedResult =
60+
Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))
61+
62+
val javaList = java.util.Arrays
63+
.asList(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))
64+
val resultJavaListDF = spark.createDataFrame(javaList, schema)
65+
checkAnswer(resultJavaListDF, expectedResult)
66+
67+
// Test that unsupported SRID with mixed schema will throw an error.
68+
val invalidData =
69+
java.util.Arrays
70+
.asList(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2, 4326)))
71+
checkError(
72+
exception = intercept[SparkIllegalArgumentException] {
73+
spark.createDataFrame(invalidData, schema).collect()
74+
},
75+
condition = "ST_INVALID_SRID_VALUE",
76+
parameters = Map("srid" -> "1"))
77+
}
78+
79+
test("createDataFrame APIs with Geometry.fromWKB") {
80+
val geometry1 = Geometry.fromWKB(point1, 0)
81+
val geometry2 = Geometry.fromWKB(point2, 0)
82+
83+
// 1. Test createDataFrame with Seq of Geometry objects
84+
val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3))
85+
val dfFromSeq = spark.createDataFrame(seq)
86+
checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 3)))
87+
88+
// 2. Test createDataFrame with RDD of Rows and StructType schema
89+
val geometry3 = Geometry.fromWKB(point1, 4326)
90+
val geometry4 = Geometry.fromWKB(point2, 4326)
91+
val schema = StructType(Seq(StructField("geometry", GeometryType(4326), nullable = true)))
92+
93+
// 3. Test createDataFrame with Java List of Rows and StructType schema
94+
val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4), Row(null))
95+
val dfFromJavaList = spark.createDataFrame(javaList, schema)
96+
checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null)))
97+
98+
// 4. Implicit conversion from Seq to DF
99+
import testImplicits._
100+
val implicitDf = Seq(geometry1, geometry2, null).toDF()
101+
checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null)))
102+
}
103+
104+
test("encode geometry type") {
105+
// POINT (17 7)
106+
val wkb = "010100000000000000000031400000000000001C40"
107+
val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$wkb')")
108+
val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
109+
val expectedGeom = Geometry.fromWKB(point, 0)
110+
checkAnswer(df, Seq(Row(expectedGeom)))
111+
}
112+
}

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
4444
import org.apache.spark.sql.connect.client.CloseableIterator
4545
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
4646
import org.apache.spark.sql.connect.test.ConnectFunSuite
47-
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
47+
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
4848
import org.apache.spark.unsafe.types.VariantVal
4949
import org.apache.spark.util.{MaybeNull, SparkStringUtils}
5050

@@ -263,6 +263,102 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
263263
assert(inspector.numBatches == 1)
264264
}
265265

266+
test("geography round trip") {
267+
val point1 = "010100000000000000000031400000000000001C40"
268+
.grouped(2)
269+
.map(Integer.parseInt(_, 16).toByte)
270+
.toArray
271+
val point2 = "010100000000000000000035400000000000001E40"
272+
.grouped(2)
273+
.map(Integer.parseInt(_, 16).toByte)
274+
.toArray
275+
276+
val geographyEncoder = toRowEncoder(new StructType().add("g", "geography(4326)"))
277+
roundTripAndCheckIdentical(geographyEncoder) { () =>
278+
val maybeNull = MaybeNull(7)
279+
Iterator.tabulate(101)(i => Row(maybeNull(Geography.fromWKB(point1, 4326))))
280+
}
281+
282+
val nestedGeographyEncoder = toRowEncoder(
283+
new StructType()
284+
.add(
285+
"s",
286+
new StructType()
287+
.add("i1", "int")
288+
.add("g0", "geography(4326)")
289+
.add("i2", "int")
290+
.add("g4326", "geography(4326)"))
291+
.add("a", "array<geography(4326)>")
292+
.add("m", "map<string, geography(ANY)>"))
293+
294+
roundTripAndCheckIdentical(nestedGeographyEncoder) { () =>
295+
val maybeNull5 = MaybeNull(5)
296+
val maybeNull7 = MaybeNull(7)
297+
val maybeNull11 = MaybeNull(11)
298+
val maybeNull13 = MaybeNull(13)
299+
val maybeNull17 = MaybeNull(17)
300+
Iterator
301+
.tabulate(100)(i =>
302+
Row(
303+
maybeNull5(
304+
Row(
305+
i,
306+
maybeNull7(Geography.fromWKB(point1)),
307+
i + 1,
308+
maybeNull11(Geography.fromWKB(point2, 4326)))),
309+
maybeNull7((0 until 10).map(j => Geography.fromWKB(point2, 0))),
310+
maybeNull13(Map((i.toString, maybeNull17(Geography.fromWKB(point1, 4326)))))))
311+
}
312+
}
313+
314+
test("geometry round trip") {
315+
val point1 = "010100000000000000000031400000000000001C40"
316+
.grouped(2)
317+
.map(Integer.parseInt(_, 16).toByte)
318+
.toArray
319+
val point2 = "010100000000000000000035400000000000001E40"
320+
.grouped(2)
321+
.map(Integer.parseInt(_, 16).toByte)
322+
.toArray
323+
324+
val geometryEncoder = toRowEncoder(new StructType().add("g", "geometry(0)"))
325+
roundTripAndCheckIdentical(geometryEncoder) { () =>
326+
val maybeNull = MaybeNull(7)
327+
Iterator.tabulate(101)(i => Row(maybeNull(Geometry.fromWKB(point1, 0))))
328+
}
329+
330+
val nestedGeometryEncoder = toRowEncoder(
331+
new StructType()
332+
.add(
333+
"s",
334+
new StructType()
335+
.add("i1", "int")
336+
.add("g0", "geometry(0)")
337+
.add("i2", "int")
338+
.add("g4326", "geometry(4326)"))
339+
.add("a", "array<geometry(0)>")
340+
.add("m", "map<string, geometry(ANY)>"))
341+
342+
roundTripAndCheckIdentical(nestedGeometryEncoder) { () =>
343+
val maybeNull5 = MaybeNull(5)
344+
val maybeNull7 = MaybeNull(7)
345+
val maybeNull11 = MaybeNull(11)
346+
val maybeNull13 = MaybeNull(13)
347+
val maybeNull17 = MaybeNull(17)
348+
Iterator
349+
.tabulate(100)(i =>
350+
Row(
351+
maybeNull5(
352+
Row(
353+
i,
354+
maybeNull7(Geometry.fromWKB(point1, 0)),
355+
i + 1,
356+
maybeNull11(Geometry.fromWKB(point2, 4326)))),
357+
maybeNull7((0 until 10).map(j => Geometry.fromWKB(point2, 0))),
358+
maybeNull13(Map((i.toString, maybeNull17(Geometry.fromWKB(point1, 4326)))))))
359+
}
360+
}
361+
266362
test("variant round trip") {
267363
val variantEncoder = toRowEncoder(new StructType().add("v", "variant"))
268364
roundTripAndCheckIdentical(variantEncoder) { () =>

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,14 @@ object ArrowDeserializers {
341341
}
342342
}
343343

344+
case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
345+
val gdser = new GeometryArrowSerDe
346+
gdser.createDeserializer(struct, vectors, timeZoneId)
347+
348+
case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
349+
val gdser = new GeographyArrowSerDe
350+
gdser.createDeserializer(struct, vectors, timeZoneId)
351+
344352
case (VariantEncoder, StructVectors(struct, vectors)) =>
345353
assert(vectors.exists(_.getName == "value"))
346354
assert(

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils {
4141
def unsupportedCollectionType(cls: Class[_]): Nothing = {
4242
throw new RuntimeException(s"Unsupported collection type: $cls")
4343
}
44+
45+
def assertMetadataPresent(
46+
vectors: Seq[FieldVector],
47+
expectedVectors: Seq[String],
48+
expectedMetadata: Seq[(String, String)]): Unit = {
49+
expectedVectors.foreach { vectorName =>
50+
assert(vectors.exists(_.getName == vectorName))
51+
}
52+
53+
expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key, value)) =>
54+
assert(
55+
vectors.exists(field =>
56+
field.getName == vectorName && field.getField.getMetadata
57+
.containsKey(key) && field.getField.getMetadata.get(key) == value))
58+
}
59+
}
4460
}
4561

4662
private[arrow] object StructVectors {

0 commit comments

Comments
 (0)