@@ -81,6 +81,7 @@ def __init__(
8181 int
8282 ] = llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
8383 pooling_type : int = llama_cpp .LLAMA_POOLING_TYPE_UNSPECIFIED ,
84+ attention_type : int = llama_cpp .LLAMA_ATTENTION_TYPE_UNSPECIFIED ,
8485 rope_freq_base : float = 0.0 ,
8586 rope_freq_scale : float = 0.0 ,
8687 yarn_ext_factor : float = - 1.0 ,
@@ -163,6 +164,7 @@ def __init__(
163164 n_threads_batch: Number of threads to use for batch processing
164165 rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
165166 pooling_type: Pooling type, from `enum llama_pooling_type`.
167+ attention_type: Attention type, from `enum llama_attention_type`.
166168 rope_freq_base: RoPE base frequency, 0 = from model
167169 rope_freq_scale: RoPE frequency scaling factor, 0 = from model
168170 yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -319,6 +321,7 @@ def __init__(
319321 else llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
320322 )
321323 self .context_params .pooling_type = pooling_type
324+ self .context_params .attention_type = attention_type
322325 self .context_params .rope_freq_base = (
323326 rope_freq_base if rope_freq_base != 0.0 else 0
324327 )
@@ -2100,6 +2103,7 @@ def __getstate__(self):
21002103 n_threads_batch = self .context_params .n_threads_batch ,
21012104 rope_scaling_type = self .context_params .rope_scaling_type ,
21022105 pooling_type = self .context_params .pooling_type ,
2106+ attention_type = self .context_params .attention_type ,
21032107 rope_freq_base = self .context_params .rope_freq_base ,
21042108 rope_freq_scale = self .context_params .rope_freq_scale ,
21052109 yarn_ext_factor = self .context_params .yarn_ext_factor ,
0 commit comments