Skip to content

Commit 3ae5e04

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - initial commit
1 parent 9e2c60e commit 3ae5e04

5 files changed

Lines changed: 170 additions & 53 deletions

File tree

llama/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(GGML_BUILD_TESTS OFF CACHE BOOL "" FORCE)
3333
set(GGML_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
3434

3535
# CPU-only flags
36+
set(GGML_OPENMP OFF CACHE BOOL "" FORCE)
3637
set(GGML_CUDA OFF CACHE BOOL "" FORCE)
3738
set(GGML_METAL OFF CACHE BOOL "" FORCE)
3839
set(GGML_OPENCL OFF CACHE BOOL "" FORCE)
@@ -114,6 +115,9 @@ set_target_properties(llm_test PROPERTIES
114115
# Android native library
115116
# ------------------------------------------------------------------
116117
if (ANDROID)
118+
set(GGML_LLAMAFILE OFF CACHE BOOL "" FORCE)
119+
set(GGML_BLAS OFF CACHE BOOL "" FORCE)
120+
117121
# CMake sets ANDROID when using the Android toolchain
118122
# Re‑use the same source files for the Android .so
119123
add_library(llm_android SHARED

llama/llama-sb.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ Llama::Llama() :
1515
_sampler(nullptr),
1616
_vocab(nullptr),
1717
_temperature(0),
18-
_n_ctx(0) {
18+
_top_k(0),
19+
_top_p(1.0f),
20+
_min_p(0.0f),
21+
_max_tokens(150),
22+
_log_level(GGML_LOG_LEVEL_NONE) {
23+
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
24+
Llama *llama = (Llama *)user_data;
25+
if (level > llama->_log_level) {
26+
fprintf(stderr, "LLAMA: %s", text);
27+
}
28+
}, this);
1929
}
2030

2131
Llama::~Llama() {
@@ -42,62 +52,54 @@ const string Llama::build_chat_prompt(const string &user_msg) {
4252
return _chat_prompt;
4353
}
4454

45-
bool Llama::construct(string model_path, int n_ctx, bool disable_log) {
46-
if (disable_log) {
47-
// only print errors
48-
llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
49-
if (level >= GGML_LOG_LEVEL_ERROR && text[0] != '.' && text[0] != '\n') {
50-
fprintf(stderr, "%s", text);
51-
}
52-
}, nullptr);
53-
}
54-
55+
bool Llama::construct(string model_path, int n_ctx, int n_batch) {
5556
ggml_backend_load_all();
5657

5758
llama_model_params mparams = llama_model_default_params();
58-
mparams.n_gpu_layers = 99;
59+
mparams.n_gpu_layers = 0;
5960

6061
_model = llama_model_load_from_file(model_path.c_str(), mparams);
6162
if (!_model) {
6263
_last_error = "failed to load model";
6364
} else {
6465
llama_context_params cparams = llama_context_default_params();
6566
cparams.n_ctx = n_ctx;
66-
cparams.n_batch = n_ctx;
67+
cparams.n_batch = n_batch;
6768
cparams.no_perf = true;
68-
6969
_ctx = llama_init_from_model(_model, cparams);
7070
if (!_ctx) {
7171
_last_error = "failed to create context";
7272
} else {
7373
_vocab = llama_model_get_vocab(_model);
74+
75+
auto sparams = llama_sampler_chain_default_params();
76+
sparams.no_perf = false;
77+
_sampler = llama_sampler_chain_init(sparams);
7478
}
7579
}
7680
return _last_error.empty();
7781
}
7882

79-
void Llama::configure_sampler(float temperature) {
80-
if (temperature != _temperature || _sampler == nullptr) {
81-
if (_sampler != nullptr) {
82-
llama_sampler_free(_sampler);
83+
void Llama::configure_sampler() {
84+
llama_sampler_reset(_sampler);
85+
if (_temperature <= 0.0f) {
86+
llama_sampler_chain_add(_sampler, llama_sampler_init_greedy());
87+
} else {
88+
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(_temperature));
89+
if (_top_k > 0) {
90+
llama_sampler_chain_add(_sampler, llama_sampler_init_top_k(_top_k));
8391
}
84-
auto sparams = llama_sampler_chain_default_params();
85-
sparams.no_perf = false;
86-
_sampler = llama_sampler_chain_init(sparams);
87-
_temperature = temperature;
88-
89-
// llama_sampler_chain_reset(sampler);
90-
if (temperature <= 0.0f) {
91-
llama_sampler_chain_add(_sampler, llama_sampler_init_greedy());
92-
} else {
93-
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(0.05f, 1));
94-
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature));
95-
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
92+
if (_top_p < 1.0f) {
93+
llama_sampler_chain_add(_sampler, llama_sampler_init_top_p(_top_p, 1));
94+
}
95+
if (_min_p > 0.0f) {
96+
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(_min_p, 1));
9697
}
98+
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
9799
}
98100
}
99101

100-
string Llama::generate(const string &prompt, int max_tokens, float temperature) {
102+
string Llama::generate(const string &prompt) {
101103
string out;
102104

103105
// find the number of tokens in the prompt
@@ -111,7 +113,7 @@ string Llama::generate(const string &prompt, int max_tokens, float temperature)
111113
}
112114

113115
// initialize the sampler
114-
configure_sampler(temperature);
116+
configure_sampler();
115117

116118
// prepare a batch for the prompt
117119
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
@@ -129,7 +131,7 @@ string Llama::generate(const string &prompt, int max_tokens, float temperature)
129131
batch = llama_batch_get_one(&decoder_start_token_id, 1);
130132
}
131133

132-
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + max_tokens;) {
134+
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + _max_tokens;) {
133135
// evaluate the current batch with the transformer model
134136
if (llama_decode(_ctx, batch)) {
135137
_last_error = "failed to eval";

llama/llama-sb.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,32 @@ struct Llama {
1616
explicit Llama();
1717
~Llama();
1818

19+
// init
20+
bool construct(string model_path, int n_ctx, int n_batch);
21+
22+
// generation
23+
string generate(const string &prompt);
24+
25+
// generation parameters
26+
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
27+
void set_min_p(float min_p) { _min_p = min_p; }
28+
void set_temperature(float temperature) { _temperature = temperature; }
29+
void set_top_k(int top_k) { _top_k = top_k; }
30+
void set_top_p(float top_p) { _top_p = top_p; }
31+
32+
// messages
1933
void append_response(const string &response);
34+
void append_user_message(const string &user_msg);
35+
const string& get_chat_history() const;
2036
const string build_chat_prompt(const string &user_msg);
21-
bool construct(string model_path, int n_ctx, bool disable_log);
22-
string generate(const string &prompt, int max_tokens, float temperature);
37+
38+
// error handling
2339
const char *last_error() { return _last_error.c_str(); }
40+
void set_log_level(int level) { _log_level = level; }
2441
void reset();
2542

2643
private:
27-
void configure_sampler(float temperature);
44+
void configure_sampler();
2845

2946
llama_model *_model;
3047
llama_context *_ctx;
@@ -33,5 +50,9 @@ struct Llama {
3350
string _chat_prompt;
3451
string _last_error;
3552
float _temperature;
36-
int _n_ctx;
53+
float _top_p;
54+
float _min_p;
55+
int _top_k;
56+
int _max_tokens;
57+
int _log_level;
3758
};

llama/main.cpp

Lines changed: 104 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,111 @@ static string expand_path(const char *path) {
5050
return result;
5151
}
5252

53+
//
54+
// llama.set_max_tokens(50)
55+
//
56+
static int cmd_llama_set_max_tokens(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
57+
int result = 0;
58+
if (argc != 1) {
59+
error(retval, "llama.set_max_tokens", 1, 1);
60+
} else {
61+
int id = get_class_id(self, retval);
62+
if (id != -1) {
63+
Llama &llama = g_map.at(id);
64+
llama.set_max_tokens(get_param_int(argc, arg, 0, 0));
65+
result = 1;
66+
}
67+
}
68+
return result;
69+
}
70+
71+
//
72+
// llama.set_min_p(0.5)
73+
//
74+
static int cmd_llama_set_min_p(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
75+
int result = 0;
76+
if (argc != 1) {
77+
error(retval, "llama.set_min_p", 1, 1);
78+
} else {
79+
int id = get_class_id(self, retval);
80+
if (id != -1) {
81+
Llama &llama = g_map.at(id);
82+
llama.set_min_p(get_param_num(argc, arg, 0, 0));
83+
result = 1;
84+
}
85+
}
86+
return result;
87+
}
88+
89+
//
90+
// llama.set_temperature(0.8)
91+
//
92+
static int cmd_llama_set_temperature(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
93+
int result = 0;
94+
if (argc != 1) {
95+
error(retval, "llama.set_temperature", 1, 1);
96+
} else {
97+
int id = get_class_id(self, retval);
98+
if (id != -1) {
99+
Llama &llama = g_map.at(id);
100+
llama.set_temperature(get_param_num(argc, arg, 0, 0));
101+
result = 1;
102+
}
103+
}
104+
return result;
105+
}
106+
107+
//
108+
// llama.set_set_top_k(10.0)
109+
//
110+
static int cmd_llama_set_top_k(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
111+
int result = 0;
112+
if (argc != 1) {
113+
error(retval, "llama.set_top_k", 1, 1);
114+
} else {
115+
int id = get_class_id(self, retval);
116+
if (id != -1) {
117+
Llama &llama = g_map.at(id);
118+
llama.set_top_k(get_param_int(argc, arg, 0, 0));
119+
result = 1;
120+
}
121+
}
122+
return result;
123+
}
124+
125+
static int cmd_llama_set_top_p(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
126+
int result = 0;
127+
if (argc != 1) {
128+
error(retval, "llama.set_top_p", 1, 1);
129+
} else {
130+
int id = get_class_id(self, retval);
131+
if (id != -1) {
132+
Llama &llama = g_map.at(id);
133+
llama.set_top_p(get_param_num(argc, arg, 0, 0));
134+
result = 1;
135+
}
136+
}
137+
return result;
138+
}
139+
53140
//
54141
// print llama.chat("Hello")
55142
//
56143
static int cmd_llama_chat(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
57144
int result = 0;
58-
if (argc < 1) {
59-
error(retval, "llama.chat", 1, 3);
145+
if (argc != 1) {
146+
error(retval, "llama.chat", 1, 1);
60147
} else {
61148
int id = get_class_id(self, retval);
62149
if (id != -1) {
63150
Llama &llama = g_map.at(id);
64151
auto prompt = get_param_str(argc, arg, 0, "");
65-
int max_tokens = get_param_int(argc, arg, 1, 32);
66-
var_num_t temperature = get_param_num(argc, arg, 2, 0.8f);
67152

68153
// build accumulated prompt
69154
string updated_prompt = llama.build_chat_prompt(prompt);
70155

71156
// run generation WITHOUT clearing cache
72-
string response = llama.generate(updated_prompt, max_tokens, temperature);
157+
string response = llama.generate(updated_prompt);
73158

74159
// append assistant reply to history
75160
llama.append_response(response);
@@ -100,20 +185,18 @@ static int cmd_llama_reset(var_s *self, int argc, slib_par_t *arg, var_s *retval
100185
}
101186

102187
//
103-
// print llama.generate("please generate as simple program in BASIC to draw a cat", 1024, 0.8)
188+
// print llama.generate("please generate as simple program in BASIC to draw a cat")
104189
//
105190
static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
106191
int result = 0;
107-
if (argc < 1) {
108-
error(retval, "llama.generate", 1, 3);
192+
if (argc != 1) {
193+
error(retval, "llama.generate", 1, 1);
109194
} else {
110195
int id = get_class_id(self, retval);
111196
if (id != -1) {
112197
Llama &llama = g_map.at(id);
113198
auto prompt = get_param_str(argc, arg, 0, "");
114-
int max_tokens = get_param_int(argc, arg, 1, 32);
115-
var_num_t temperature = get_param_num(argc, arg, 2, 0.8f);
116-
string response = llama.generate(prompt, max_tokens, temperature);
199+
string response = llama.generate(prompt);
117200
v_setstr(retval, response.c_str());
118201
result = 1;
119202
}
@@ -124,12 +207,19 @@ static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *ret
124207
static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
125208
int result;
126209
auto model = expand_path(get_param_str(argc, params, 0, ""));
127-
int n_ctx = get_param_int(argc, params, 0, 2048);
128-
int disable_log = get_param_int(argc, params, 1, 1);
210+
auto n_ctx = get_param_int(argc, params, 0, 2048);
211+
auto n_batch = get_param_int(argc, params, 1, 1024);
212+
auto temperature = get_param_num(argc, params, 2, 0.25);
129213
int id = ++g_nextId;
130214
Llama &llama = g_map[id];
131-
if (llama.construct(model, n_ctx, disable_log)) {
215+
if (llama.construct(model, n_ctx, n_batch)) {
216+
llama.set_temperature(temperature);
132217
map_init_id(retval, id, CLASS_ID);
218+
v_create_callback(retval, "set_max_tokens", cmd_llama_set_max_tokens);
219+
v_create_callback(retval, "set_min_p", cmd_llama_set_min_p);
220+
v_create_callback(retval, "set_temperature", cmd_llama_set_temperature);
221+
v_create_callback(retval, "set_top_k", cmd_llama_set_top_k);
222+
v_create_callback(retval, "set_top_p", cmd_llama_set_top_p);
133223
v_create_callback(retval, "chat", cmd_llama_chat);
134224
v_create_callback(retval, "generate", cmd_llama_generate);
135225
v_create_callback(retval, "reset", cmd_llama_reset);

llama/test_main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ int main(int argc, char ** argv) {
5656
}
5757

5858
Llama llama;
59-
if (llama.construct(model_path, 1024, true)) {
60-
string out = llama. generate(prompt, n_predict, 0.8f);
59+
if (llama.construct(model_path, 1024, 1024)) {
60+
string out = llama.generate(prompt);
6161
printf("\033[33m");
6262
printf(out.c_str());
6363
printf("\n\033[0m");

0 commit comments

Comments
 (0)