Skip to content

Commit b9f3071

Browse files
committed
Java API for applySchema.
1 parent 1c9f33c commit b9f3071

File tree

11 files changed

+348
-37
lines changed

11 files changed

+348
-37
lines changed

sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,20 @@ public static StructField createStructField(String name, DataType dataType, bool
134134
}
135135

136136
/**
137-
* Creates a StructType with the given StructFields ({@code fields}).
137+
* Creates a StructType with the given list of StructFields ({@code fields}).
138138
* @param fields
139139
* @return
140140
*/
141141
public static StructType createStructType(List<StructField> fields) {
142+
return createStructType(fields.toArray(new StructField[0]));
143+
}
144+
145+
/**
146+
* Creates a StructType with the given StructField array ({@code fields}).
147+
* @param fields
148+
* @return
149+
*/
150+
public static StructType createStructType(StructField[] fields) {
142151
if (fields == null) {
143152
throw new IllegalArgumentException("fields should not be null.");
144153
}
@@ -151,11 +160,11 @@ public static StructType createStructType(List<StructField> fields) {
151160

152161
distinctNames.add(field.getName());
153162
}
154-
if (distinctNames.size() != fields.size()) {
155-
throw new IllegalArgumentException(
156-
"fields should have distinct names.");
163+
if (distinctNames.size() != fields.length) {
164+
throw new IllegalArgumentException("fields should have distinct names.");
157165
}
158166

159167
return new StructType(fields);
160168
}
169+
161170
}

sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
/**
2424
* The data type representing Rows.
25-
* A StructType object comprises a List of StructFields.
25+
* A StructType object comprises an array of StructFields.
2626
*/
2727
public class StructType extends DataType {
2828
private StructField[] fields;
2929

30-
protected StructType(List<StructField> fields) {
31-
this.fields = fields.toArray(new StructField[0]);
30+
protected StructType(StructField[] fields) {
31+
this.fields = fields;
3232
}
3333

3434
public StructField[] getFields() {

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ package org.apache.spark.sql
1919

2020
import scala.language.implicitConversions
2121
import scala.reflect.runtime.universe.TypeTag
22+
import scala.collection.JavaConverters._
2223

2324
import org.apache.hadoop.conf.Configuration
2425

2526
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
2627
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField}
2729
import org.apache.spark.sql.catalyst.analysis._
2830
import org.apache.spark.sql.catalyst.expressions._
2931
import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
@@ -420,4 +422,97 @@ class SQLContext(@transient val sparkContext: SparkContext)
420422
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
421423
}
422424

425+
/**
426+
* Returns the equivalent StructField in Scala for the given StructField in Java.
427+
*/
428+
protected def asJavaStructField(scalaStructField: StructField): JStructField = {
429+
org.apache.spark.sql.api.java.types.DataType.createStructField(
430+
scalaStructField.name,
431+
asJavaDataType(scalaStructField.dataType),
432+
scalaStructField.nullable)
433+
}
434+
435+
/**
436+
* Returns the equivalent DataType in Java for the given DataType in Scala.
437+
*/
438+
protected[sql] def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
439+
case StringType =>
440+
org.apache.spark.sql.api.java.types.DataType.StringType
441+
case BinaryType =>
442+
org.apache.spark.sql.api.java.types.DataType.BinaryType
443+
case BooleanType =>
444+
org.apache.spark.sql.api.java.types.DataType.BooleanType
445+
case TimestampType =>
446+
org.apache.spark.sql.api.java.types.DataType.TimestampType
447+
case DecimalType =>
448+
org.apache.spark.sql.api.java.types.DataType.DecimalType
449+
case DoubleType =>
450+
org.apache.spark.sql.api.java.types.DataType.DoubleType
451+
case FloatType =>
452+
org.apache.spark.sql.api.java.types.DataType.FloatType
453+
case ByteType =>
454+
org.apache.spark.sql.api.java.types.DataType.ByteType
455+
case IntegerType =>
456+
org.apache.spark.sql.api.java.types.DataType.IntegerType
457+
case LongType =>
458+
org.apache.spark.sql.api.java.types.DataType.LongType
459+
case ShortType =>
460+
org.apache.spark.sql.api.java.types.DataType.ShortType
461+
462+
case arrayType: ArrayType =>
463+
org.apache.spark.sql.api.java.types.DataType.createArrayType(
464+
asJavaDataType(arrayType.elementType), arrayType.containsNull)
465+
case mapType: MapType =>
466+
org.apache.spark.sql.api.java.types.DataType.createMapType(
467+
asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType))
468+
case structType: StructType =>
469+
org.apache.spark.sql.api.java.types.DataType.createStructType(
470+
structType.fields.map(asJavaStructField).asJava)
471+
}
472+
473+
/**
474+
* Returns the equivalent StructField in Scala for the given StructField in Java.
475+
*/
476+
protected def asScalaStructField(javaStructField: JStructField): StructField = {
477+
StructField(
478+
javaStructField.getName,
479+
asScalaDataType(javaStructField.getDataType),
480+
javaStructField.isNullable)
481+
}
482+
483+
/**
484+
* Returns the equivalent DataType in Scala for the given DataType in Java.
485+
*/
486+
protected[sql] def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
487+
case stringType: org.apache.spark.sql.api.java.types.StringType =>
488+
StringType
489+
case binaryType: org.apache.spark.sql.api.java.types.BinaryType =>
490+
BinaryType
491+
case booleanType: org.apache.spark.sql.api.java.types.BooleanType =>
492+
BooleanType
493+
case timestampType: org.apache.spark.sql.api.java.types.TimestampType =>
494+
TimestampType
495+
case decimalType: org.apache.spark.sql.api.java.types.DecimalType =>
496+
DecimalType
497+
case doubleType: org.apache.spark.sql.api.java.types.DoubleType =>
498+
DoubleType
499+
case floatType: org.apache.spark.sql.api.java.types.FloatType =>
500+
FloatType
501+
case byteType: org.apache.spark.sql.api.java.types.ByteType =>
502+
ByteType
503+
case integerType: org.apache.spark.sql.api.java.types.IntegerType =>
504+
IntegerType
505+
case longType: org.apache.spark.sql.api.java.types.LongType =>
506+
LongType
507+
case shortType: org.apache.spark.sql.api.java.types.ShortType =>
508+
ShortType
509+
510+
case arrayType: org.apache.spark.sql.api.java.types.ArrayType =>
511+
ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull)
512+
case mapType: org.apache.spark.sql.api.java.types.MapType =>
513+
MapType(asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType))
514+
case structType: org.apache.spark.sql.api.java.types.StructType =>
515+
StructType(structType.getFields.map(asScalaStructField))
516+
}
517+
423518
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class SchemaRDD(
119119
override protected def getDependencies: Seq[Dependency[_]] =
120120
List(new OneToOneDependency(queryExecution.toRdd))
121121

122+
/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
123+
*
124+
* @group schema
125+
*/
126+
def schema: StructType = queryExecution.analyzed.schema
122127

123128
// =======================================================================
124129
// Query DSL

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,11 @@ private[sql] trait SchemaRDDLike {
123123
def saveAsTable(tableName: String): Unit =
124124
sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd
125125

126-
/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
127-
*
128-
* @group schema
129-
*/
130-
def schema: StructType = queryExecution.analyzed.schema
131-
132126
/** Returns the schema as a string in the tree format.
133127
*
134128
* @group schema
135129
*/
136-
def schemaString: String = schema.treeString
130+
def schemaString: String = baseSchemaRDD.schema.treeString
137131

138132
/** Prints out the schema.
139133
*

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ package org.apache.spark.sql.api.java
1919

2020
import java.beans.Introspector
2121

22+
import scala.collection.JavaConverters._
23+
2224
import org.apache.hadoop.conf.Configuration
2325

24-
import org.apache.spark.annotation.Experimental
26+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
2527
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
28+
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructType => JStructType}
29+
import org.apache.spark.sql.api.java.types.{StructField => JStructField}
2630
import org.apache.spark.sql.json.JsonRDD
27-
import org.apache.spark.sql.SQLContext
31+
import org.apache.spark.sql._
2832
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
29-
import org.apache.spark.sql.catalyst.types._
3033
import org.apache.spark.sql.parquet.ParquetRelation
3134
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
3235
import org.apache.spark.util.Utils
@@ -95,6 +98,20 @@ class JavaSQLContext(val sqlContext: SQLContext) {
9598
new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
9699
}
97100

101+
/**
102+
* :: DeveloperApi ::
103+
* Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD.
104+
* It is important to make sure that the structure of every Row of the provided RDD matches the
105+
* provided schema. Otherwise, there will be runtime exception.
106+
*/
107+
@DeveloperApi
108+
def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = {
109+
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
110+
val scalaSchema = sqlContext.asScalaDataType(schema).asInstanceOf[StructType]
111+
val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))
112+
new JavaSchemaRDD(sqlContext, logicalPlan)
113+
}
114+
98115
/**
99116
* Loads a parquet file, returning the result as a [[JavaSchemaRDD]].
100117
*/
@@ -104,26 +121,45 @@ class JavaSQLContext(val sqlContext: SQLContext) {
104121
ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration)))
105122

