@@ -671,58 +671,45 @@ class ALSSuite
671671 .setItemCol(" item" )
672672 }
673673
674- test(" recommendForAllUsers with k < num_items" ) {
675- val topItems = getALSModel.recommendForAllUsers(2 )
676- assert(topItems.count() == 3 )
677- assert(topItems.columns.contains(" user" ))
678-
679- val expected = Map (
680- 0 -> Array ((3 , 54f ), (4 , 44f )),
681- 1 -> Array ((3 , 39f ), (5 , 33f )),
682- 2 -> Array ((3 , 51f ), (5 , 45f ))
683- )
684- checkRecommendations(topItems, expected, " item" )
685- }
686-
687- test(" recommendForAllUsers with k = num_items" ) {
688- val topItems = getALSModel.recommendForAllUsers(4 )
689- assert(topItems.count() == 3 )
690- assert(topItems.columns.contains(" user" ))
691-
674+ test(" recommendForAllUsers with k <, = and > num_items" ) {
675+ val model = getALSModel
676+ val numUsers = model.userFactors.count
677+ val numItems = model.itemFactors.count
692678 val expected = Map (
693679 0 -> Array ((3 , 54f ), (4 , 44f ), (5 , 42f ), (6 , 28f )),
694680 1 -> Array ((3 , 39f ), (5 , 33f ), (4 , 26f ), (6 , 16f )),
695681 2 -> Array ((3 , 51f ), (5 , 45f ), (4 , 30f ), (6 , 18f ))
696682 )
697- checkRecommendations(topItems, expected, " item" )
698- }
699683
700- test(" recommendForAllItems with k < num_users" ) {
701- val topUsers = getALSModel.recommendForAllItems(2 )
702- assert(topUsers.count() == 4 )
703- assert(topUsers.columns.contains(" item" ))
704-
705- val expected = Map (
706- 3 -> Array ((0 , 54f ), (2 , 51f )),
707- 4 -> Array ((0 , 44f ), (2 , 30f )),
708- 5 -> Array ((2 , 45f ), (0 , 42f )),
709- 6 -> Array ((0 , 28f ), (2 , 18f ))
710- )
711- checkRecommendations(topUsers, expected, " user" )
684+ Seq (2 , 4 , 6 ).foreach { k =>
685+ val n = math.min(k, numItems).toInt
686+ val expectedUpToN = expected.mapValues(_.slice(0 , n))
687+ val topItems = model.recommendForAllUsers(k)
688+ assert(topItems.count() == numUsers)
689+ assert(topItems.columns.contains(" user" ))
690+ checkRecommendations(topItems, expectedUpToN, " item" )
691+ }
712692 }
713693
714- test(" recommendForAllItems with k = num_users" ) {
715- val topUsers = getALSModel.recommendForAllItems(3 )
716- assert(topUsers.count() == 4 )
717- assert(topUsers.columns.contains(" item" ))
718-
694+ test(" recommendForAllItems with k <, = and > num_users" ) {
695+ val model = getALSModel
696+ val numUsers = model.userFactors.count
697+ val numItems = model.itemFactors.count
719698 val expected = Map (
720699 3 -> Array ((0 , 54f ), (2 , 51f ), (1 , 39f )),
721700 4 -> Array ((0 , 44f ), (2 , 30f ), (1 , 26f )),
722701 5 -> Array ((2 , 45f ), (0 , 42f ), (1 , 33f )),
723702 6 -> Array ((0 , 28f ), (2 , 18f ), (1 , 16f ))
724703 )
725- checkRecommendations(topUsers, expected, " user" )
704+
705+ Seq (2 , 3 , 4 ).foreach { k =>
706+ val n = math.min(k, numUsers).toInt
707+ val expectedUpToN = expected.mapValues(_.slice(0 , n))
708+ val topUsers = getALSModel.recommendForAllItems(k)
709+ assert(topUsers.count() == numItems)
710+ assert(topUsers.columns.contains(" item" ))
711+ checkRecommendations(topUsers, expectedUpToN, " user" )
712+ }
726713 }
727714
728715 private def checkRecommendations (
0 commit comments