Skip to content

Commit db6ee16

Browse files
authored
Merge branch 'JamePeng:main' into main
2 parents 3698316 + 5bf6b6a commit db6ee16

File tree

3 files changed

+505
-255
lines changed

3 files changed

+505
-255
lines changed

llama_cpp/_internals.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def close(self):
8585
self.model = None
8686
self.vocab = None
8787

88-
self._exit_stack.close()
88+
if getattr(self, "_exit_stack", None) is not None and hasattr(self._exit_stack, "close"):
89+
self._exit_stack.close()
90+
self._exit_stack = None
8991

9092
def __del__(self):
9193
self.close()
@@ -386,8 +388,11 @@ def close(self):
386388
except Exception:
387389
pass
388390
self.ctx = None
391+
self.params = None
389392

390-
self._exit_stack.close()
393+
if getattr(self, "_exit_stack", None) is not None and hasattr(self._exit_stack, "close"):
394+
self._exit_stack.close()
395+
self._exit_stack = None
391396

392397
def __del__(self):
393398
self.close()
@@ -442,6 +447,9 @@ def memory_seq_pos_max(self, seq_id: int) -> int:
442447
def memory_seq_pos_min(self, seq_id: int) -> int:
443448
return llama_cpp.llama_memory_seq_pos_min(self.get_memory(), seq_id)
444449

450+
def memory_can_shift(self) -> bool:
451+
return llama_cpp.llama_memory_can_shift(self.get_memory())
452+
445453
# // State / sessions API
446454

447455
def get_state_size(self) -> int:
@@ -659,7 +667,9 @@ def close(self):
659667
pass
660668
self.batch = None
661669

662-
self._exit_stack.close()
670+
if getattr(self, "_exit_stack", None) is not None and hasattr(self._exit_stack, "close"):
671+
self._exit_stack.close()
672+
self._exit_stack = None
663673

664674
def __del__(self):
665675
self.close()

llama_cpp/llama.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,9 @@ def close(self) -> None:
682682
self._c_tensor_split = None
683683
self._kv_overrides_array = None
684684

685-
self._stack.close()
685+
if getattr(self, "_stack", None) is not None and hasattr(self._stack, "close"):
686+
self._stack.close()
687+
self._stack = None
686688

687689
def __del__(self) -> None:
688690
self.close()
@@ -789,30 +791,68 @@ def eval(self, tokens: Sequence[int]):
789791

790792
# Context Shift: Prevent OOM by discarding older tokens when context limit is reached.
791793
if self.n_tokens + n_eval > self._n_ctx:
792-
_n_keep = min(self.n_keep, self.n_tokens)
793-
# Number of tokens after n_keep that may be discarded when shifting context
794-
# defaults to half
795-
_n_discard = (self.n_tokens - _n_keep) // 2
794+
# 0. Check if the memory supports shifting
795+
if not self._ctx.memory_can_shift():
796+
raise RuntimeError(
797+
f"Llama.eval: Context Shift is explicitly disabled by the C++ backend "
798+
f"(n_pos_per_embd > 1 or incompatible M-RoPE). "
799+
f"You MUST increase n_ctx (currently {self._n_ctx}) to fit the dialogue."
800+
)
801+
# 1. Calculate the absolute minimum number of tokens we must discard to fit the new chunk.
802+
required_discard = (self.n_tokens + n_eval) - self._n_ctx
796803

