Skip to content

Commit 7b38c31

Browse files
jamesbiederbeckVictor Biederbeckabetlen
authored
feat: expose attention_type parameter in Llama.__init__ (abetlen#2143)
* feat: expose attention_type parameter in Llama.__init__ * docs: preserve attention_type in pickled state * docs: update changelog for attention_type --------- Co-authored-by: Victor Biederbeck <victor@moria.hiddencove.xyz> Co-authored-by: abetlen <abetlen@gmail.com>
1 parent ccc6bc0 commit 7b38c31

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- feat: Expose `attention_type` in `Llama.__init__` for non-causal embedding models by @jamesbiederbeck in #2143
1011
- fix(ci): Build Docker images from the checked-out source and sanitize branch tags by @abetlen in #2156
1112
- fix(ci): Fix the CUDA wheel workflow and keep release tags aligned with the built toolkit by @abetlen in #2155
1213
- fix(ci): Speed up release wheel builds by moving arm64 off QEMU and parallelizing riscv64 by @abetlen in #2154

llama_cpp/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
int
8282
] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
8383
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
84+
attention_type: int = llama_cpp.LLAMA_ATTENTION_TYPE_UNSPECIFIED,
8485
rope_freq_base: float = 0.0,
8586
rope_freq_scale: float = 0.0,
8687
yarn_ext_factor: float = -1.0,
@@ -163,6 +164,7 @@ def __init__(
163164
n_threads_batch: Number of threads to use for batch processing
164165
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
165166
pooling_type: Pooling type, from `enum llama_pooling_type`.
167+
attention_type: Attention type, from `enum llama_attention_type`.
166168
rope_freq_base: RoPE base frequency, 0 = from model
167169
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
168170
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -319,6 +321,7 @@ def __init__(
319321
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
320322
)
321323
self.context_params.pooling_type = pooling_type
324+
self.context_params.attention_type = attention_type
322325
self.context_params.rope_freq_base = (
323326
rope_freq_base if rope_freq_base != 0.0 else 0
324327
)
@@ -2100,6 +2103,7 @@ def __getstate__(self):
21002103
n_threads_batch=self.context_params.n_threads_batch,
21012104
rope_scaling_type=self.context_params.rope_scaling_type,
21022105
pooling_type=self.context_params.pooling_type,
2106+
attention_type=self.context_params.attention_type,
21032107
rope_freq_base=self.context_params.rope_freq_base,
21042108
rope_freq_scale=self.context_params.rope_freq_scale,
21052109
yarn_ext_factor=self.context_params.yarn_ext_factor,

0 commit comments

Comments
 (0)