File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
mllib/src/main/scala/org/apache/spark/ml/param Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -24,7 +24,8 @@ import scala.annotation.varargs
2424import scala .collection .mutable
2525
2626import org .apache .spark .annotation .AlphaComponent
27- import org .apache .spark .ml .util .Identifiable
27+ import org .apache .spark .ml .util .{SchemaUtils , Identifiable }
28+ import org .apache .spark .sql .types .{DataType , StructType }
2829
2930/**
3031 * :: AlphaComponent ::
@@ -380,6 +381,18 @@ trait Params extends Identifiable with Serializable {
380381 this
381382 }
382383
384+ /**
385+ * Check whether the given schema contains an input column.
386+ * @param colName Input column name
387+ * @param dataType Input column DataType
388+ */
389+ protected def checkInputColumn (schema : StructType , colName : String , dataType : DataType ): Unit = {
390+ val actualDataType = schema(colName).dataType
391+ SchemaUtils .checkColumnType(schema, colName, dataType)
392+ require(actualDataType.equals(dataType), s " Input column Name: $colName Description: ${getParam(colName)}" )
393+ }
394+
395+
383396 /**
384397 * Gets the default value of a parameter.
385398 */
You can’t perform that action at this time.
0 commit comments