@@ -682,7 +682,9 @@ def close(self) -> None:
682682 self ._c_tensor_split = None
683683 self ._kv_overrides_array = None
684684
685- self ._stack .close ()
685+ if getattr (self , "_stack" , None ) is not None and hasattr (self ._stack , "close" ):
686+ self ._stack .close ()
687+ self ._stack = None
686688
687689 def __del__ (self ) -> None :
688690 self .close ()
@@ -789,30 +791,68 @@ def eval(self, tokens: Sequence[int]):
789791
790792 # Context Shift: Prevent OOM by discarding older tokens when context limit is reached.
791793 if self .n_tokens + n_eval > self ._n_ctx :
792- _n_keep = min (self .n_keep , self .n_tokens )
793- # Number of tokens after n_keep that may be discarded when shifting context
794- # defaults to half
795- _n_discard = (self .n_tokens - _n_keep ) // 2
794+ # 0. Check if the memory supports shifting
795+ if not self ._ctx .memory_can_shift ():
796+ raise RuntimeError (
797+ f"Llama.eval: Context Shift is explicitly disabled by the C++ backend "
798+ f"(n_pos_per_embd > 1 or incompatible M-RoPE). "
799+ f"You MUST increase n_ctx (currently { self ._n_ctx } ) to fit the dialogue."
800+ )
801+ # 1. Calculate the absolute minimum number of tokens we must discard to fit the new chunk.
802+ required_discard = (self .n_tokens + n_eval ) - self ._n_ctx
796803
797- if self .verbose :
798- model_type = "Hybrid/Recurrent/SWA" if self .is_hybrid else "Transformer"
799- print (f"Llama.eval: { model_type } context limit reached. Shifting context: "
800- f"discarding { _n_discard } tokens..." , file = sys .stderr )
804+ # 2. Sanity check: If the incoming chunk itself is larger than the entire context window,
805+ # shifting is physically impossible.
806+ if required_discard > self .n_tokens :
807+ raise RuntimeError (f"Llama.eval: Context shift failed. The incoming chunk ({ n_eval } tokens) "
808+ f"is larger than the entire context window ({ self ._n_ctx } )." )
809+
810+ # 3. Determine how many tokens to keep at the beginning (usually the System Prompt).
811+ _n_keep_desired = min (self .n_keep , self .n_tokens )
812+
813+ # Ensure that keeping these tokens doesn't prevent us from discarding the required amount.
814+ max_keep_allowed = max (0 , self .n_tokens - required_discard )
815+ _n_keep = min (_n_keep_desired , max_keep_allowed )
816+
817+ # 4. Calculate the final discard count. Default strategy is to discard half of the available
818+ # past tokens to minimize frequent shifting, but it must be at least `required_discard`.
819+ _n_discard = max (required_discard , (self .n_tokens - _n_keep ) // 2 )
820+
821+ # 5. Execute the shift only if there are tokens to discard.
822+ if _n_discard > 0 :
823+ if self .verbose :
824+ model_type = "Hybrid/Recurrent/SWA" if getattr (self , 'is_hybrid' , False ) else "Transformer"
825+ print (f"Llama.eval: { model_type } context limit reached. Shifting context: "
826+ f"keeping { _n_keep } , discarding { _n_discard } tokens..." , file = sys .stderr )
801827
802- # Use context memory methods for handles both Attention KV removal and RNN pos shifting automatically
803- self ._ctx .memory_seq_rm (0 , _n_keep , _n_keep + _n_discard )
804- self ._ctx .memory_seq_add (0 , _n_keep + _n_discard , self .n_tokens , - _n_discard )
828+ try :
829+ # Remove the specified block of tokens from the physical KV cache
830+ self ._ctx .memory_seq_rm (0 , _n_keep , _n_keep + _n_discard )
831+
832+ # Shift the positional IDs of all subsequent tokens to the left to close the gap
833+ self ._ctx .memory_seq_add (0 , _n_keep + _n_discard , self .n_tokens , - _n_discard )
834+ except Exception as e :
835+ # Defense-in-depth: Catch any other recoverable backend errors
836+ raise RuntimeError (f"Llama.eval: Context Shift failed at the C++ level. Error: { str (e )} " ) from e
805837
806- remaining_len = self .n_tokens - (_n_keep + _n_discard )
807- if remaining_len > 0 :
808- self .input_ids [_n_keep : _n_keep + remaining_len ] = self .input_ids [_n_keep + _n_discard : self .n_tokens ]
838+ # 6. Synchronize the Python-side token tracking array (ledger)
839+ remaining_len = self .n_tokens - (_n_keep + _n_discard )
840+ if remaining_len > 0 :
841+ self .input_ids [_n_keep : _n_keep + remaining_len ] = self .input_ids [_n_keep + _n_discard : self .n_tokens ]
809842
810- self .n_tokens -= _n_discard
843+ # 7. Update the global token counter
844+ self .n_tokens -= _n_discard
811845
812846 # Adaptive batch downgrade limit initialization
813847 current_max_batch = self .n_batch
814848 last_ckpt_pos = self .n_tokens
815849
850+ # Adaptive Periodic Checkpointing for Hybrid Models
851+ # Following the "no more than three times" principle :)
852+ # when pre-filling very large blocks, dilute the save frequency to minimize I/O blocking.
853+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
854+ dynamic_interval = max (self .checkpoint_interval , n_eval // 3 ) # Maximum of 3 triggers
855+
816856 # If KV slots are full, `current_batch_size` will be halved.
817857 # A `while` loop allows us to correctly resume from the exact cut-off point.
818858 i = 0
@@ -900,15 +940,23 @@ def eval(self, tokens: Sequence[int]):
900940
901941 # Periodic Checkpoint: Save states for hybrid models to avoid massive rollbacks
902942 if self .is_hybrid and self ._hybrid_cache_mgr is not None :
903- if (self .n_tokens - last_ckpt_pos >= self .checkpoint_interval ) and (i < n_eval ):
943+ current_pos = self .n_tokens
944+ if (current_pos - last_ckpt_pos >= dynamic_interval ) and (i < n_eval ):
945+
904946 if self .verbose :
905- print (f"Llama.eval: [Periodic Checkpoint] Saving hybrid state at pos { self .n_tokens } ." , file = sys .stderr )
906- self ._hybrid_cache_mgr .save_checkpoint (
907- current_pos = self .n_tokens ,
908- tokens = self .input_ids [:self .n_tokens ].tolist (),
947+ print (f"Llama.eval: [Periodic Checkpoint] Saving hybrid state at pos { current_pos } "
948+ f"(checkpoint_interval({ dynamic_interval } ) reached, last={ last_ckpt_pos } )." , file = sys .stderr )
949+
950+ success = self ._hybrid_cache_mgr .save_checkpoint (
951+ current_pos = current_pos ,
952+ tokens = self .input_ids [:current_pos ].tolist (),
909953 seq_id = 0
910954 )
911- last_ckpt_pos = self .n_tokens
955+ if success :
956+ last_ckpt_pos = current_pos
957+ else :
958+ if self .verbose :
959+ print (f"Llama.eval: [Periodic Checkpoint] HybridCheckpoint save failed at pos { current_pos } , skipping update" , file = sys .stderr )
912960
913961 # Save the final logit if not in _logits_all mode
914962 if not self ._logits_all :
0 commit comments