Skip to content

Commit d2318a4

Browse files
committed
chore: enhance hybrid cache logging and document M-RoPE token usage
- Added explanatory comments detailing why n_tokens is used instead of chunk_n_pos for M-RoPE models (to prevent the system from skipping evaluation). - Added verbose logging for hybrid cache clearance scenarios (when checkpoints are missing, restore fails, or max_checkpoints is 0). Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent bdc2d7c commit d2318a4

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

llama_cpp/llama_chat_format.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3185,6 +3185,10 @@ def _create_bitmap_func(idx: int, item: str):
31853185
self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO
31863186
]:
31873187
# Extract media properties
3188+
# Note(JamePeng):
3189+
# The M-RoPE model is based on `n_pos` instead of `n_tokens` (of course, there's no difference in non-M-RoPE models).
3190+
# However, I still keep `n_tokens` because if `n_pos` is used, the underlying system will assume it is a full-match and will skip eval and sample.
3191+
# chunk_n_pos = self._mtmd_cpp.mtmd_input_chunk_get_n_pos(chunk) # equals to max(t,h,w) for M-RoPE; equals to `n_tokens` otherwise
31883192
chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk)
31893193

31903194
if media_items_cur < media_items_count:
@@ -3318,10 +3322,14 @@ def __call__(
33183322
if self.verbose:
33193323
print(f"{self.log_prefix}(__call__): Successfully rolled back to checkpoint at pos {llama.n_tokens}.", file=sys.stderr)
33203324
else:
3325+
if self.verbose:
3326+
print(f"{self.log_prefix}(__call__): No suitable checkpoint found or restore failed. Clearing hybrid cache entirely.", file=sys.stderr)
33213327
llama._hybrid_cache_mgr.clear()
33223328
llama._ctx.memory_clear(True)
33233329
llama.n_tokens = 0
33243330
else:
3331+
if self.verbose:
3332+
print(f"{self.log_prefix}(__call__): Hybrid cache enabled but max_checkpoints is 0. Clearing cache entirely.", file=sys.stderr)
33253333
llama._hybrid_cache_mgr.clear()
33263334
llama._ctx.memory_clear(True)
33273335
llama.n_tokens = 0

0 commit comments

Comments
 (0)