File tree Expand file tree Collapse file tree 1 file changed +18
-5
lines changed
Expand file tree Collapse file tree 1 file changed +18
-5
lines changed Original file line number Diff line number Diff line change @@ -998,6 +998,15 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
998998 """
999999 self .cache = cache
10001000
1001+ def set_seed (self , seed : int ):
1002+ """Set the random seed.
1003+
1004+ Args:
1005+ seed: The random seed.
1006+ """
1007+ assert self ._ctx .ctx is not None
1008+ llama_cpp .llama_set_rng_seed (self ._ctx .ctx , seed )
1009+
10011010 def reset (self ):
10021011 """Reset the model state."""
10031012 self .n_tokens = 0
@@ -1318,10 +1327,14 @@ def _create_completion(
13181327 completion_tokens : List [int ] = []
13191328 # Add blank space to start of prompt to match OG llama tokenizer
13201329 prompt_tokens : List [int ] = (
1321- self .tokenize (prompt .encode ("utf-8" ), special = True )
1322- if prompt != ""
1323- else [self .token_bos ()]
1324- ) if isinstance (prompt , str ) else prompt
1330+ (
1331+ self .tokenize (prompt .encode ("utf-8" ), special = True )
1332+ if prompt != ""
1333+ else [self .token_bos ()]
1334+ )
1335+ if isinstance (prompt , str )
1336+ else prompt
1337+ )
13251338 text : bytes = b""
13261339 returned_tokens : int = 0
13271340 stop = (
@@ -1374,7 +1387,7 @@ def _create_completion(
13741387 except KeyError :
13751388 if self .verbose :
13761389 print ("Llama._create_completion: cache miss" , file = sys .stderr )
1377-
1390+
13781391 if seed is not None :
13791392 self ._ctx .set_rng_seed (seed )
13801393
You can’t perform that action at this time.
0 commit comments