Skip to content

Commit 4399cbc

Browse files
author
Chris Warren-Smith
committed
LLAMA: added apis to penalty_freq and penalty_present
1 parent 54a7003 commit 4399cbc

5 files changed

Lines changed: 56 additions & 2 deletions

File tree

llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ target_include_directories(llm PRIVATE
130130
target_link_libraries(llm PRIVATE
131131
llama
132132
ggml
133+
# force dynamic libm
134+
-Wl,-Bdynamic,-lm
133135
)
134136

135137
# Include all static code into plugin

llama/llama-sb.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Llama::Llama() :
3737
_vocab(nullptr),
3838
_penalty_last_n(0),
3939
_penalty_repeat(0),
40+
_penalty_freq(0.0f),
41+
_penalty_present(0.0f),
4042
_temperature(0),
4143
_top_p(0),
4244
_min_p(0),
@@ -65,6 +67,8 @@ Llama::Llama(Llama &&other) noexcept
6567
, _last_error(std::move(other._last_error))
6668
, _penalty_last_n(other._penalty_last_n)
6769
, _penalty_repeat(other._penalty_repeat)
70+
, _penalty_freq(other._penalty_freq)
71+
, _penalty_present(other._penalty_present)
6872
, _temperature(other._temperature)
6973
, _top_p(other._top_p)
7074
, _min_p(other._min_p)
@@ -92,6 +96,8 @@ void Llama::reset() {
9296
_last_error = "";
9397
_penalty_last_n = 64;
9498
_penalty_repeat = 1.1f;
99+
_penalty_freq = 0.0f;
100+
_penalty_present = 0.0f;
95101
_temperature = 0;
96102
_top_k = 0;
97103
_top_p = 1.0f;
@@ -155,7 +161,7 @@ bool Llama::configure_sampler() {
155161
llama_sampler_chain_add(chain, grammar);
156162
}
157163
if (_penalty_last_n != 0 && _penalty_repeat != 1.0f) {
158-
auto penalties = llama_sampler_init_penalties(_penalty_last_n, _penalty_repeat, 0.0f, 0.0f);
164+
auto penalties = llama_sampler_init_penalties(_penalty_last_n, _penalty_repeat, _penalty_freq, _penalty_present);
159165
llama_sampler_chain_add(chain, penalties);
160166
}
161167
if (_temperature <= 0.0f) {

llama/llama-sb.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ struct Llama {
6060
void clear_stops() { _stop_sequences.clear(); }
6161
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; }
6262
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; }
63+
void set_penalty_freq(float penalty_freq) { _penalty_freq = penalty_freq; }
64+
void set_penalty_present(float penalty_present) { _penalty_present = penalty_present; }
6365
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
6466
void set_min_p(float min_p) { _min_p = min_p; }
6567
void set_temperature(float temperature) { _temperature = temperature; }
@@ -90,6 +92,8 @@ struct Llama {
9092
string _last_error;
9193
int32_t _penalty_last_n;
9294
float _penalty_repeat;
95+
float _penalty_freq;
96+
float _penalty_present;
9397
float _temperature;
9498
float _top_p;
9599
float _min_p;

llama/llama.cpp

llama/main.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,46 @@ static int cmd_llama_set_penalty_repeat(var_s *self, int argc, slib_par_t *arg,
104104
return result;
105105
}
106106

107+
//
108+
// llama.set_penalty_freq(0.8)
109+
//
110+
static int cmd_llama_set_penalty_freq(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
111+
int result = 0;
112+
if (argc != 1) {
113+
error(retval, "llama.set_penalty_freq", 1, 1);
114+
} else {
115+
int id = get_llama_class_id(self, retval);
116+
if (id != -1) {
117+
Llama &llama = g_llama.at(id);
118+
auto value = get_param_num(argc, arg, 0, 0);
119+
llama.set_penalty_freq(value);
120+
v_setreal(map_add_var(self, "penalty_freq", 0), value);
121+
result = 1;
122+
}
123+
}
124+
return result;
125+
}
126+
127+
//
128+
// llama.set_penalty_present(0.8)
129+
//
130+
static int cmd_llama_set_penalty_present(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
131+
int result = 0;
132+
if (argc != 1) {
133+
error(retval, "llama.set_penalty_present", 1, 1);
134+
} else {
135+
int id = get_llama_class_id(self, retval);
136+
if (id != -1) {
137+
Llama &llama = g_llama.at(id);
138+
auto value = get_param_num(argc, arg, 0, 0);
139+
llama.set_penalty_present(value);
140+
v_setreal(map_add_var(self, "penalty_present", 0), value);
141+
result = 1;
142+
}
143+
}
144+
return result;
145+
}
146+
107147
//
108148
// llama.set_penalty_last_n(0.8)
109149
//
@@ -404,6 +444,8 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
404444
v_create_callback(retval, "generate", cmd_llama_generate);
405445
v_create_callback(retval, "reset", cmd_llama_reset);
406446
v_create_callback(retval, "set_penalty_repeat", cmd_llama_set_penalty_repeat);
447+
v_create_callback(retval, "set_penalty_freq", cmd_llama_set_penalty_freq);
448+
v_create_callback(retval, "set_penalty_present", cmd_llama_set_penalty_present);
407449
v_create_callback(retval, "set_penalty_last_n", cmd_llama_set_penalty_last_n);
408450
v_create_callback(retval, "set_max_tokens", cmd_llama_set_max_tokens);
409451
v_create_callback(retval, "set_min_p", cmd_llama_set_min_p);

0 commit comments

Comments
 (0)