Skip to content

Commit bbd70a5

Browse files
committed
fix(Llama.generate): add explicit fallback context reset and expand generator docstrings
- Added a fallback `if reset:` block in `Llama.generate` to ensure the KV cache and hybrid cache manager are explicitly cleared when `reset=True` is passed and no prefix match is found. This prevents potential context poisoning from previous runs. - Added comprehensive docstrings to the `generate` method for all newly integrated sampler parameters (e.g., XTC, Mirostat, DRY penalties etc.). - Added explicit verbose logging for cache resetting, rollback events, and speculative decoding behaviors to improve debuggability. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent d422e82 commit bbd70a5

1 file changed

Lines changed: 52 additions & 8 deletions

File tree

llama_cpp/llama.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,12 +1167,41 @@ def generate(
11671167
... print(llama.detokenize([token]))
11681168
11691169
Args:
1170-
tokens: The prompt tokens.
1171-
top_k: The top-k sampling parameter.
1172-
top_p: The top-p sampling parameter.
1173-
temp: The temperature parameter.
1174-
repeat_penalty: The repeat penalty parameter.
1175-
reset: Whether to reset the model state.
1170+
tokens: The prompt tokens to evaluate.
1171+
top_k: Limit the next token selection to the K most probable tokens. (<=0 to use vocab size)
1172+
top_p: Nucleus sampling. Limits selection to a cumulative probability of P.
1173+
min_p: Minimum P sampling. Drops tokens with a probability less than min_p relative to the most likely token.
1174+
typical_p: Locally typical sampling. (1.0 = disabled)
1175+
temp: Temperature. Controls randomness. (<=0.0 greedy, 0.0 no probabilities)
1176+
dynatemp_range: Range of dynamic temperature.
1177+
dynatemp_exponent: Exponent of dynamic temperature.
1178+
top_n_sigma: Limit selection to tokens within n * sigma of the max logit. (-1.0 = disabled)
1179+
min_keep: Minimum tokens to keep for sampling.
1180+
penalty_last_n: Last n tokens to penalize (0 = disable penalty, -1 = context size).
1181+
repeat_penalty: General penalty for repeated tokens. (1.0 = disabled)
1182+
frequency_penalty: Penalty based on the absolute frequency of a token in the prompt.
1183+
present_penalty: Flat penalty applied if a token is present anywhere in the context.
1184+
reset: If True, attempts to automatically match the KV cache prefix to avoid re-evaluation. If False, blindly appends tokens to existing context.
1185+
mirostat_mode: Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
1186+
mirostat_tau: Target cross-entropy (surprisal) for Mirostat.
1187+
mirostat_eta: Learning rate for Mirostat.
1188+
xtc_threshold: Minimum probability threshold for XTC token removal.
1189+
xtc_probability: Chance for token removal in XTC sampling.
1190+
dry_multiplier: DRY (Don't Repeat Yourself) repetition penalty multiplier (0.0 = disabled).
1191+
dry_base: DRY repetition penalty base value.
1192+
dry_allowed_length: DRY maximum allowed sequence length without penalty.
1193+
dry_penalty_last_n: DRY tokens to scan for repetitions (0 = disabled, -1 = context size).
1194+
dry_seq_breakers: Array of sequence breakers for DRY sampling.
1195+
adaptive_target: Adaptive-p target probability (0.0 to 1.0, negative = disabled).
1196+
adaptive_decay: Adaptive-p decay rate (0.0 to 0.99).
1197+
use_infill: Activate specialized fill-in-the-middle sampler.
1198+
ignore_eos: If True, ignore the End-of-Sequence token.
1199+
logit_bias: Dictionary mapping token IDs to their bias values.
1200+
logits_processor: List of custom Python callbacks to modify logits in-place.
1201+
stopping_criteria: List of custom callbacks to halt generation dynamically.
1202+
grammar: Optional BNF-like grammar (GBNF) to constrain sampling syntax.
1203+
grammar_lazy: If True, activates grammar constraints only on specific trigger tokens.
1204+
seed: RNG seed for sampling. Overrides the instance seed.
11761205
11771206
Yields:
11781207
The generated tokens.
@@ -1259,12 +1288,14 @@ def generate(
12591288
f"remaining {len(tokens)} prompt tokens to eval",
12601289
file=sys.stderr,
12611290
)
1262-
else:
1291+
if reset:
12631292
# No prefix matched at all. Completely clear the KV cache to prevent context poisoning.
12641293
self.n_tokens = 0
12651294
self._ctx.memory_clear(True)
12661295
if self.is_hybrid and self._hybrid_cache_mgr is not None:
12671296
self._hybrid_cache_mgr.clear()
1297+
if self.verbose:
1298+
print("Llama.generate: Context reset requested or no prefix match. Cleared KV cache.", file=sys.stderr)
12681299

12691300
# Reset mirostat sampling
12701301
params = LlamaSamplingParams(
@@ -1315,6 +1346,7 @@ def generate(
13151346
seed=seed if seed is not None else self._seed,
13161347
)
13171348

1349+
# Register custom python-level logits processors if provided
13181350
if logits_processor:
13191351
def adapter(token_data_array: llama_cpp.llama_token_data_array):
13201352
if self._logits_all:
@@ -1336,6 +1368,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13361368
if CommonSamplerType.CUSTOM not in params.samplers:
13371369
params.samplers.insert(3, CommonSamplerType.CUSTOM)
13381370

1371+
# Free previous sampling context to prevent memory leaks
13391372
if getattr(self, "_sampling_ctx", None) is not None:
13401373
self._sampling_ctx.close()
13411374
self._sampling_ctx = None
@@ -1345,7 +1378,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13451378
sample_idx = self.n_tokens + len(tokens) - 1
13461379
tokens = list(tokens)
13471380

1348-
# Eval and sample
1381+
# Main evaluation and generation loop
13491382
try:
13501383
while True:
13511384
if len(tokens) > 0:
@@ -1376,12 +1409,15 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13761409
else:
13771410
# Standard evaluation or single-token generation step
13781411
self.eval(tokens)
1412+
1413+
# Sample loop
13791414
while sample_idx < self.n_tokens:
13801415
token = self._sampling_ctx.sample(self._ctx, idx=-1)
13811416
self._sampling_ctx.accept(token, False if grammar is None else True)
13821417

13831418
sample_idx += 1
13841419

1420+
# Halt generation if custom stopping criteria are met
13851421
if stopping_criteria is not None:
13861422
if self._logits_all:
13871423
logits_idx = sample_idx - self.n_tokens
@@ -1399,13 +1435,17 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13991435
):
14001436
return
14011437

1438+
# Yield the generated token to the caller
14021439
tokens_or_none = yield token
1440+
14031441
tokens.clear()
14041442
tokens.append(token)
14051443

14061444
if tokens_or_none is not None:
14071445
tokens.extend(tokens_or_none)
14081446

1447+
# Rollback Check: A previously evaluated token (e.g. from speculative decoding)
1448+
# mismatched the newly sampled token. We must rollback the KV cache.
14091449
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
14101450
self.n_tokens = sample_idx
14111451
if self.is_hybrid:
@@ -1420,10 +1460,13 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
14201460
self._ctx.memory_clear(True)
14211461
self.n_tokens = 0
14221462
else:
1463+
if self.verbose:
1464+
print(f"Llama.generate: Draft token rejected. Truncating context to {self.n_tokens}.", file=sys.stderr)
14231465
self._ctx.memory_seq_rm(0, self.n_tokens, -1)
14241466

14251467
break
14261468

1469+
# Speculative Decoding (Draft Model) logic
14271470
if self.draft_model is not None:
14281471
if self.is_hybrid:
14291472
if self.verbose:
@@ -1439,6 +1482,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
14391482
]
14401483
)
14411484
finally:
1485+
# Ensure the final state is checkpointed for hybrid models when generation finishes or is interrupted
14421486
if (
14431487
self.is_hybrid
14441488
and self._hybrid_cache_mgr is not None

0 commit comments

Comments
 (0)