Skip to content

Commit 54da2cb

Browse files
committed
update docs
1 parent 995e88f commit 54da2cb

File tree

2 files changed

+60
-54
lines changed

2 files changed

+60
-54
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
6666
/**
6767
* Param for the power in the variance function of the Tweedie distribution which provides
6868
* the relationship between the variance and mean of the distribution.
69-
* Used only for the Tweedie family.
69+
* Only applicable for the Tweedie family.
7070
* (see <a href="https://en.wikipedia.org/wiki/Tweedie_distribution">
7171
* Tweedie Distribution (Wikipedia)</a>)
7272
* Supported values: 0 and [1, Inf).
@@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7979
final val variancePower: DoubleParam = new DoubleParam(this, "variancePower",
8080
"The power in the variance function of the Tweedie distribution which characterizes " +
8181
"the relationship between the variance and mean of the distribution. " +
82-
"Used only for the Tweedie family. Supported values: 0 and [1, Inf).",
82+
"Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).",
8383
(x: Double) => x >= 1.0 || x == 0.0)
8484

8585
/** @group getParam */
@@ -106,17 +106,15 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
106106
def getLink: String = $(link)
107107

108108
/**
109-
* Param for the index in the power link function. This is used to specify the link function
110-
* in the Tweedie family.
109+
* Param for the index in the power link function. Only applicable for the Tweedie family.
111110
* Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt
112111
* link, respectively.
113112
*
114113
* @group param
115114
*/
116115
@Since("2.2.0")
117116
final val linkPower: DoubleParam = new DoubleParam(this, "linkPower",
118-
"The index in the power link function. This is used to specify the link function in the " +
119-
"Tweedie family.")
117+
"The index in the power link function. Only applicable for the Tweedie family.")
120118

