Skip to content

Commit ad10cfd

Browse files
committed
fix(sampling): pass seed to sampling context and remove global mutation
- Add `seed` parameter to `generate` and `sample` method signatures. - Pass the resolved seed directly to `LlamaSamplingParams` to ensure the underlying C++ sampler uses it. - Remove thread-unsafe `self.set_seed()` calls in `_create_completion` to prevent global state pollution during concurrent requests. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent fb3072d commit ad10cfd

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

llama_cpp/llama.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ def sample(
10191019
grammar: Optional[LlamaGrammar] = None, # optional BNF-like grammar to constrain sampling
10201020
grammar_lazy: bool = False,
10211021
idx: Optional[int] = None,
1022+
seed: Optional[int] = None,
10221023
):
10231024
"""Sample a token from the model.
10241025
Returns:
@@ -1040,6 +1041,7 @@ def sample(
10401041
temp=temp,
10411042
top_n_sigma=top_n_sigma,
10421043
min_keep=min_keep,
1044+
seed=seed if seed is not None else self._seed,
10431045

10441046
# Dynamic Temp
10451047
dynatemp_range=dynatemp_range,
@@ -1146,7 +1148,8 @@ def generate(
11461148
logits_processor: Optional[LogitsProcessorList] = None,
11471149
stopping_criteria: Optional[StoppingCriteriaList] = None,
11481150
grammar: Optional[LlamaGrammar] = None,
1149-
grammar_lazy :bool = False,
1151+
grammar_lazy: bool = False,
1152+
seed: Optional[int] = None,
11501153
) -> Generator[int, Optional[Sequence[int]], None]:
11511154
"""Create a generator of tokens from a prompt.
11521155
@@ -1302,6 +1305,7 @@ def generate(
13021305
logit_bias=self._convert_logit_bias(logit_bias),
13031306
grammar=grammar._grammar if grammar else "",
13041307
grammar_lazy=grammar_lazy,
1308+
seed=seed if seed is not None else self._seed,
13051309
)
13061310

13071311
if logits_processor:
@@ -1635,7 +1639,6 @@ def _create_completion(
16351639
dynatemp_exponent: float = 1.0,
16361640
min_keep: int = 0,
16371641
stream: bool = False,
1638-
seed: Optional[int] = None,
16391642
mirostat_mode: int = 0,
16401643
mirostat_tau: float = 5.0,
16411644
mirostat_eta: float = 0.1,
@@ -1654,7 +1657,8 @@ def _create_completion(
16541657
logit_bias: Optional[Dict[int, float]] = None,
16551658
logits_processor: Optional[LogitsProcessorList] = None,
16561659
grammar: Optional[LlamaGrammar] = None,
1657-
grammar_lazy: bool = False
1660+
grammar_lazy: bool = False,
1661+
seed: Optional[int] = None,
16581662
) -> Union[
16591663
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
16601664
]:
@@ -1798,11 +1802,6 @@ def _create_completion(
17981802
if self.verbose:
17991803
print("Llama._create_completion: cache miss", file=sys.stderr)
18001804

1801-
if seed is not None:
1802-
self.set_seed(seed)
1803-
else:
1804-
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
1805-
18061805
finish_reason = "length"
18071806
multibyte_fix = 0
18081807
for token in self.generate(
@@ -1838,6 +1837,7 @@ def _create_completion(
18381837
logits_processor=logits_processor,
18391838
grammar=grammar,
18401839
grammar_lazy=grammar_lazy,
1840+
seed=seed if seed is not None else self._seed,
18411841
):
18421842
if llama_cpp.llama_token_is_eog(self._model.vocab, token):
18431843
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)

0 commit comments

Comments
 (0)