106123
/**
107-
* Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
124+
* Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD.
108125
* It goes through the entire dataset once to determine the schema.
109-
*
110-
* @group userf
111126
*/
112127
def jsonFile(path: String): JavaSchemaRDD =
113128
jsonRDD(sqlContext.sparkContext.textFile(path))
114129

130+
/**
131+
* :: Experimental ::
132+
* Loads a JSON file (one object per line) and applies the given schema,
133+
* returning the result as a JavaSchemaRDD.
134+
*/
135+
@Experimental
136+
def jsonFile(path: String, schema: JStructType): JavaSchemaRDD =
137+
jsonRDD(sqlContext.sparkContext.textFile(path), schema)
138+
115139
/**
116140
* Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
117-
* [[JavaSchemaRDD]].
141+
* [JavaSchemaRDD.
118142
* It goes through the entire dataset once to determine the schema.
119-
*
120-
* @group userf
121143
*/
122144
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
123-
val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))
124-
val logicalPlan =
125-
sqlContext.makeCustomRDDScan[String](json, schema, JsonRDD.jsonStringToRow(schema, _))
145+
val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
146+
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
147+
val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))
148+
new JavaSchemaRDD(sqlContext, logicalPlan)
149+
}
126150

151+
/**
152+
* :: Experimental ::
153+
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
154+
* returning the result as a JavaSchemaRDD.
155+
*/
156+
@Experimental
157+
def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = {
158+
val appliedScalaSchema =
159+
Option(sqlContext.asScalaDataType(schema)).getOrElse(
160+
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType]
161+
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
162+
val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))
127163
new JavaSchemaRDD(sqlContext, logicalPlan)
128164
}
129165

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.{List => JList}
2222
import org.apache.spark.Partitioner
2323
import org.apache.spark.api.java.{JavaRDDLike, JavaRDD}
2424
import org.apache.spark.api.java.function.{Function => JFunction}
25+
import org.apache.spark.sql.api.java.types.StructType
2526
import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike}
2627
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2728
import org.apache.spark.rdd.RDD
@@ -53,6 +54,10 @@ class JavaSchemaRDD(
5354

5455
override def toString: String = baseSchemaRDD.toString
5556

57+
/** Returns the schema of this JavaSchemaRDD (represented by a StructType). */
58+
def schema: StructType =
59+
sqlContext.asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType]
60+
5661
// =======================================================================
5762
// Base RDD functions that do NOT change schema
5863
// =======================================================================

0 commit comments

Comments
 (0)