@@ -355,7 +355,9 @@ def get_embeddings_seq(self, seq_id: int):
355355 # Sampling functions - deprecated, use LlamaSampler instead
356356
357357 def set_rng_seed (self , seed : int ):
358- raise NotImplementedError ("set_rng_seed is deprecated, use LlamaSampler instead" )
358+ raise NotImplementedError (
359+ "set_rng_seed is deprecated, use LlamaSampler instead"
360+ )
359361
360362 def sample_repetition_penalties (
361363 self ,
@@ -366,30 +368,44 @@ def sample_repetition_penalties(
366368 penalty_freq : float ,
367369 penalty_present : float ,
368370 ):
369- raise NotImplementedError ("sample_repetition_penalties is deprecated, use LlamaSampler instead" )
371+ raise NotImplementedError (
372+ "sample_repetition_penalties is deprecated, use LlamaSampler instead"
373+ )
370374
371375 def sample_softmax (self , candidates : "_LlamaTokenDataArray" ):
372- raise NotImplementedError ("sample_softmax is deprecated, use LlamaSampler instead" )
376+ raise NotImplementedError (
377+ "sample_softmax is deprecated, use LlamaSampler instead"
378+ )
373379
374380 def sample_top_k (self , candidates : "_LlamaTokenDataArray" , k : int , min_keep : int ):
375- raise NotImplementedError ("sample_top_k is deprecated, use LlamaSampler instead" )
381+ raise NotImplementedError (
382+ "sample_top_k is deprecated, use LlamaSampler instead"
383+ )
376384
377385 def sample_top_p (self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int ):
378- raise NotImplementedError ("sample_top_p is deprecated, use LlamaSampler instead" )
386+ raise NotImplementedError (
387+ "sample_top_p is deprecated, use LlamaSampler instead"
388+ )
379389
380390 def sample_min_p (self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int ):
381- raise NotImplementedError ("sample_min_p is deprecated, use LlamaSampler instead" )
391+ raise NotImplementedError (
392+ "sample_min_p is deprecated, use LlamaSampler instead"
393+ )
382394
383395 def sample_typical (
384396 self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int
385397 ):
386- raise NotImplementedError ("sample_typical is deprecated, use LlamaSampler instead" )
398+ raise NotImplementedError (
399+ "sample_typical is deprecated, use LlamaSampler instead"
400+ )
387401
388402 def sample_temp (self , candidates : "_LlamaTokenDataArray" , temp : float ):
389403 raise NotImplementedError ("sample_temp is deprecated, use LlamaSampler instead" )
390404
391405 def sample_grammar (self , candidates : "_LlamaTokenDataArray" , grammar : LlamaGrammar ):
392- raise NotImplementedError ("sample_grammar is deprecated, use LlamaSampler instead" )
406+ raise NotImplementedError (
407+ "sample_grammar is deprecated, use LlamaSampler instead"
408+ )
393409
394410 def sample_token_mirostat (
395411 self ,
@@ -399,7 +415,9 @@ def sample_token_mirostat(
399415 m : int ,
400416 mu : llama_cpp .CtypesPointerOrRef [ctypes .c_float ],
401417 ) -> int :
402- raise NotImplementedError ("sample_token_mirostat is deprecated, use LlamaSampler instead" )
418+ raise NotImplementedError (
419+ "sample_token_mirostat is deprecated, use LlamaSampler instead"
420+ )
403421
404422 def sample_token_mirostat_v2 (
405423 self ,
@@ -408,17 +426,25 @@ def sample_token_mirostat_v2(
408426 eta : float ,
409427 mu : llama_cpp .CtypesPointerOrRef [ctypes .c_float ],
410428 ) -> int :
411- raise NotImplementedError ("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead" )
429+ raise NotImplementedError (
430+ "sample_token_mirostat_v2 is deprecated, use LlamaSampler instead"
431+ )
412432
413433 def sample_token_greedy (self , candidates : "_LlamaTokenDataArray" ) -> int :
414- raise NotImplementedError ("sample_token_greedy is deprecated, use LlamaSampler instead" )
434+ raise NotImplementedError (
435+ "sample_token_greedy is deprecated, use LlamaSampler instead"
436+ )
415437
416438 def sample_token (self , candidates : "_LlamaTokenDataArray" ) -> int :
417- raise NotImplementedError ("sample_token is deprecated, use LlamaSampler instead" )
439+ raise NotImplementedError (
440+ "sample_token is deprecated, use LlamaSampler instead"
441+ )
418442
419443 # Grammar
420444 def grammar_accept_token (self , grammar : LlamaGrammar , token : int ):
421- raise NotImplementedError ("grammar_accept_token is deprecated, use LlamaSampler instead" )
445+ raise NotImplementedError (
446+ "grammar_accept_token is deprecated, use LlamaSampler instead"
447+ )
422448
423449 def reset_timings (self ):
424450 llama_cpp .llama_perf_context_reset (self .ctx )
@@ -602,16 +628,16 @@ def sample(
602628 logits_array : Optional [npt .NDArray [np .single ]] = None ,
603629 ):
604630 # This method is deprecated in favor of using LlamaSampler directly
605- raise NotImplementedError ("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead" )
631+ raise NotImplementedError (
632+ "LlamaSamplingContext.sample is deprecated, use LlamaSampler instead"
633+ )
606634
607635 def accept (self , ctx_main : LlamaContext , id : int , apply_grammar : bool ):
608636 self .prev .append (id )
609637
610638
611639class CustomSampler :
612- def __init__ (
613- self , apply_func : Callable [[llama_cpp .llama_token_data_array ], None ]
614- ):
640+ def __init__ (self , apply_func : Callable [[llama_cpp .llama_token_data_array ], None ]):
615641 self .apply_func = apply_func
616642
617643 def apply_wrapper (
@@ -723,28 +749,28 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
723749 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
724750
725751 def add_grammar_lazy_patterns (
726- self ,
727- model : LlamaModel ,
752+ self ,
753+ model : LlamaModel ,
728754 grammar : LlamaGrammar ,
729755 trigger_patterns : List [str ],
730- trigger_tokens : List [int ]
756+ trigger_tokens : List [int ],
731757 ):
732758 # Convert patterns to C array
733759 pattern_ptrs = (ctypes .c_char_p * len (trigger_patterns ))()
734760 for i , pattern in enumerate (trigger_patterns ):
735761 pattern_ptrs [i ] = pattern .encode ("utf-8" )
736-
762+
737763 # Convert tokens to C array
738764 token_array = (llama_cpp .llama_token * len (trigger_tokens ))(* trigger_tokens )
739-
765+
740766 sampler = llama_cpp .llama_sampler_init_grammar_lazy_patterns (
741767 model .vocab ,
742768 grammar ._grammar .encode ("utf-8" ),
743769 grammar ._root .encode ("utf-8" ),
744770 pattern_ptrs ,
745771 len (trigger_patterns ),
746772 token_array ,
747- len (trigger_tokens )
773+ len (trigger_tokens ),
748774 )
749775 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
750776
@@ -771,13 +797,13 @@ def add_dry(
771797 dry_base : float ,
772798 dry_allowed_length : int ,
773799 dry_penalty_last_n : int ,
774- seq_breakers : List [str ]
800+ seq_breakers : List [str ],
775801 ):
776802 # Convert seq_breakers to C array
777803 breaker_ptrs = (ctypes .c_char_p * len (seq_breakers ))()
778804 for i , breaker in enumerate (seq_breakers ):
779805 breaker_ptrs [i ] = breaker .encode ("utf-8" )
780-
806+
781807 sampler = llama_cpp .llama_sampler_init_dry (
782808 model .vocab ,
783809 n_ctx_train ,
@@ -786,25 +812,19 @@ def add_dry(
786812 dry_allowed_length ,
787813 dry_penalty_last_n ,
788814 breaker_ptrs ,
789- len (seq_breakers )
815+ len (seq_breakers ),
790816 )
791817 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
792818
793- def add_logit_bias (
794- self ,
795- n_vocab : int ,
796- logit_bias : Dict [int , float ]
797- ):
819+ def add_logit_bias (self , n_vocab : int , logit_bias : Dict [int , float ]):
798820 # Convert logit_bias dict to C array
799821 bias_array = (llama_cpp .llama_logit_bias * len (logit_bias ))()
800822 for i , (token , bias ) in enumerate (logit_bias .items ()):
801823 bias_array [i ].token = token
802824 bias_array [i ].bias = bias
803-
825+
804826 sampler = llama_cpp .llama_sampler_init_logit_bias (
805- n_vocab ,
806- len (logit_bias ),
807- bias_array
827+ n_vocab , len (logit_bias ), bias_array
808828 )
809829 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
810830
@@ -838,15 +858,17 @@ def reset(self):
838858 def clone (self ):
839859 # NOTE: Custom samplers cannot be cloned due to Python callback limitations
840860 if self .custom_samplers :
841- raise NotImplementedError ("Cannot clone LlamaSampler that contains custom samplers" )
842-
861+ raise NotImplementedError (
862+ "Cannot clone LlamaSampler that contains custom samplers"
863+ )
864+
843865 cloned_sampler = llama_cpp .llama_sampler_clone (self .sampler )
844866 # Create a new wrapper around the cloned sampler
845867 new_sampler = LlamaSampler .__new__ (LlamaSampler )
846868 new_sampler .sampler = cloned_sampler
847869 new_sampler .custom_samplers = []
848870 new_sampler ._exit_stack = ExitStack ()
849-
871+
850872 def free_sampler ():
851873 if new_sampler .sampler is not None :
852874 llama_cpp .llama_sampler_free (new_sampler .sampler )
0 commit comments