Skip to content

Commit 21c0ac5

Browse files
author
Nick Pentreath
committed
Consolidate test cases and add k > num recommendables case
1 parent 6b9e49d commit 21c0ac5

File tree

1 file changed

+25
-38
lines changed

1 file changed

+25
-38
lines changed

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -671,58 +671,45 @@ class ALSSuite
671671
.setItemCol("item")
672672
}
673673

674-
test("recommendForAllUsers with k < num_items") {
675-
val topItems = getALSModel.recommendForAllUsers(2)
676-
assert(topItems.count() == 3)
677-
assert(topItems.columns.contains("user"))
678-
679-
val expected = Map(
680-
0 -> Array((3, 54f), (4, 44f)),
681-
1 -> Array((3, 39f), (5, 33f)),
682-
2 -> Array((3, 51f), (5, 45f))
683-
)
684-
checkRecommendations(topItems, expected, "item")
685-
}
686-
687-
test("recommendForAllUsers with k = num_items") {
688-
val topItems = getALSModel.recommendForAllUsers(4)
689-
assert(topItems.count() == 3)
690-
assert(topItems.columns.contains("user"))
691-
674+
test("recommendForAllUsers with k <, = and > num_items") {
675+
val model = getALSModel
676+
val numUsers = model.userFactors.count
677+
val numItems = model.itemFactors.count
692678
val expected = Map(
693679
0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
694680
1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
695681
2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
696682
)
697-
checkRecommendations(topItems, expected, "item")
698-
}
699683

700-
test("recommendForAllItems with k < num_users") {
701-
val topUsers = getALSModel.recommendForAllItems(2)
702-
assert(topUsers.count() == 4)
703-
assert(topUsers.columns.contains("item"))
704-
705-
val expected = Map(
706-
3 -> Array((0, 54f), (2, 51f)),
707-
4 -> Array((0, 44f), (2, 30f)),
708-
5 -> Array((2, 45f), (0, 42f)),
709-
6 -> Array((0, 28f), (2, 18f))
710-
)
711-
checkRecommendations(topUsers, expected, "user")
684+
Seq(2, 4, 6).foreach { k =>
685+
val n = math.min(k, numItems).toInt
686+
val expectedUpToN = expected.mapValues(_.slice(0, n))
687+
val topItems = model.recommendForAllUsers(k)
688+
assert(topItems.count() == numUsers)
689+
assert(topItems.columns.contains("user"))
690+
checkRecommendations(topItems, expectedUpToN, "item")
691+
}
712692
}
713693

714-
test("recommendForAllItems with k = num_users") {
715-
val topUsers = getALSModel.recommendForAllItems(3)
716-
assert(topUsers.count() == 4)
717-
assert(topUsers.columns.contains("item"))
718-
694+
test("recommendForAllItems with k <, = and > num_users") {
695+
val model = getALSModel
696+
val numUsers = model.userFactors.count
697+
val numItems = model.itemFactors.count
719698
val expected = Map(
720699
3 -> Array((0, 54f), (2, 51f), (1, 39f)),
721700
4 -> Array((0, 44f), (2, 30f), (1, 26f)),
722701
5 -> Array((2, 45f), (0, 42f), (1, 33f)),
723702
6 -> Array((0, 28f), (2, 18f), (1, 16f))
724703
)
725-
checkRecommendations(topUsers, expected, "user")
704+
705+
Seq(2, 3, 4).foreach { k =>
706+
val n = math.min(k, numUsers).toInt
707+
val expectedUpToN = expected.mapValues(_.slice(0, n))
708+
val topUsers = getALSModel.recommendForAllItems(k)
709+
assert(topUsers.count() == numItems)
710+
assert(topUsers.columns.contains("item"))
711+
checkRecommendations(topUsers, expectedUpToN, "user")
712+
}
726713
}
727714

728715
private def checkRecommendations(

0 commit comments

Comments
 (0)