Skip to content

Commit eca9d37

Browse files
committed
based on mengxr's suggestions
1 parent 937e54c commit eca9d37

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ object MovieLensALS {
6464
.text(s"use Kryo serialization")
6565
.action((_, c) => c.copy(kryo = true))
6666
opt[Unit]("implicitPrefs")
67-
.text(s"use Implicit Preference")
67+
.text(s"use implicit preference")
6868
.action((_, c) => c.copy(implicitPrefs = true))
6969
arg[String]("<input>")
7070
.required()
@@ -93,6 +93,22 @@ object MovieLensALS {
9393
val ratings = sc.textFile(params.input).map { line =>
9494
val fields = line.split("::")
9595
if (params.implicitPrefs) {
96+
/**
97+
* MovieLens ratings are on a scale of 1-5:
98+
* 5: Must see
99+
* 4: Will enjoy
100+
* 3: It's okay
101+
* 2: Fairly bad
102+
* 1: Awful
103+
* So we should not recommend a movie if the predicted rating is less than 3.
104+
* To map ratings to confidence scores, we use
105+
* 5 -> 2.5, 4 -> 1.5, 3 -> 0.5, 2 -> -0.5, 1 -> -1.5. This mappings means unobserved
106+
* entries are generally between It's okay and Fairly bad.
107+
* The semantics of 0 in this expanded world of non-positive weights
108+
* are "the same as never having interacted at all"
109+
* It's possible that 0 values are ignored when constructing the sparse representation,
110+
* because the 0s are implicit. This would be a problem, at least, a theoretical one.
111+
*/
96112
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5)
97113
} else {
98114
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
@@ -108,8 +124,14 @@ object MovieLensALS {
108124
val splits = ratings.randomSplit(Array(0.8, 0.2))
109125
val training = splits(0).cache()
110126
val test = if (params.implicitPrefs) {
111-
splits(1)
112-
.map(x => Rating(x.user, x.product, if(x.rating >= 0) 1.0 else 0.0))
127+
/**
128+
* 0 means "don't know" and positive values mean "confident that the prediction should be 1".
129+
* Negative values means "confident that the prediction should be 0".
130+
* We have in this case used some kind of weighted RMSE. The weight is the absolute value of
131+
* the confidence. The error is the difference between prediction and either 1 or 0,
132+
* depending on whether r is positive or negative.
133+
*/
134+
splits(1).map(x => Rating(x.user, x.product, if(x.rating > 0) 1.0 else 0.0))
113135
} else {
114136
splits(1)
115137
}.cache()
@@ -127,25 +149,22 @@ object MovieLensALS {
127149
.setImplicitPrefs(params.implicitPrefs)
128150
.run(training)
129151

130-
val rmse = computeRmse(model, test, params)
152+
val rmse = computeRmse(model, test, params.implicitPrefs)
131153

132154
println(s"Test RMSE = $rmse.")
133155

134156
sc.stop()
135157
}
136158

137159
/** Compute RMSE (Root Mean Squared Error). */
138-
def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], params: Params) = {
160+
def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = {
161+
162+
def evalRating(r: Double) =
163+
if (!implicitPrefs) r else if (r > 1.0) 1.0 else if (r < 0.0) 0.0 else r
164+
139165
val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product)))
140-
val predictionsAndRatings = if (params.implicitPrefs) {
141-
predictions.map(x => (
142-
(x.user, x.product),
143-
if (x.rating > 1.0) 1.0 else if (x.rating < 0.0) 0.0 else x.rating
144-
)).join(data.map(x => ((x.user, x.product), x.rating)))
145-
} else {
146-
predictions.map(x => ((x.user, x.product), x.rating))
147-
.join(data.map(x => ((x.user, x.product), x.rating)))
148-
}
149-
math.sqrt(predictionsAndRatings.values.map(x => (x._1 - x._2) * (x._1 - x._2)).mean())
166+
val predictionsAndRatings = predictions.map(x => ((x.user, x.product), evalRating(x.rating)))
167+
.join(data.map(x => ((x.user, x.product), x.rating))).values
168+
math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean())
150169
}
151170
}

0 commit comments

Comments
 (0)