Skip to content

Commit 881e636

Browse files
committed
Add specific Java List support to JavaTypeInference
Remove specific Java List support from ScalaReflection Remove implicit encoder for Java Lists Add relevant tests to JavaDatasetSuite Remove tests from ScalaReflectionSuite and DatasetPrimitiveSuite
1 parent fbf92c5 commit 881e636

File tree

5 files changed

+73
-35
lines changed

5 files changed

+73
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,16 +267,11 @@ object JavaTypeInference {
267267

268268
case c if listType.isAssignableFrom(typeToken) =>
269269
val et = elementType(typeToken)
270-
val array =
271-
Invoke(
272-
MapObjects(
273-
p => deserializerFor(et, Some(p)),
274-
getPath,
275-
inferDataType(et)._1),
276-
"array",
277-
ObjectType(classOf[Array[Any]]))
278-
279-
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
270+
MapObjects(
271+
p => deserializerFor(et, Some(p)),
272+
getPath,
273+
inferDataType(et)._1,
274+
customCollectionCls = Some(c))
280275

281276
case _ if mapType.isAssignableFrom(typeToken) =>
282277
val (keyType, valueType) = mapKeyValueType(typeToken)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ object ScalaReflection extends ScalaReflection {
307307
Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
308308
}
309309

310-
case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] =>
310+
case t if t <:< localTypeOf[Seq[_]] =>
311311
val TypeRef(_, _, Seq(elementType)) = t
312312
val Schema(dataType, elementNullable) = schemaFor(elementType)
313313
val className = getClassNameFromType(elementType)
@@ -324,14 +324,10 @@ object ScalaReflection extends ScalaReflection {
324324
}
325325
}
326326

327-
val cls = if (t <:< localTypeOf[java.util.List[_]]) {
328-
mirror.runtimeClass(t.typeSymbol.asClass)
329-
} else {
330-
val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
331-
companion.declaration(newTermName("newBuilder")) match {
332-
case NoSymbol => classOf[Seq[_]]
333-
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
334-
}
327+
val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
328+
val cls = companion.declaration(newTermName("newBuilder")) match {
329+
case NoSymbol => classOf[Seq[_]]
330+
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
335331
}
336332
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
337333

@@ -498,7 +494,7 @@ object ScalaReflection extends ScalaReflection {
498494
// Since List[_] also belongs to localTypeOf[Product], we put this case before
499495
// "case t if definedByConstructorParams(t)" to make sure it will match to the
500496
// case "localTypeOf[Seq[_]]"
501-
case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] =>
497+
case t if t <:< localTypeOf[Seq[_]] =>
502498
val TypeRef(_, _, Seq(elementType)) = t
503499
toCatalystArray(inputObject, elementType)
504500

@@ -716,7 +712,7 @@ object ScalaReflection extends ScalaReflection {
716712
val TypeRef(_, _, Seq(elementType)) = t
717713
val Schema(dataType, nullable) = schemaFor(elementType)
718714
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
719-
case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] =>
715+
case t if t <:< localTypeOf[Seq[_]] =>
720716
val TypeRef(_, _, Seq(elementType)) = t
721717
val Schema(dataType, nullable) = schemaFor(elementType)
722718
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
314314
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
315315
}
316316

317-
test("serialize and deserialize arbitrary java list types") {
318-
import java.util.ArrayList
319-
val arrayListSerializer = serializerFor[ArrayList[Int]](BoundReference(
320-
0, ObjectType(classOf[ArrayList[Int]]), nullable = false))
321-
assert(arrayListSerializer.dataType.head.dataType ==
322-
ArrayType(IntegerType, containsNull = false))
323-
val arrayListDeserializer = deserializerFor[ArrayList[Int]]
324-
assert(arrayListDeserializer.dataType == ObjectType(classOf[ArrayList[_]]))
325-
}
326-
327317
private val dataTypeForComplexData = dataTypeFor[ComplexData]
328318
private val typeOfComplexData = typeOf[ComplexData]
329319

sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
166166
/** @since 2.2.0 */
167167
implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
168168

169-
/** @since 2.2.0 */
170-
implicit def newJavaListEncoder[T <: java.util.List[_] : TypeTag]: Encoder[T] =
171-
ExpressionEncoder()
172-
173169
// Arrays
174170

175171
/** @since 1.6.1 */

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,4 +1399,65 @@ public void testSerializeNull() {
13991399
ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder);
14001400
Assert.assertEquals(beans, ds2.collectAsList());
14011401
}
1402+
1403+
@Test
1404+
public void testSpecificLists() {
1405+
SpecificListsBean bean = new SpecificListsBean();
1406+
ArrayList<Integer> arrayList = new ArrayList<>();
1407+
arrayList.add(1);
1408+
bean.setArrayList(arrayList);
1409+
LinkedList<Integer> linkedList = new LinkedList<>();
1410+
linkedList.add(1);
1411+
bean.setLinkedList(linkedList);
1412+
bean.setList(Collections.singletonList(1));
1413+
List<SpecificListsBean> beans = Collections.singletonList(bean);
1414+
Dataset<SpecificListsBean> dataset =
1415+
spark.createDataset(beans, Encoders.bean(SpecificListsBean.class));
1416+
Assert.assertEquals(beans, dataset.collectAsList());
1417+
}
1418+
1419+
public static class SpecificListsBean implements Serializable {
1420+
private ArrayList<Integer> arrayList;
1421+
private LinkedList<Integer> linkedList;
1422+
private List<Integer> list;
1423+
1424+
public ArrayList<Integer> getArrayList() {
1425+
return arrayList;
1426+
}
1427+
1428+
public void setArrayList(ArrayList<Integer> arrayList) {
1429+
this.arrayList = arrayList;
1430+
}
1431+
1432+
public LinkedList<Integer> getLinkedList() {
1433+
return linkedList;
1434+
}
1435+
1436+
public void setLinkedList(LinkedList<Integer> linkedList) {
1437+
this.linkedList = linkedList;
1438+
}
1439+
1440+
public List<Integer> getList() {
1441+
return list;
1442+
}
1443+
1444+
public void setList(List<Integer> list) {
1445+
this.list = list;
1446+
}
1447+
1448+
@Override
1449+
public boolean equals(Object o) {
1450+
if (this == o) return true;
1451+
if (o == null || getClass() != o.getClass()) return false;
1452+
SpecificListsBean that = (SpecificListsBean) o;
1453+
return Objects.equal(arrayList, that.arrayList) &&
1454+
Objects.equal(linkedList, that.linkedList) &&
1455+
Objects.equal(list, that.list);
1456+
}
1457+
1458+
@Override
1459+
public int hashCode() {
1460+
return Objects.hashCode(arrayList, linkedList, list);
1461+
}
1462+
}
14021463
}

0 commit comments

Comments
 (0)