@@ -52,16 +52,18 @@ def mnl_interaction_dataset(choosers, alternatives, SAMPLE_SIZE,
5252 # SAMPLE_SIZE >= numalts. That may not happen often in
5353 # practical situations but it should be supported
5454 # because a) why not? and b) testing.
55+ alts_idx = np .arange (len (alternatives ))
5556 if SAMPLE_SIZE < numalts :
56- sample = np .random .choice (
57- alternatives .index .values , SAMPLE_SIZE * numchoosers )
57+ sample = np .random .choice (alts_idx , SAMPLE_SIZE * numchoosers )
5858 if chosenalts is not None :
5959 # replace the first row for each chooser with
6060 # the currently chosen alternative.
61- sample [::SAMPLE_SIZE ] = chosenalts
61+ # chosenalts -> integer position
62+ sample [::SAMPLE_SIZE ] = pd .Series (
63+ alts_idx , index = alternatives .index ).loc [chosenalts ].values
6264 else :
6365 assert chosenalts is None # if not sampling, must be simulating
64- sample = np .tile (alternatives . index . values , numchoosers )
66+ sample = np .tile (alts_idx , numchoosers )
6567
6668 if not choosers .index .is_unique :
6769 raise Exception (
@@ -72,7 +74,7 @@ def mnl_interaction_dataset(choosers, alternatives, SAMPLE_SIZE,
7274 "ERROR: alternatives index is not unique, "
7375 "sample will not work correctly" )
7476
75- alts_sample = alternatives .loc [ sample ]
77+ alts_sample = alternatives .take ( sample )
7678 assert len (alts_sample .index ) == SAMPLE_SIZE * len (choosers .index )
7779 alts_sample ['join_index' ] = np .repeat (choosers .index .values , SAMPLE_SIZE )
7880
@@ -84,4 +86,4 @@ def mnl_interaction_dataset(choosers, alternatives, SAMPLE_SIZE,
8486 chosen [:, 0 ] = 1
8587
8688 logger .debug ('finish: compute MNL interaction dataset' )
87- return sample , alts_sample , chosen
89+ return alternatives . index . values [ sample ] , alts_sample , chosen
0 commit comments