@@ -547,16 +547,18 @@ object DecisionTree extends Serializable with Logging {
547547
548548 /**
549549 * Sequential search helper method to find bin for categorical feature in multiclass
550- * classification. Dummy value of 0 used since it is not used in future calculation
550+ * classification. The category is returned since each category can belong to multiple
551+ * splits. The actual left/right child allocation per split is performed in the
552+ * sequential phase of the bin aggregate operation.
551553 */
552- def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
554+ def sequentialBinSearchForCategoricalFeatureInMulticlassClassification (): Int = {
553555 labeledPoint.features(featureIndex).toInt
554556 }
555557
556558 /**
557559 * Sequential search helper method to find bin for categorical feature.
558560 */
559- def sequentialBinSearchForCategoricalFeatureInMultiClassClassification (): Int = {
561+ def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
560562 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
561563 val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
562564 var binIndex = 0
@@ -583,9 +585,9 @@ object DecisionTree extends Serializable with Logging {
583585 // Perform sequential search to find bin for categorical features.
584586 val binIndex = {
585587 if (isMulticlassClassification) {
586- sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
588+ sequentialBinSearchForCategoricalFeatureInMulticlassClassification ()
587589 } else {
588- sequentialBinSearchForCategoricalFeatureInMultiClassClassification ()
590+ sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
589591 }
590592 }
591593 if (binIndex == - 1 ){
@@ -684,7 +686,7 @@ object DecisionTree extends Serializable with Logging {
684686 * @return Array[Double] storing aggregate calculation of size
685687 * 2 * numSplits * numFeatures * numNodes for classification
686688 */
687- def binaryClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
689+ def orderedClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
688690 // Iterate over all nodes.
689691 var nodeIndex = 0
690692 while (nodeIndex < numNodes) {
@@ -716,7 +718,7 @@ object DecisionTree extends Serializable with Logging {
716718 * @return Array[Double] storing aggregate calculation of size
717719 * 2 * numClasses * numSplits * numFeatures * numNodes for classification
718720 */
719- def multiClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
721+ def unorderedClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
720722 // Iterate over all nodes.
721723 var nodeIndex = 0
722724 while (nodeIndex < numNodes) {
@@ -789,9 +791,9 @@ object DecisionTree extends Serializable with Logging {
789791 strategy.algo match {
790792 case Classification =>
791793 if (isMulticlassClassificationWithCategoricalFeatures) {
792- multiClassificationBinSeqOp (arr, agg)
794+ unorderedClassificationBinSeqOp (arr, agg)
793795 } else {
794- binaryClassificationBinSeqOp (arr, agg)
796+ orderedClassificationBinSeqOp (arr, agg)
795797 }
796798 case Regression => regressionBinSeqOp(arr, agg)
797799 }
0 commit comments