797-
if self.verbose:
798-
model_type = "Hybrid/Recurrent/SWA" if self.is_hybrid else "Transformer"
799-
print(f"Llama.eval: {model_type} context limit reached. Shifting context: "
800-
f"discarding {_n_discard} tokens...", file=sys.stderr)
804+
# 2. Sanity check: If the incoming chunk itself is larger than the entire context window,
805+
# shifting is physically impossible.
806+
if required_discard > self.n_tokens:
807+
raise RuntimeError(f"Llama.eval: Context shift failed. The incoming chunk ({n_eval} tokens) "
808+
f"is larger than the entire context window ({self._n_ctx}).")
809+
810+
# 3. Determine how many tokens to keep at the beginning (usually the System Prompt).
811+
_n_keep_desired = min(self.n_keep, self.n_tokens)
812+
813+
# Ensure that keeping these tokens doesn't prevent us from discarding the required amount.
814+
max_keep_allowed = max(0, self.n_tokens - required_discard)
815+
_n_keep = min(_n_keep_desired, max_keep_allowed)
816+
817+
# 4. Calculate the final discard count. Default strategy is to discard half of the available
818+
# past tokens to minimize frequent shifting, but it must be at least `required_discard`.
819+
_n_discard = max(required_discard, (self.n_tokens - _n_keep) // 2)
820+
821+
# 5. Execute the shift only if there are tokens to discard.
822+
if _n_discard > 0:
823+
if self.verbose:
824+
model_type = "Hybrid/Recurrent/SWA" if getattr(self, 'is_hybrid', False) else "Transformer"
825+
print(f"Llama.eval: {model_type} context limit reached. Shifting context: "
826+
f"keeping {_n_keep}, discarding {_n_discard} tokens...", file=sys.stderr)
801827

802-
# Use context memory methods for handles both Attention KV removal and RNN pos shifting automatically
803-
self._ctx.memory_seq_rm(0, _n_keep, _n_keep + _n_discard)
804-
self._ctx.memory_seq_add(0, _n_keep + _n_discard, self.n_tokens, -_n_discard)
828+
try:
829+
# Remove the specified block of tokens from the physical KV cache
830+
self._ctx.memory_seq_rm(0, _n_keep, _n_keep + _n_discard)
831+
832+
# Shift the positional IDs of all subsequent tokens to the left to close the gap
833+
self._ctx.memory_seq_add(0, _n_keep + _n_discard, self.n_tokens, -_n_discard)
834+
except Exception as e:
835+
# Defense-in-depth: Catch any other recoverable backend errors
836+
raise RuntimeError(f"Llama.eval: Context Shift failed at the C++ level. Error: {str(e)}") from e
805837

806-
remaining_len = self.n_tokens - (_n_keep + _n_discard)
807-
if remaining_len > 0:
808-
self.input_ids[_n_keep : _n_keep + remaining_len] = self.input_ids[_n_keep + _n_discard : self.n_tokens]
838+
# 6. Synchronize the Python-side token tracking array (ledger)
839+
remaining_len = self.n_tokens - (_n_keep + _n_discard)
840+
if remaining_len > 0:
841+
self.input_ids[_n_keep : _n_keep + remaining_len] = self.input_ids[_n_keep + _n_discard : self.n_tokens]
809842

810-
self.n_tokens -= _n_discard
843+
# 7. Update the global token counter
844+
self.n_tokens -= _n_discard
811845

812846
# Adaptive batch downgrade limit initialization
813847
current_max_batch = self.n_batch
814848
last_ckpt_pos = self.n_tokens
815849

850+
# Adaptive Periodic Checkpointing for Hybrid Models
851+
# Following the "no more than three times" principle :)
852+
# when pre-filling very large blocks, dilute the save frequency to minimize I/O blocking.
853+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
854+
dynamic_interval = max(self.checkpoint_interval, n_eval // 3) # Maximum of 3 triggers
855+
816856
# If KV slots are full, `current_batch_size` will be halved.
817857
# A `while` loop allows us to correctly resume from the exact cut-off point.
818858
i = 0
@@ -900,15 +940,23 @@ def eval(self, tokens: Sequence[int]):
900940

901941
# Periodic Checkpoint: Save states for hybrid models to avoid massive rollbacks
902942
if self.is_hybrid and self._hybrid_cache_mgr is not None:
903-
if (self.n_tokens - last_ckpt_pos >= self.checkpoint_interval) and (i < n_eval):
943+
current_pos = self.n_tokens
944+
if (current_pos - last_ckpt_pos >= dynamic_interval) and (i < n_eval):
945+
904946
if self.verbose:
905-
print(f"Llama.eval: [Periodic Checkpoint] Saving hybrid state at pos {self.n_tokens}.", file=sys.stderr)
906-
self._hybrid_cache_mgr.save_checkpoint(
907-
current_pos=self.n_tokens,
908-
tokens=self.input_ids[:self.n_tokens].tolist(),
947+
print(f"Llama.eval: [Periodic Checkpoint] Saving hybrid state at pos {current_pos} "
948+
f"(checkpoint_interval({dynamic_interval}) reached, last={last_ckpt_pos}).", file=sys.stderr)
949+
950+
success = self._hybrid_cache_mgr.save_checkpoint(
951+
current_pos=current_pos,
952+
tokens=self.input_ids[:current_pos].tolist(),
909953
seq_id=0
910954
)
911-
last_ckpt_pos = self.n_tokens
955+
if success:
956+
last_ckpt_pos = current_pos
957+
else:
958+
if self.verbose:
959+
print(f"Llama.eval: [Periodic Checkpoint] HybridCheckpoint save failed at pos {current_pos}, skipping update", file=sys.stderr)
912960

913961
# Save the final logit if not in _logits_all mode
914962
if not self._logits_all:

0 commit comments

Comments
 (0)