@@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
6666 /**
6767 * Param for the power in the variance function of the Tweedie distribution which provides
6868 * the relationship between the variance and mean of the distribution.
69- * Used only for the Tweedie family.
69+ * Only applicable for the Tweedie family.
7070 * (see <a href="https://en.wikipedia.org/wiki/Tweedie_distribution">
7171 * Tweedie Distribution (Wikipedia)</a>)
7272 * Supported values: 0 and [1, Inf).
@@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7979 final val variancePower : DoubleParam = new DoubleParam (this , " variancePower" ,
8080 " The power in the variance function of the Tweedie distribution which characterizes " +
8181 " the relationship between the variance and mean of the distribution. " +
82- " Used only for the Tweedie family. Supported values: 0 and [1, Inf)." ,
82+ " Only applicable for the Tweedie family. Supported values: 0 and [1, Inf)." ,
8383 (x : Double ) => x >= 1.0 || x == 0.0 )
8484
8585 /** @group getParam */
@@ -106,17 +106,15 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
106106 def getLink : String = $(link)
107107
108108 /**
109- * Param for the index in the power link function. This is used to specify the link function
110- * in the Tweedie family.
109+ * Param for the index in the power link function. Only applicable for the Tweedie family.
111110 * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt
112111 * link, respectively.
113112 *
114113 * @group param
115114 */
116115 @ Since (" 2.2.0" )
117116 final val linkPower : DoubleParam = new DoubleParam (this , " linkPower" ,
118- " The index in the power link function. This is used to specify the link function in the " +
119- " Tweedie family." )
117+ " The index in the power link function. Only applicable for the Tweedie family." )
120118
121119 /** @group getParam */
122120 @ Since (" 2.2.0" )
@@ -148,12 +146,15 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
148146 schema : StructType ,
149147 fitting : Boolean ,
150148 featuresDataType : DataType ): StructType = {
151- if ($(family) == " tweedie" ) {
149+ if ($(family).toLowerCase == " tweedie" ) {
152150 if (isSet(link)) {
153151 logWarning(" When family is tweedie, use param linkPower to specify link function. " +
154152 " Setting param link will take no effect." )
155153 }
156154 } else {
155+ if (isSet(variancePower)) {
156+ logWarning(" When family is not tweedie, setting param variancePower will take no effect." )
157+ }
157158 if (isSet(linkPower)) {
158159 logWarning(" When family is not tweedie, use param link to specify link function. " +
159160 " Setting param linkPower will take no effect." )
@@ -381,8 +382,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
381382 override def load (path : String ): GeneralizedLinearRegression = super .load(path)
382383
383384 /**
384- * Set of family and link pairs that GeneralizedLinearRegression supports.
385- * Tweedie family is specified through linkPower.
385+ * Set of family (except for tweedie) and link pairs that GeneralizedLinearRegression supports.
386+ * The link function of the Tweedie family is specified through param linkPower.
386387 */
387388 private [regression] lazy val supportedFamilyAndLinkPairs = Set (
388389 Gaussian -> Identity , Gaussian -> Log , Gaussian -> Inverse ,
@@ -453,8 +454,9 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
453454 */
454455 def apply (params : GeneralizedLinearRegressionBase ): FamilyAndLink = {
455456 val familyObj = Family .fromParams(params)
456- val linkObj = if ((params.getFamily != " tweedie" && params.isDefined(params.link)) ||
457- (params.getFamily == " tweedie" && params.isDefined(params.linkPower))) {
457+ val linkObj = if ((params.getFamily.toLowerCase != " tweedie" &&
458+ params.isSet(params.link)) || (params.getFamily.toLowerCase == " tweedie" &&
459+ params.isSet(params.linkPower))) {
458460 Link .fromParams(params)
459461 } else {
460462 familyObj.defaultLink
@@ -503,9 +505,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
503505 private [regression] object Family {
504506
505507 /**
506- * Gets the [[Family ]] object based on family and variancePower.
507- * 1) retrieve object based on family name
508- * 2) if family name is tweedie, retrieve object based on variancePower
508+ * Gets the [[Family ]] object based on param family and variancePower.
509+ * If param family is set with "gaussian", "binomial", "poisson" or "gamma",
510+ * return the corresponding object directly; otherwise, construct a Tweedie object
511+ * according to variancePower.
509512 *
510513 * @param params the parameter map containing family name and variance power
511514 */
@@ -779,11 +782,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
779782 private [regression] object Link {
780783
781784 /**
782- * Gets the [[Link ]] object based on link or linkPower.
783- * 1) if family is "tweedie", retrieve object using linkPower
784- * 2) otherwise, retrieve object based on link name
785+ * Gets the [[Link ]] object based on param family, link and linkPower.
786+ * If param family is set with "tweedie", return or construct link function object
787+ * according to linkPower; otherwise, return link function object according to link.
785788 *
786- * @param params the parameter map containing link and link power
789+ * @param params the parameter map containing family, link and linkPower
787790 */
788791 def fromParams (params : GeneralizedLinearRegressionBase ): Link = {
789792 if (params.getFamily.toLowerCase == " tweedie" ) {
@@ -1244,7 +1247,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
12441247 */
12451248 @ Since (" 2.0.0" )
12461249 lazy val dispersion : Double = if (
1247- model.getFamily == Binomial .name || model.getFamily == Poisson .name) {
1250+ model.getFamily.toLowerCase == Binomial .name ||
1251+ model.getFamily.toLowerCase == Poisson .name) {
12481252 1.0
12491253 } else {
12501254 val rss = pearsonResiduals.agg(sum(pow(col(" pearsonResiduals" ), 2.0 ))).first().getDouble(0 )
@@ -1347,7 +1351,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
13471351 @ Since (" 2.0.0" )
13481352 lazy val pValues : Array [Double ] = {
13491353 if (isNormalSolver) {
1350- if (model.getFamily == Binomial .name || model.getFamily == Poisson .name) {
1354+ if (model.getFamily.toLowerCase == Binomial .name ||
1355+ model.getFamily.toLowerCase == Poisson .name) {
13511356 tValues.map { x => 2.0 * (1.0 - dist.Gaussian (0.0 , 1.0 ).cdf(math.abs(x))) }
13521357 } else {
13531358 tValues.map { x =>
0 commit comments