Skip to content

Commit

Permalink
examples : fix build after sampling refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 15, 2023
1 parent 4a7f43f commit 225de38
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 233 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
$(CXX) $(CXXFLAGS) -c $< -o $@

COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o

common.o: common/common.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
Expand Down
138 changes: 69 additions & 69 deletions common/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,75 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
return buf.str();
}

#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
[&tokens, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (const auto &token : tokens) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, token); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "'" << detokenized << "'" \
<< ":" << std::to_string(token); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()

#define LOG_BATCH_TOSTR_PRETTY(ctx, batch) \
[&batch, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (int i = 0; i < batch.n_tokens; ++i) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, batch.token[i]); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "\n" << std::to_string(i) \
<< ":token '" << detokenized << "'" \
<< ":pos " << std::to_string(batch.pos[i]) \
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i]) \
<< ":seq_id " << std::to_string(batch.seq_id[i][0]) \
<< ":logits " << std::to_string(batch.logits[i]); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()
template <typename C, typename T>
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
{
std::stringstream buf;
buf << "[ ";

bool first = true;
for (const auto &token : tokens)
{
if (!first) {
buf << ", ";
} else {
first = false;
}

auto detokenized = llama_token_to_piece(ctx, token);

detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf
<< "'" << detokenized << "'"
<< ":" << std::to_string(token);
}
buf << " ]";

return buf.str();
}

template <typename C, typename B>
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
{
std::stringstream buf;
buf << "[ ";

bool first = true;
for (int i = 0; i < batch.n_tokens; ++i)
{
if (!first) {
buf << ", ";
} else {
first = false;
}

auto detokenized = llama_token_to_piece(ctx, batch.token[i]);

detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf
<< "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
}
buf << " ]";

return buf.str();
}

#ifdef LOG_DISABLE_LOGS

Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
Expand Down
44 changes: 18 additions & 26 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ int main(int argc, char ** argv) {

LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());

// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}

// Tokenize negative prompt
Expand All @@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
Expand All @@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}

LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());


// enable interactive mode if interactive start is specified
Expand Down Expand Up @@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}

// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
Expand Down Expand Up @@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;

const int n_vocab = llama_n_vocab(model);

llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);

while (n_remain != 0 || params.interactive) {
// predict
Expand Down Expand Up @@ -484,7 +477,7 @@ int main(int argc, char ** argv) {

LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);

LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

}

Expand Down Expand Up @@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();

LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
Expand All @@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}

LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
Expand All @@ -554,12 +547,11 @@ int main(int argc, char ** argv) {

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);

last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx, id);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());

embd.push_back(id);

Expand All @@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
Expand Down Expand Up @@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {

// deal with eot token in infill mode
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
Expand Down Expand Up @@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (last_tokens.back() == llama_token_eos(ctx)) {
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");

if (params.interactive) {
Expand Down Expand Up @@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();

const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

Expand Down
Loading

0 comments on commit 225de38

Please sign in to comment.