File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed
Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -451,7 +451,7 @@ def free_lora_adapter():
451451 self .n_tokens = 0
452452 self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
453453 self .scores : npt .NDArray [np .single ] = np .ndarray (
454- (n_ctx , self ._n_vocab ), dtype = np .single
454+ (n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
455455 )
456456
457457 self ._mirostat_mu = ctypes .c_float (
@@ -648,12 +648,14 @@ def eval(self, tokens: Sequence[int]):
648648 )
649649 self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[::] = logits
650650 else :
651- rows = 1
652- cols = self ._n_vocab
653- logits = np .ctypeslib .as_array (
654- self ._ctx .get_logits (), shape = (rows * cols ,)
655- )
656- self .scores [n_past + n_tokens - 1 , :].reshape (- 1 )[::] = logits
651+ # rows = 1
652+ # cols = self._n_vocab
653+ # logits = np.ctypeslib.as_array(
654+ # self._ctx.get_logits(), shape=(rows * cols,)
655+ # )
656+ # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
657+ # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
658+ pass
657659 # Update n_tokens
658660 self .n_tokens += n_tokens
659661
You can’t perform that action at this time.
0 commit comments