Skip to content

Commit f423fbf

Browse files
author
Chris Warren-Smith
committed
LLAMA: implement nitro agent (work in progress)
1 parent 4399cbc commit f423fbf

6 files changed

Lines changed: 45 additions & 35 deletions

File tree

llama/llama-sb.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Llama::Llama() :
4545
_top_k(0),
4646
_max_tokens(0),
4747
_log_level(GGML_LOG_LEVEL_CONT),
48+
_n_past(0),
4849
_seed(LLAMA_DEFAULT_SEED) {
4950
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
5051
Llama *llama = (Llama *)user_data;
@@ -75,6 +76,7 @@ Llama::Llama(Llama &&other) noexcept
7576
, _top_k(other._top_k)
7677
, _max_tokens(other._max_tokens)
7778
, _log_level(other._log_level)
79+
, _n_past(other._n_past)
7880
, _seed(other._seed) {
7981
}
8082

@@ -103,6 +105,7 @@ void Llama::reset() {
103105
_top_p = 1.0f;
104106
_min_p = 0.0f;
105107
_max_tokens = 150;
108+
_n_past = 0;
106109
_grammar_src.clear();
107110
_grammar_root.clear();
108111
_seed = LLAMA_DEFAULT_SEED;
@@ -138,7 +141,10 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch, int n_gpu_layer
138141
} else {
139142
_vocab = llama_model_get_vocab(_model);
140143
}
144+
_template = llama_model_chat_template(_model, nullptr);
141145
}
146+
147+
142148
return _last_error.empty();
143149
}
144150

@@ -261,7 +267,20 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
261267
return true;
262268
}
263269

264-
bool Llama::generate(LlamaIter &iter, const string &prompt) {
270+
bool Llama::add_message(LlamaIter &iter, const string &role, const string &content) {
271+
llama_chat_message msg = {role.c_str(), content.c_str()};
272+
273+
int buf_size = 2 * (int)(role.size() + content.size() + 64);
274+
vector<char> buf(buf_size);
275+
bool add_ass = (role == "user");
276+
277+
int32_t n = llama_chat_apply_template(_template, &msg, 1, add_ass, buf.data(), buf.size());
278+
if (n > (int32_t)buf.size()) {
279+
buf.resize(n);
280+
llama_chat_apply_template(_template, &msg, 1, add_ass, buf.data(), buf.size());
281+
}
282+
string prompt(buf.data(), n);
283+
265284
if (!configure_sampler()) {
266285
return false;
267286
}
@@ -271,7 +290,7 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
271290
return false;
272291
}
273292

274-
if (!make_space_for_tokens(prompt_tokens.size(), 0)) {
293+
if (!make_space_for_tokens(prompt_tokens.size(), _n_past)) {
275294
return false;
276295
}
277296

@@ -303,6 +322,7 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
303322
}
304323
}
305324

325+
_n_past += prompt_tokens.size();
306326
iter._t_start = std::chrono::high_resolution_clock::now();
307327
iter._llama = this;
308328
iter._has_next = true;

