@@ -1019,6 +1019,7 @@ def sample(
10191019 grammar : Optional [LlamaGrammar ] = None , # optional BNF-like grammar to constrain sampling
10201020 grammar_lazy : bool = False ,
10211021 idx : Optional [int ] = None ,
1022+ seed : Optional [int ] = None ,
10221023 ):
10231024 """Sample a token from the model.
10241025 Returns:
@@ -1040,6 +1041,7 @@ def sample(
10401041 temp = temp ,
10411042 top_n_sigma = top_n_sigma ,
10421043 min_keep = min_keep ,
1044+ seed = seed if seed is not None else self ._seed ,
10431045
10441046 # Dynamic Temp
10451047 dynatemp_range = dynatemp_range ,
@@ -1146,7 +1148,8 @@ def generate(
11461148 logits_processor : Optional [LogitsProcessorList ] = None ,
11471149 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
11481150 grammar : Optional [LlamaGrammar ] = None ,
1149- grammar_lazy :bool = False ,
1151+ grammar_lazy : bool = False ,
1152+ seed : Optional [int ] = None ,
11501153 ) -> Generator [int , Optional [Sequence [int ]], None ]:
11511154 """Create a generator of tokens from a prompt.
11521155
@@ -1302,6 +1305,7 @@ def generate(
13021305 logit_bias = self ._convert_logit_bias (logit_bias ),
13031306 grammar = grammar ._grammar if grammar else "" ,
13041307 grammar_lazy = grammar_lazy ,
1308+ seed = seed if seed is not None else self ._seed ,
13051309 )
13061310
13071311 if logits_processor :
@@ -1635,7 +1639,6 @@ def _create_completion(
16351639 dynatemp_exponent : float = 1.0 ,
16361640 min_keep : int = 0 ,
16371641 stream : bool = False ,
1638- seed : Optional [int ] = None ,
16391642 mirostat_mode : int = 0 ,
16401643 mirostat_tau : float = 5.0 ,
16411644 mirostat_eta : float = 0.1 ,
@@ -1654,7 +1657,8 @@ def _create_completion(
16541657 logit_bias : Optional [Dict [int , float ]] = None ,
16551658 logits_processor : Optional [LogitsProcessorList ] = None ,
16561659 grammar : Optional [LlamaGrammar ] = None ,
1657- grammar_lazy : bool = False
1660+ grammar_lazy : bool = False ,
1661+ seed : Optional [int ] = None ,
16581662 ) -> Union [
16591663 Iterator [CreateCompletionResponse ], Iterator [CreateCompletionStreamResponse ]
16601664 ]:
@@ -1798,11 +1802,6 @@ def _create_completion(
17981802 if self .verbose :
17991803 print ("Llama._create_completion: cache miss" , file = sys .stderr )
18001804
1801- if seed is not None :
1802- self .set_seed (seed )
1803- else :
1804- self .set_seed (random .Random (self ._seed ).randint (0 , 2 ** 32 ))
1805-
18061805 finish_reason = "length"
18071806 multibyte_fix = 0
18081807 for token in self .generate (
@@ -1838,6 +1837,7 @@ def _create_completion(
18381837 logits_processor = logits_processor ,
18391838 grammar = grammar ,
18401839 grammar_lazy = grammar_lazy ,
1840+ seed = seed if seed is not None else self ._seed ,
18411841 ):
18421842 if llama_cpp .llama_token_is_eog (self ._model .vocab , token ):
18431843 text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
0 commit comments