1414 Tuple ,
1515)
1616
17- import llama_cpp .llama
18- import llama_cpp ._internals as _internals
19- import llama_cpp .llama_cpp as llama_cpp
17+ import llama_cpp .llama as llama_core
18+ import llama_cpp .llama_cpp as llama_cpp_lib
2019
2120from .llama_types import *
2221
@@ -39,7 +38,7 @@ def _find_longest_prefix_key(
3938 pass
4039
4140 @abstractmethod
42- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama .LlamaState" :
41+ def __getitem__ (self , key : Sequence [int ]) -> "llama_core .LlamaState" :
4342 raise NotImplementedError
4443
4544 @abstractmethod
@@ -48,7 +47,7 @@ def __contains__(self, key: Sequence[int]) -> bool:
4847
4948 @abstractmethod
5049 def __setitem__ (
51- self , key : Sequence [int ], value : "llama_cpp.llama .LlamaState"
50+ self , key : Sequence [int ], value : "llama_core .LlamaState"
5251 ) -> None :
5352 raise NotImplementedError
5453
@@ -73,18 +72,18 @@ def _find_longest_prefix_key(
7372 min_len = 0
7473 min_key : Optional [Tuple [int , ...]] = None
7574 for k in self .cache .iterkeys (): # type: ignore
76- prefix_len = llama_cpp . llama .Llama .longest_token_prefix (k , key )
75+ prefix_len = llama_core .Llama .longest_token_prefix (k , key )
7776 if prefix_len > min_len :
7877 min_len = prefix_len
7978 min_key = k # type: ignore
8079 return min_key
8180
82- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama .LlamaState" :
81+ def __getitem__ (self , key : Sequence [int ]) -> "llama_core .LlamaState" :
8382 key = tuple (key )
8483 _key = self ._find_longest_prefix_key (key )
8584 if _key is None :
8685 raise KeyError ("Key not found" )
87- value : "llama_cpp.llama .LlamaState" = self .cache .pop (_key ) # type: ignore
86+ value : "llama_core .LlamaState" = self .cache .pop (_key ) # type: ignore
8887 # NOTE: This puts an integer as key in cache, which breaks,
8988 # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
9089 # self.cache.push(_key, side="front") # type: ignore
@@ -93,7 +92,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
9392 def __contains__ (self , key : Sequence [int ]) -> bool :
9493 return self ._find_longest_prefix_key (tuple (key )) is not None
9594
96- def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama .LlamaState" ):
95+ def __setitem__ (self , key : Sequence [int ], value : "llama_core .LlamaState" ):
9796 print ("LlamaDiskCache.__setitem__: called" , file = sys .stderr )
9897 key = tuple (key )
9998 if key in self .cache :
@@ -114,7 +113,7 @@ def __init__(self, capacity_bytes: int = (2 << 30)):
114113 super ().__init__ (capacity_bytes )
115114 self .capacity_bytes = capacity_bytes
116115 self .cache_state : OrderedDict [
117- Tuple [int , ...], "llama_cpp.llama .LlamaState"
116+ Tuple [int , ...], "llama_core .LlamaState"
118117 ] = OrderedDict ()
119118
120119 @property
@@ -128,7 +127,7 @@ def _find_longest_prefix_key(
128127 min_len = 0
129128 min_key = None
130129 keys = (
131- (k , llama_cpp . llama .Llama .longest_token_prefix (k , key ))
130+ (k , llama_core .Llama .longest_token_prefix (k , key ))
132131 for k in self .cache_state .keys ()
133132 )
134133 for k , prefix_len in keys :
@@ -137,7 +136,7 @@ def _find_longest_prefix_key(
137136 min_key = k
138137 return min_key
139138
140- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama .LlamaState" :
139+ def __getitem__ (self , key : Sequence [int ]) -> "llama_core .LlamaState" :
141140 key = tuple (key )
142141 _key = self ._find_longest_prefix_key (key )
143142 if _key is None :
@@ -149,7 +148,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
149148 def __contains__ (self , key : Sequence [int ]) -> bool :
150149 return self ._find_longest_prefix_key (tuple (key )) is not None
151150
152- def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama .LlamaState" ):
151+ def __setitem__ (self , key : Sequence [int ], value : "llama_core .LlamaState" ):
153152 key = tuple (key )
154153 if key in self .cache_state :
155154 del self .cache_state [key ]
@@ -164,7 +163,7 @@ def __init__(self):
164163 # Child nodes: {token_id: TrieNode}
165164 self .children : Dict [int , "TrieNode" ] = {}
166165 # Stores the LlamaState if this node marks the end of a cached sequence.
167- self .state : Optional ["llama_cpp.llama .LlamaState" ] = None
166+ self .state : Optional ["llama_core .LlamaState" ] = None
168167
169168
170169class LlamaTrieCache (BaseLlamaCache ):
@@ -228,7 +227,7 @@ def _find_longest_prefix_node(
228227
229228 return longest_prefix_node , longest_prefix_key
230229
231- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama .LlamaState" :
230+ def __getitem__ (self , key : Sequence [int ]) -> "llama_core .LlamaState" :
232231 """
233232 Retrieves the state for the longest matching prefix in O(K) time.
234233 Updates the LRU status.
@@ -282,7 +281,7 @@ def _prune(self, key: Tuple[int, ...]):
282281 # Node is still in use, stop pruning
283282 break
284283
285- def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama .LlamaState" ):
284+ def __setitem__ (self , key : Sequence [int ], value : "llama_core .LlamaState" ):
286285 """
287286 Adds a (key, state) pair to the cache in O(K) time.
288287 Handles LRU updates and eviction.
@@ -334,7 +333,7 @@ class HybridCheckpointCache(BaseLlamaCache):
334333 Manager for RNN state snapshots (Checkpoints) tailored for Hybrid/Recurrent models.
335334 Provides rollback capabilities for models that cannot physically truncate KV cache.
336335 """
337- def __init__ (self , ctx : llama_cpp .llama_context_p , max_checkpoints : int = 16 , verbose : bool = False ):
336+ def __init__ (self , ctx : llama_cpp_lib .llama_context_p , max_checkpoints : int = 16 , verbose : bool = False ):
338337 if ctx is None :
339338 raise ValueError ("HybridCheckpointCache(__init__): Failed to create HybridCheckpointCache with model context" )
340339 self ._ctx = ctx
@@ -343,10 +342,10 @@ def __init__(self, ctx: llama_cpp.llama_context_p, max_checkpoints: int = 16, ve
343342 self ._current_size = 0
344343
345344 # Cache C-type API function pointers for performance
346- self ._get_size_ext = llama_cpp .llama_state_seq_get_size_ext
347- self ._get_data_ext = llama_cpp .llama_state_seq_get_data_ext
348- self ._set_data_ext = llama_cpp .llama_state_seq_set_data_ext
349- self ._flag_partial = llama_cpp .LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY
345+ self ._get_size_ext = llama_cpp_lib .llama_state_seq_get_size_ext
346+ self ._get_data_ext = llama_cpp_lib .llama_state_seq_get_data_ext
347+ self ._set_data_ext = llama_cpp_lib .llama_state_seq_set_data_ext
348+ self ._flag_partial = llama_cpp_lib .LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY
350349
351350 self .verbose = verbose
352351
0 commit comments