121119
/** @group getParam */
122120
@Since("2.2.0")
@@ -148,12 +146,15 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
148146
schema: StructType,
149147
fitting: Boolean,
150148
featuresDataType: DataType): StructType = {
151-
if ($(family) == "tweedie") {
149+
if ($(family).toLowerCase == "tweedie") {
152150
if (isSet(link)) {
153151
logWarning("When family is tweedie, use param linkPower to specify link function. " +
154152
"Setting param link will take no effect.")
155153
}
156154
} else {
155+
if (isSet(variancePower)) {
156+
logWarning("When family is not tweedie, setting param variancePower will take no effect.")
157+
}
157158
if (isSet(linkPower)) {
158159
logWarning("When family is not tweedie, use param link to specify link function. " +
159160
"Setting param linkPower will take no effect.")
@@ -381,8 +382,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
381382
override def load(path: String): GeneralizedLinearRegression = super.load(path)
382383

383384
/**
384-
* Set of family and link pairs that GeneralizedLinearRegression supports.
385-
* Tweedie family is specified through linkPower.
385+
* Set of family (except for tweedie) and link pairs that GeneralizedLinearRegression supports.
386+
* The link function of the Tweedie family is specified through param linkPower.
386387
*/
387388
private[regression] lazy val supportedFamilyAndLinkPairs = Set(
388389
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
@@ -453,8 +454,9 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
453454
*/
454455
def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = {
455456
val familyObj = Family.fromParams(params)
456-
val linkObj = if ((params.getFamily != "tweedie" && params.isDefined(params.link)) ||
457-
(params.getFamily == "tweedie" && params.isDefined(params.linkPower))) {
457+
val linkObj = if ((params.getFamily.toLowerCase != "tweedie" &&
458+
params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" &&
459+
params.isSet(params.linkPower))) {
458460
Link.fromParams(params)
459461
} else {
460462
familyObj.defaultLink
@@ -503,9 +505,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
503505
private[regression] object Family {
504506

505507
/**
506-
* Gets the [[Family]] object based on family and variancePower.
507-
* 1) retrieve object based on family name
508-
* 2) if family name is tweedie, retrieve object based on variancePower
508+
* Gets the [[Family]] object based on param family and variancePower.
509+
* If param family is set with "gaussian", "binomial", "poisson" or "gamma",
510+
* return the corresponding object directly; otherwise, construct a Tweedie object
511+
* according to variancePower.
509512
*
510513
* @param params the parameter map containing family name and variance power
511514
*/
@@ -779,11 +782,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
779782
private[regression] object Link {
780783

781784
/**
782-
* Gets the [[Link]] object based on link or linkPower.
783-
* 1) if family is "tweedie", retrieve object using linkPower
784-
* 2) otherwise, retrieve object based on link name
785+
* Gets the [[Link]] object based on param family, link and linkPower.
786+
* If param family is set with "tweedie", return or construct link function object
787+
* according to linkPower; otherwise, return link function object according to link.
785788
*
786-
* @param params the parameter map containing link and link power
789+
* @param params the parameter map containing family, link and linkPower
787790
*/
788791
def fromParams(params: GeneralizedLinearRegressionBase): Link = {
789792
if (params.getFamily.toLowerCase == "tweedie") {
@@ -1244,7 +1247,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
12441247
*/
12451248
@Since("2.0.0")
12461249
lazy val dispersion: Double = if (
1247-
model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
1250+
model.getFamily.toLowerCase == Binomial.name ||
1251+
model.getFamily.toLowerCase == Poisson.name) {
12481252
1.0
12491253
} else {
12501254
val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
@@ -1347,7 +1351,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
13471351
@Since("2.0.0")
13481352
lazy val pValues: Array[Double] = {
13491353
if (isNormalSolver) {
1350-
if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
1354+
if (model.getFamily.toLowerCase == Binomial.name ||
1355+
model.getFamily.toLowerCase == Poisson.name) {
13511356
tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
13521357
} else {
13531358
tValues.map { x =>

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -637,33 +637,33 @@ class GeneralizedLinearRegressionSuite
637637
import GeneralizedLinearRegression._
638638

639639
var idx = 0
640-
for (fitIntercept <- Seq(false, true); linkPower <- Seq(0.0, 1.0, -1.0)) {
641-
for (variancePower <- Seq(1.6, 2.5)) {
642-
val trainer = new GeneralizedLinearRegression().setFamily("tweedie")
643-
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
644-
.setVariancePower(variancePower).setLinkPower(linkPower)
645-
val model = trainer.fit(datasetTweedie)
646-
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
647-
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with tweedie family, " +
648-
s"linkPower = $linkPower, fitIntercept = $fitIntercept " +
649-
s"and variancePower = $variancePower.")
650-
651-
val familyLink = FamilyAndLink(trainer)
652-
model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
653-
.foreach {
654-
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
655-
val eta = BLAS.dot(features, model.coefficients) + model.intercept
656-
val prediction2 = familyLink.fitted(eta)
657-
val linkPrediction2 = eta
658-
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
659-
s"tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
660-
s"and variancePower = $variancePower.")
661-
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
662-
s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
663-
s"and variancePower = $variancePower.")
664-
}
665-
idx += 1
666-
}
640+
for (fitIntercept <- Seq(false, true);
641+
linkPower <- Seq(0.0, 1.0, -1.0);
642+
variancePower <- Seq(1.6, 2.5)) {
643+
val trainer = new GeneralizedLinearRegression().setFamily("tweedie")
644+
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
645+
.setVariancePower(variancePower).setLinkPower(linkPower)
646+
val model = trainer.fit(datasetTweedie)
647+
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
648+
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with tweedie family, " +
649+
s"linkPower = $linkPower, fitIntercept = $fitIntercept " +
650+
s"and variancePower = $variancePower.")
651+
652+
val familyLink = FamilyAndLink(trainer)
653+
model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
654+
.foreach {
655+
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
656+
val eta = BLAS.dot(features, model.coefficients) + model.intercept
657+
val prediction2 = familyLink.fitted(eta)
658+
val linkPrediction2 = eta
659+
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
660+
s"tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
661+
s"and variancePower = $variancePower.")
662+
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
663+
s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
664+
s"and variancePower = $variancePower.")
665+
}
666+
idx += 1
667667
}
668668
}
669669

@@ -1228,8 +1228,9 @@ class GeneralizedLinearRegressionSuite
12281228
1.0, 3.0, 2.0, 1.0,
12291229
0.0, 4.0, 3.0, 3.0), 4, 4, byrow = TRUE))
12301230
1231-
f <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2,
1231+
model <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2,
12321232
family = tweedie(var.power = 1.6, link.power = 0))
1233+
summary(model)
12331234
12341235
Deviance Residuals:
12351236
1 2 3 4
@@ -1249,14 +1250,14 @@ class GeneralizedLinearRegressionSuite
12491250
Number of Fisher Scoring iterations: 11
12501251
12511252
residuals(model, type="pearson")
1252-
1 2 3 4
1253-
0.01873881 -0.01312994 0.04190280 -0.10332690
1253+
1 2 3 4
1254+
0.7383616 -0.0509458 2.2348337 -1.4552090
12541255
residuals(model, type="working")
1255-
1 2 3 4
1256-
0.018067789 -0.003326304 0.038720616 -0.824070943
1256+
1 2 3 4
1257+
0.83354150 -0.04103552 1.55676369 -1.00000000
12571258
residuals(model, type="response")
1258-
1 2 3 4
1259-
0.018067789 -0.003326304 0.038720616 -0.824070943
1259+
1 2 3 4
1260+
0.45460738 -0.02139574 0.60888055 -0.20392801
12601261
*/
12611262
val datasetWithWeight = Seq(
12621263
Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)),

0 commit comments

Comments
 (0)