@@ -1292,6 +1292,7 @@ def _create_completion(
12921292 repeat_penalty : float = 1.1 ,
12931293 top_k : int = 40 ,
12941294 stream : bool = False ,
1295+ seed : Optional [int ] = None ,
12951296 tfs_z : float = 1.0 ,
12961297 mirostat_mode : int = 0 ,
12971298 mirostat_tau : float = 5.0 ,
@@ -1367,6 +1368,9 @@ def _create_completion(
13671368 except KeyError :
13681369 if self .verbose :
13691370 print ("Llama._create_completion: cache miss" , file = sys .stderr )
1371+
1372+ if seed is not None :
1373+ self ._ctx .set_rng_seed (seed )
13701374
13711375 finish_reason = "length"
13721376 multibyte_fix = 0
@@ -1750,6 +1754,7 @@ def create_completion(
17501754 repeat_penalty : float = 1.1 ,
17511755 top_k : int = 40 ,
17521756 stream : bool = False ,
1757+ seed : Optional [int ] = None ,
17531758 tfs_z : float = 1.0 ,
17541759 mirostat_mode : int = 0 ,
17551760 mirostat_tau : float = 5.0 ,
@@ -1795,6 +1800,7 @@ def create_completion(
17951800 repeat_penalty = repeat_penalty ,
17961801 top_k = top_k ,
17971802 stream = stream ,
1803+ seed = seed ,
17981804 tfs_z = tfs_z ,
17991805 mirostat_mode = mirostat_mode ,
18001806 mirostat_tau = mirostat_tau ,
@@ -1825,6 +1831,7 @@ def __call__(
18251831 repeat_penalty : float = 1.1 ,
18261832 top_k : int = 40 ,
18271833 stream : bool = False ,
1834+ seed : Optional [int ] = None ,
18281835 tfs_z : float = 1.0 ,
18291836 mirostat_mode : int = 0 ,
18301837 mirostat_tau : float = 5.0 ,
@@ -1870,6 +1877,7 @@ def __call__(
18701877 repeat_penalty = repeat_penalty ,
18711878 top_k = top_k ,
18721879 stream = stream ,
1880+ seed = seed ,
18731881 tfs_z = tfs_z ,
18741882 mirostat_mode = mirostat_mode ,
18751883 mirostat_tau = mirostat_tau ,
@@ -1892,6 +1900,7 @@ def create_chat_completion(
18921900 top_k : int = 40 ,
18931901 stream : bool = False ,
18941902 stop : Optional [Union [str , List [str ]]] = [],
1903+ seed : Optional [int ] = None ,
18951904 max_tokens : int = 256 ,
18961905 presence_penalty : float = 0.0 ,
18971906 frequency_penalty : float = 0.0 ,
@@ -1936,6 +1945,7 @@ def create_chat_completion(
19361945 top_k = top_k ,
19371946 stream = stream ,
19381947 stop = stop ,
1948+ seed = seed ,
19391949 max_tokens = max_tokens ,
19401950 presence_penalty = presence_penalty ,
19411951 frequency_penalty = frequency_penalty ,
0 commit comments