llama/llama-sb.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct Llama {
5151
bool construct(string model_path, int n_ctx, int n_batch, int n_gpu_layers, int log_level);
5252

5353
// generation
54-
bool generate(LlamaIter &iter, const string &prompt);
54+
bool add_message(LlamaIter &iter, const string &role, const string &content);
5555
string next(LlamaIter &iter);
5656
string all(LlamaIter &iter);
5757

@@ -81,6 +81,7 @@ struct Llama {
8181
bool make_space_for_tokens(int n_tokens, int keep_min);
8282
vector<llama_token> tokenize(const string &prompt);
8383
string token_to_string(LlamaIter &iter, llama_token tok);
84+
bool encode(const string &role, const string &content, bool add_assistant_prompt) ;
8485

8586
llama_model *_model;
8687
llama_context *_ctx;
@@ -90,6 +91,7 @@ struct Llama {
9091
string _grammar_src;
9192
string _grammar_root;
9293
string _last_error;
94+
const char *_template;
9395
int32_t _penalty_last_n;
9496
float _penalty_repeat;
9597
float _penalty_freq;
@@ -100,5 +102,6 @@ struct Llama {
100102
int _top_k;
101103
int _max_tokens;
102104
int _log_level;
105+
int _n_past;
103106
unsigned int _seed;
104107
};

llama/llama.cpp

llama/main.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,20 +401,21 @@ static int cmd_llama_tokens_sec(var_s *self, int argc, slib_par_t *arg, var_s *r
401401
}
402402

403403
//
404-
// print llama.generate("please generate as simple program in BASIC to draw a cat")
404+
// print llama.add_message("please generate as simple program in BASIC to draw a cat")
405405
//
406-
static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
406+
static int cmd_llama_add_message(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
407407
int result = 0;
408-
if (argc != 1) {
409-
error(retval, "llama.generate", 1, 1);
408+
if (argc != 2) {
409+
error(retval, "llama.add_message", 2, 2);
410410
} else {
411411
int id = get_llama_class_id(self, retval);
412412
if (id != -1) {
413413
int iter_id = ++g_nextId;
414414
LlamaIter &iter = g_llama_iter[iter_id];
415415
Llama &llama = g_llama.at(id);
416-
auto prompt = get_param_str(argc, arg, 0, "");
417-
if (llama.generate(iter, prompt)) {
416+
auto role = get_param_str(argc, arg, 0, "");
417+
auto content = get_param_str(argc, arg, 1, "");
418+
if (llama.add_message(iter, role, content)) {
418419
map_init_id(retval, iter_id, CLASS_ID_LLAMA_ITER);
419420
v_create_callback(retval, "all", cmd_llama_all);
420421
v_create_callback(retval, "has_next", cmd_llama_has_next);
@@ -441,7 +442,7 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
441442
if (llama.construct(model, n_ctx, n_batch, n_gpu_layers, n_log_level)) {
442443
map_init_id(retval, id, CLASS_ID_LLAMA);
443444
v_create_callback(retval, "add_stop", cmd_llama_add_stop);
444-
v_create_callback(retval, "generate", cmd_llama_generate);
445+
v_create_callback(retval, "add_message", cmd_llama_add_message);
445446
v_create_callback(retval, "reset", cmd_llama_reset);
446447
v_create_callback(retval, "set_penalty_repeat", cmd_llama_set_penalty_repeat);
447448
v_create_callback(retval, "set_penalty_freq", cmd_llama_set_penalty_freq);

llama/samples/nitro_cli.bas

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,7 @@ func handle_cmd(cmd)
134134
end
135135

136136
'
137-
' Loads knowledge_files then returns the following format:
138-
'
139-
' <|turn|>system
140-
' {nitro.md...}
141-
' <|turn|>
137+
' Loads knowledge_files
142138
'
143139
func initialize_agent()
144140
local prompt = ""
@@ -164,28 +160,18 @@ func initialize_agent()
164160
print " ╚═══════════════════════════════════════╝"
165161
print
166162
print RESET
167-
return "<|turn|>system\n" + prompt + "\n<|turn|>"
163+
return prompt
168164
end
169165

170166
'
171-
' Execute the given tool, then returns the following format:
172-
'
173-
' <|turn|>tool
174-
' {tool_output}
175-
' <|turn|>
176-
' <|turn|>model
167+
' Execute the given tool
177168
'
178169
func process_tool(tool)
179-
return "<|turn|>tool\n" + handle_cmd(trim(tool)) + "\n<|turn|>\n<|turn|>model"
170+
return handle_cmd(trim(tool))
180171
end
181172

182173
'
183-
' Process user input, then returns the following format
184-
'
185-
' <|turn|>user
186-
' {user_input}
187-
' <|turn|>
188-
' <|turn|>model
174+
' Returns the user user input
189175
'
190176
func process_input()
191177
local user_input
@@ -194,7 +180,7 @@ func process_input()
194180
if user_input == "exit" OR user_input = "quit" then
195181
stop
196182
endif
197-
return "<|turn|>user\n" + user_input + "\n<|turn|>\n<|turn|>model"
183+
return user_input
198184
end
199185

200186
'
@@ -219,7 +205,7 @@ end
219205
sub main()
220206
' note: this construct requires recent sbasic fixes
221207
local llama = create_llama()
222-
local iter = llama.generate(initialize_agent())
208+
local iter = llama.add_message("system", initialize_agent())
223209

224210
while 1
225211
local buffer = ""
@@ -259,15 +245,15 @@ sub main()
259245
' Flush remaining line buffer
260246
if len(buffer) > 0 and left(trim(buffer), 5) == "TOOL:" then
261247
' TOOL:xxx should always appear on the final line
262-
iter = llama.generate(process_tool(buffer))
248+
iter = llama.add_message("tool", process_tool(buffer))
263249
else
264250
if len(buffer) > 0 then
265251
' TODO: trim any trailing <|turn|>
266252
print text_colour + buffer + RESET
267253
endif
268254
print
269255
print "--- Tokens/sec: " + round(iter.tokens_sec(), 2) + " ---\n"
270-
iter = llama.generate(process_input())
256+
iter = llama.add_message("user", process_input())
271257
endif
272258
wend
273259
end

llama/test_main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int main(int argc, char ** argv) {
5959
if (llama.construct(model_path, 1024, 1024, -1, GGML_LOG_LEVEL_CONT)) {
6060
LlamaIter iter;
6161
llama.set_max_tokens(n_predict);
62-
llama.generate(iter, prompt);
62+
llama.add_message(iter, "user", prompt);
6363
while (iter._has_next) {
6464
auto out = llama.next(iter);
6565
printf("\033[33m");

0 commit comments

Comments
 (0)