@@ -1167,12 +1167,41 @@ def generate(
11671167 ... print(llama.detokenize([token]))
11681168
11691169 Args:
1170- tokens: The prompt tokens.
1171- top_k: The top-k sampling parameter.
1172- top_p: The top-p sampling parameter.
1173- temp: The temperature parameter.
1174- repeat_penalty: The repeat penalty parameter.
1175- reset: Whether to reset the model state.
1170+ tokens: The prompt tokens to evaluate.
1171+ top_k: Limit the next token selection to the K most probable tokens. (<=0 to use vocab size)
1172+ top_p: Nucleus sampling. Limits selection to a cumulative probability of P.
1173+ min_p: Minimum P sampling. Drops tokens with a probability less than min_p relative to the most likely token.
1174+ typical_p: Locally typical sampling. (1.0 = disabled)
1175+ temp: Temperature. Controls randomness. (<=0.0 greedy, 0.0 no probabilities)
1176+ dynatemp_range: Range of dynamic temperature.
1177+ dynatemp_exponent: Exponent of dynamic temperature.
1178+ top_n_sigma: Limit selection to tokens within n * sigma of the max logit. (-1.0 = disabled)
1179+ min_keep: Minimum tokens to keep for sampling.
1180+ penalty_last_n: Last n tokens to penalize (0 = disable penalty, -1 = context size).
1181+ repeat_penalty: General penalty for repeated tokens. (1.0 = disabled)
1182+ frequency_penalty: Penalty based on the absolute frequency of a token in the prompt.
1183+ present_penalty: Flat penalty applied if a token is present anywhere in the context.
1184+ reset: If True, attempts to automatically match the KV cache prefix to avoid re-evaluation. If False, blindly appends tokens to existing context.
1185+ mirostat_mode: Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
1186+ mirostat_tau: Target cross-entropy (surprisal) for Mirostat.
1187+ mirostat_eta: Learning rate for Mirostat.
1188+ xtc_threshold: Minimum probability threshold for XTC token removal.
1189+ xtc_probability: Chance for token removal in XTC sampling.
1190+ dry_multiplier: DRY (Don't Repeat Yourself) repetition penalty multiplier (0.0 = disabled).
1191+ dry_base: DRY repetition penalty base value.
1192+ dry_allowed_length: DRY maximum allowed sequence length without penalty.
1193+ dry_penalty_last_n: DRY tokens to scan for repetitions (0 = disabled, -1 = context size).
1194+ dry_seq_breakers: Array of sequence breakers for DRY sampling.
1195+ adaptive_target: Adaptive-p target probability (0.0 to 1.0, negative = disabled).
1196+ adaptive_decay: Adaptive-p decay rate (0.0 to 0.99).
1197+ use_infill: Activate specialized fill-in-the-middle sampler.
1198+ ignore_eos: If True, ignore the End-of-Sequence token.
1199+ logit_bias: Dictionary mapping token IDs to their bias values.
1200+ logits_processor: List of custom Python callbacks to modify logits in-place.
1201+ stopping_criteria: List of custom callbacks to halt generation dynamically.
1202+ grammar: Optional BNF-like grammar (GBNF) to constrain sampling syntax.
1203+ grammar_lazy: If True, activates grammar constraints only on specific trigger tokens.
1204+ seed: RNG seed for sampling. Overrides the instance seed.
11761205
11771206 Yields:
11781207 The generated tokens.
@@ -1259,12 +1288,14 @@ def generate(
12591288 f"remaining { len (tokens )} prompt tokens to eval" ,
12601289 file = sys .stderr ,
12611290 )
1262- else :
1291+ if reset :
12631292 # No prefix matched at all. Completely clear the KV cache to prevent context poisoning.
12641293 self .n_tokens = 0
12651294 self ._ctx .memory_clear (True )
12661295 if self .is_hybrid and self ._hybrid_cache_mgr is not None :
12671296 self ._hybrid_cache_mgr .clear ()
1297+ if self .verbose :
1298+ print ("Llama.generate: Context reset requested or no prefix match. Cleared KV cache." , file = sys .stderr )
12681299
12691300 # Reset mirostat sampling
12701301 params = LlamaSamplingParams (
@@ -1315,6 +1346,7 @@ def generate(
13151346 seed = seed if seed is not None else self ._seed ,
13161347 )
13171348
1349+ # Register custom python-level logits processors if provided
13181350 if logits_processor :
13191351 def adapter (token_data_array : llama_cpp .llama_token_data_array ):
13201352 if self ._logits_all :
@@ -1336,6 +1368,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13361368 if CommonSamplerType .CUSTOM not in params .samplers :
13371369 params .samplers .insert (3 , CommonSamplerType .CUSTOM )
13381370
1371+ # Free previous sampling context to prevent memory leaks
13391372 if getattr (self , "_sampling_ctx" , None ) is not None :
13401373 self ._sampling_ctx .close ()
13411374 self ._sampling_ctx = None
@@ -1345,7 +1378,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13451378 sample_idx = self .n_tokens + len (tokens ) - 1
13461379 tokens = list (tokens )
13471380
1348- # Eval and sample
1381+ # Main evaluation and generation loop
13491382 try :
13501383 while True :
13511384 if len (tokens ) > 0 :
@@ -1376,12 +1409,15 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13761409 else :
13771410 # Standard evaluation or single-token generation step
13781411 self .eval (tokens )
1412+
1413+ # Sample loop
13791414 while sample_idx < self .n_tokens :
13801415 token = self ._sampling_ctx .sample (self ._ctx , idx = - 1 )
13811416 self ._sampling_ctx .accept (token , False if grammar is None else True )
13821417
13831418 sample_idx += 1
13841419
1420+ # Halt generation if custom stopping criteria are met
13851421 if stopping_criteria is not None :
13861422 if self ._logits_all :
13871423 logits_idx = sample_idx - self .n_tokens
@@ -1399,13 +1435,17 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
13991435 ):
14001436 return
14011437
1438+ # Yield the generated token to the caller
14021439 tokens_or_none = yield token
1440+
14031441 tokens .clear ()
14041442 tokens .append (token )
14051443
14061444 if tokens_or_none is not None :
14071445 tokens .extend (tokens_or_none )
14081446
1447+ # Rollback Check: A previously evaluated token (e.g. from speculative decoding)
1448+ # mismatched the newly sampled token. We must rollback the KV cache.
14091449 if sample_idx < self .n_tokens and token != self ._input_ids [sample_idx ]:
14101450 self .n_tokens = sample_idx
14111451 if self .is_hybrid :
@@ -1420,10 +1460,13 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
14201460 self ._ctx .memory_clear (True )
14211461 self .n_tokens = 0
14221462 else :
1463+ if self .verbose :
1464+ print (f"Llama.generate: Draft token rejected. Truncating context to { self .n_tokens } ." , file = sys .stderr )
14231465 self ._ctx .memory_seq_rm (0 , self .n_tokens , - 1 )
14241466
14251467 break
14261468
1469+ # Speculative Decoding (Draft Model) logic
14271470 if self .draft_model is not None :
14281471 if self .is_hybrid :
14291472 if self .verbose :
@@ -1439,6 +1482,7 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
14391482 ]
14401483 )
14411484 finally :
1485+ # Ensure the final state is checkpointed for hybrid models when generation finishes or is interrupted
14421486 if (
14431487 self .is_hybrid
14441488 and self ._hybrid_cache_mgr is not None
0 commit comments