diff --git a/examples/talk/gpt-2.cpp b/examples/talk/gpt-2.cpp index 43ca8fa0..f1638e85 100644 --- a/examples/talk/gpt-2.cpp +++ b/examples/talk/gpt-2.cpp @@ -1,78 +1,32 @@ #include "ggml.h" #include "common-ggml.h" +#include "ggml-backend.h" +#include "ggml-alloc.h" #include "gpt-2.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif -/////////////////////// GPT-2 BEGIN ///////////////////////// +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif -// default hparams (GPT-2 117M) -struct gpt2_hparams { - int32_t n_vocab = 50257; - int32_t n_ctx = 1024; - int32_t n_embd = 768; - int32_t n_head = 12; - int32_t n_layer = 12; - int32_t ftype = 1; -}; -struct gpt2_layer { - // normalization - struct ggml_tensor * ln_1_g; - struct ggml_tensor * ln_1_b; +#define GPT2_MAX_NODES 4096 - struct ggml_tensor * ln_2_g; - struct ggml_tensor * ln_2_b; +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} - // attention - struct ggml_tensor * c_attn_attn_w; - struct ggml_tensor * c_attn_attn_b; - struct ggml_tensor * c_attn_proj_w; - struct ggml_tensor * c_attn_proj_b; - - // mlp - struct ggml_tensor * c_mlp_fc_w; - struct ggml_tensor * c_mlp_fc_b; - - struct ggml_tensor * c_mlp_proj_w; - struct ggml_tensor * c_mlp_proj_b; -}; - -struct gpt2_model { - gpt2_hparams hparams; - - // normalization - struct ggml_tensor * ln_f_g; - struct ggml_tensor * ln_f_b; - - struct ggml_tensor * wte; // position embedding - struct ggml_tensor * wpe; // token embedding - struct ggml_tensor * lm_head; // language model head - - std::vector layers; - - // key + value memory - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - // - struct ggml_context * ctx; - std::map tensors; -}; // load the model's weights from a file -static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) { +bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx = 2048, int n_gpu_layers = 0) { printf("%s: loading model from '%s'\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); @@ -85,7 +39,7 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } @@ -102,12 +56,17 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; } // load vocab @@ -121,13 +80,16 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v return false; } - char word[129]; + std::string word; + std::vector buf(128); for (int i = 0; i < n_vocab; i++) { uint32_t len; fin.read((char *) &len, sizeof(len)); - word[len] = '\0'; - fin.read((char *) word, len); + + buf.resize(len); + fin.read((char *) buf.data(), len); + word.assign(buf.data(), len); vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; @@ -143,67 +105,58 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v return false; } - auto & ctx = model.ctx; - - size_t ctx_size = 0; - - { - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - - ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g - ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b - - ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte - ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe - ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head - - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b - - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b - - ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b - - ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b - - ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b - - ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w - ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b - - ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k - ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v - - ctx_size += (6 + 12*n_layer)*256; // object overhead - - printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - } + auto & ctx = model.ctx_w; // create the ggml context { + size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer; struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, + /*.mem_size =*/ ggml_tensor_overhead() * n_tensors, /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - model.ctx = ggml_init(params); - if (!model.ctx) { + ctx = ggml_init(params); + if (!ctx) { fprintf(stderr, "%s: ggml_init() failed\n", __func__); return false; } } - // prepare memory for the weights + // initialize the backend +#ifdef GGML_USE_CUDA + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (!model.backend) { + // fallback to CPU backend + fprintf(stderr, "%s: using CPU backend\n", __func__); + model.backend = ggml_backend_cpu_init(); + } + + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__); + return false; + } + + // create the tensors for the model { const auto & hparams = model.hparams; @@ -271,8 +224,35 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v } } + // allocate the model tensors in a backend buffer + model.buffer_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend); + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %6.2f MB\n", __func__, ggml_backend_buffer_get_size(model.buffer_w)/(1024.0*1024.0)); + + // override the default training context with the user-provided + model.hparams.n_ctx = n_ctx; + // key + value memory { + auto * ctx = model.ctx_kv; + + // create the ggml context + { + size_t n_tensors = 2; + struct ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * n_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ctx = ggml_init(params); + if (!ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + const auto & hparams = model.hparams; const int n_embd = hparams.n_embd; @@ -285,8 +265,10 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + // allocate the KV memory in a backend buffer + model.buffer_kv = ggml_backend_alloc_ctx_tensors(ctx, model.backend); + const size_t memory_size = ggml_backend_buffer_get_size(model.buffer_kv); printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); } @@ -296,6 +278,8 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v bool has_lm_head = false; + std::vector read_buf; + while (true) { int32_t n_dims; int32_t length; @@ -319,41 +303,51 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v std::string name(length, 0); fin.read(&name[0], length); - if (model.tensors.find(name.data()) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + if (model.tensors.find(name) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str()); return false; } - auto tensor = model.tensors[name.data()]; + auto tensor = model.tensors[name]; + ggml_set_name(tensor, name.c_str()); if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str()); return false; } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); + __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); return false; } // for debugging if (0) { - printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); } const size_t bpe = ggml_type_size(ggml_type(ttype)); if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe); return false; } - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + if (ggml_backend_buffer_is_host(model.buffer_w)) { + // for some backends such as CPU and Metal, the tensor data is in system memory and we can read directly into it + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + fin.read(read_buf.data(), ggml_nbytes(tensor)); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } // GPT-2 models share the WTE tensor as the LM head if (name == "model/wte" && has_lm_head == false) { - memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor)); + //ggml_backend_tensor_copy(tensor, model.lm_head); + model.lm_head = tensor; } if (name == "model/lm_head") { @@ -371,350 +365,6 @@ static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_v return true; } -// evaluate the transformer -// -// - model: the model -// - n_threads: number of threads to use -// - n_past: the context size so far -// - embd_inp: the embeddings of the tokens in the context -// - embd_w: the predicted logits for the next token -// -// TODO: sync latest version from ggml repo -static bool gpt2_eval( - const gpt2_model & model, - const int n_threads, - const int n_past, - const std::vector & embd_inp, - std::vector & embd_w, - size_t & mem_per_token) { - const int N = embd_inp.size(); - - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - - static size_t buf_size = 512u*1024*1024; - static void * buf = malloc(buf_size); - - if (mem_per_token > 0 && mem_per_token*N > buf_size) { - const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead - //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); - - // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); - return false; - } - } - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); - - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; - } - - // wte + wpe - struct ggml_tensor * inpL = - ggml_add(ctx0, - ggml_get_rows(ctx0, model.wte, embd), - ggml_get_rows(ctx0, model.wpe, position)); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * cur; - - // norm - { - // [ 768, N] - cur = ggml_norm(ctx0, inpL, 1e-5f); - - // cur = ln_1_g*cur + ln_1_b - // [ 768, N] - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_1_g, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); - } - - // attn - // [2304, 768] - model.layers[il].c_attn_attn_w - // [2304, 1] - model.layers[il].c_attn_attn_b - // [ 768, N] - cur (in) - // [2304, N] - cur (out) - // - // cur = attn_w*cur + attn_b - // [2304, N] - { - cur = ggml_mul_mat(ctx0, - model.layers[il].c_attn_attn_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur), - cur); - } - - // self-attention - { - struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd); - struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd); - - // store key and value to memory - if (N >= 1) { - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); - } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - // [64, N, 12] - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), - 0, 2, 1, 3); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - // [64, n_past + N, 12] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), - n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - // GG: flash attention - //struct ggml_tensor * V = - // ggml_cpy(ctx0, - // ggml_permute(ctx0, - // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), - // n_embd/n_head, n_head, n_past + N), - // 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); - - //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); - - // K * Q - // [n_past + N, N, 12] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // [n_past + N, N, 12] - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - 1.0f/sqrt(float(n_embd)/n_head)); - - // KQ_masked = mask_past(KQ_scaled) - // [n_past + N, N, 12] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - // [n_past + N, N, 12] - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - // [n_past + N, 64, 12] - struct ggml_tensor * V_trans = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), - n_embd/n_head, n_head, n_past + N), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); - - // KQV = transpose(V) * KQ_soft_max - // [64, N, 12] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // [64, 12, N] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - // [768, N] - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - } - - // projection - // [ 768, 768] - model.layers[il].c_attn_proj_w - // [ 768, 1] - model.layers[il].c_attn_proj_b - // [ 768, N] - cur (in) - // [ 768, N] - cur (out) - // - // cur = proj_w*cur + proj_b - // [768, N] - { - cur = ggml_mul_mat(ctx0, - model.layers[il].c_attn_proj_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur), - cur); - } - - // add the input - cur = ggml_add(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - cur = ggml_norm(ctx0, inpFF, 1e-5f); - - // cur = ln_2_g*cur + ln_2_b - // [ 768, N] - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_2_g, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); - } - - // fully connected - // [3072, 768] - model.layers[il].c_mlp_fc_w - // [3072, 1] - model.layers[il].c_mlp_fc_b - // [ 768, N] - cur (in) - // [3072, N] - cur (out) - // - // cur = fc_w*cur + fc_b - // [3072, N] - cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_fc_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), - cur); - - // GELU activation - // [3072, N] - cur = ggml_gelu(ctx0, cur); - - // projection - // [ 768, 3072] - model.layers[il].c_mlp_proj_w - // [ 768, 1] - model.layers[il].c_mlp_proj_b - // [3072, N] - cur (in) - // [ 768, N] - cur (out) - // - // cur = proj_w*cur + proj_b - // [768, N] - cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_proj_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), - cur); - } - - // input for next layer - inpL = ggml_add(ctx0, cur, inpFF); - } - - // norm - { - // [ 768, N] - inpL = ggml_norm(ctx0, inpL, 1e-5f); - - // inpL = ln_f_g*inpL + ln_f_b - // [ 768, N] - inpL = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.ln_f_g, inpL), - inpL), - ggml_repeat(ctx0, model.ln_f_b, inpL)); - } - - // inpL = WTE * inpL - // [ 768, 50257] - model.lm_head - // [ 768, N] - inpL - inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); - - // logits -> probs - //inpL = ggml_soft_max(ctx0, inpL); - - // run the computation - ggml_build_forward_expand (&gf, inpL); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - - //if (n_past%100 == 0) { - // ggml_graph_print (&gf); - // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); - //} - - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - - // return result just for the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; - } - //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); - - ggml_free(ctx0); - - return true; -} - -/////////////////////////////// GPT-2 END //////////////////////////////// - -constexpr int N_THREAD = 8; - -struct gpt2_context { - std::string prompt_base = R"(Hello, how are you? -I'm fine, thanks. How are you? -Thanks, I'm fine too. What are you doing? -I'm just sitting here. -It's a lovely day, isn't it? -Yes, it is. I love the weather this time of year. -I wish it would rain a little bit. -Me too. -)"; - - std::mt19937 rng; - - gpt_vocab vocab; - gpt2_model model; - - int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); - - // sampling parameters - int32_t top_k = 5; - float top_p = 0.9f; - float temp = 1.0f; -}; struct gpt2_context * gpt2_init(const char * path_model) { gpt2_context * ctx = new gpt2_context; @@ -739,6 +389,8 @@ struct gpt2_context * gpt2_init(const char * path_model) { return ctx; } + + void gpt2_free(struct gpt2_context * ctx) { delete ctx; } @@ -755,7 +407,357 @@ std::vector gpt2_tokenize(const gpt2_context * ctx, const char * return ::gpt_tokenize(ctx->vocab, text); } -std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) { + +// build the computation graph +struct ggml_cgraph * gpt2_graph( + const gpt2_model & model, + const int n_past, + const int n_tokens) { + const int N = n_tokens; + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + struct ggml_context * ctx = ggml_init(params); + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, GPT2_MAX_NODES, false); + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + // at this point, the tensor data is not allocated yet and cannot be set + // we will find the tensor after the graph is allocated by its name, and set the data then + ggml_set_name(embd, "embd"); + // setting a tensor as an input will ensure that it is allocated at the beginning of the graph + // this is important to ensure that the input tensors are not overwritten before they are used + ggml_set_input(embd); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + ggml_set_name(position, "position"); + ggml_set_input(position); + + // wte + wpe + struct ggml_tensor * inpL = + ggml_add(ctx, + ggml_get_rows(ctx, model.wte, embd), + ggml_get_rows(ctx, model.wpe, position)); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur; + + // norm + { + // [ 768, N] + cur = ggml_norm(ctx, inpL, hparams.eps); + + // cur = ln_1_g*cur + ln_1_b + // [ 768, N] + cur = ggml_add(ctx, + ggml_mul(ctx, + cur, + model.layers[il].ln_1_g), + model.layers[il].ln_1_b); + } + + // attn + // [2304, 768] - model.layers[il].c_attn_attn_w + // [2304, 1] - model.layers[il].c_attn_attn_b + // [ 768, N] - cur (in) + // [2304, N] - cur (out) + // + // cur = attn_w*cur + attn_b + // [2304, N] + { + cur = ggml_mul_mat(ctx, + model.layers[il].c_attn_attn_w, + cur); + + cur = ggml_add(ctx, + cur, + model.layers[il].c_attn_attn_b); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * Kcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd); + struct ggml_tensor * Vcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd); + + // store key and value to memory + if (N >= 1) { + struct ggml_tensor * k = ggml_view_1d(ctx, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + // [64, N, 12] + struct ggml_tensor * Q = + ggml_permute(ctx, + ggml_cont_3d(ctx, Qcur, n_embd/n_head, n_head, N), + 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + // [64, n_past + N, 12] + struct ggml_tensor * K = + ggml_permute(ctx, + ggml_reshape_3d(ctx, + ggml_view_1d(ctx, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // GG: flash attention + //struct ggml_tensor * V = + // ggml_cpy(ctx0, + // ggml_permute(ctx0, + // ggml_reshape_3d(ctx0, + // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + // n_embd/n_head, n_head, n_past + N), + // 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); + + //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); + + // K * Q + // [n_past + N, N, 12] + struct ggml_tensor * KQ = ggml_mul_mat(ctx, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx, + KQ, + 1.0f/sqrtf(float(n_embd)/n_head)); + + // KQ_masked = mask_past(KQ_scaled) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + // [n_past + N, 64, 12] + struct ggml_tensor * V_trans = + ggml_cont_3d(ctx, + ggml_permute(ctx, + ggml_reshape_3d(ctx, + ggml_view_1d(ctx, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3), + n_past + N, n_embd/n_head, n_head); + + // KQV = transpose(V) * KQ_soft_max + // [64, N, 12] + struct ggml_tensor * KQV = ggml_mul_mat(ctx, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // [64, 12, N] + struct ggml_tensor * KQV_merged = ggml_permute(ctx, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + // [768, N] + cur = ggml_cont_2d(ctx, KQV_merged, n_embd, N); + } + + // projection + // [ 768, 768] - model.layers[il].c_attn_proj_w + // [ 768, 1] - model.layers[il].c_attn_proj_b + // [ 768, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + { + cur = ggml_mul_mat(ctx, + model.layers[il].c_attn_proj_w, + cur); + + cur = ggml_add(ctx, + cur, + model.layers[il].c_attn_proj_b); + } + + // add the input + cur = ggml_add(ctx, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx, inpFF, hparams.eps); + + // cur = ln_2_g*cur + ln_2_b + // [ 768, N] + cur = ggml_add(ctx, + ggml_mul(ctx, + cur, + model.layers[il].ln_2_g), + model.layers[il].ln_2_b); + } + + // fully connected + // [3072, 768] - model.layers[il].c_mlp_fc_w + // [3072, 1] - model.layers[il].c_mlp_fc_b + // [ 768, N] - cur (in) + // [3072, N] - cur (out) + // + // cur = fc_w*cur + fc_b + // [3072, N] + cur = ggml_mul_mat(ctx, + model.layers[il].c_mlp_fc_w, + cur); + + cur = ggml_add(ctx, + cur, + model.layers[il].c_mlp_fc_b); + + // GELU activation + // [3072, N] + cur = ggml_gelu(ctx, cur); + + // projection + // [ 768, 3072] - model.layers[il].c_mlp_proj_w + // [ 768, 1] - model.layers[il].c_mlp_proj_b + // [3072, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + cur = ggml_mul_mat(ctx, + model.layers[il].c_mlp_proj_w, + cur); + + cur = ggml_add(ctx, + cur, + model.layers[il].c_mlp_proj_b); + } + + // input for next layer + inpL = ggml_add(ctx, cur, inpFF); + } + + // norm + { + // [ 768, N] + inpL = ggml_norm(ctx, inpL, hparams.eps); + + // inpL = ln_f_g*inpL + ln_f_b + // [ 768, N] + inpL = ggml_add(ctx, + ggml_mul(ctx, + inpL, + model.ln_f_g), + model.ln_f_b); + } + + // inpL = WTE * inpL + // [ 768, 50257] - model.lm_head + // [ 768, N] - inpL + inpL = ggml_mul_mat(ctx, model.lm_head, inpL); + ggml_set_name(inpL, "logits"); + // setting a tensor as the output will ensure that it is not overwritten by subsequent operations + ggml_set_output(inpL); + + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + + ggml_build_forward_expand(gf, inpL); + + ggml_free(ctx); + + return gf; +} + + +// evaluate the transformer +// +// - model: the model +// - allocr: ggml_gallocr to use to allocate the compute buffer +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool gpt2_eval( + const gpt2_model & model, + ggml_gallocr_t allocr, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + struct ggml_cgraph * gf = gpt2_graph(model, n_past, embd_inp.size()); + + // allocate the graph tensors + ggml_gallocr_alloc_graph(allocr, gf); + + // set the graph inputs + struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "embd"); + ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd)); + + struct ggml_tensor * position = ggml_graph_get_tensor(gf, "position"); + for (int i = 0; i < N; ++i) { + int32_t v = n_past + i; + ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v)); + } + + // set backend options + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + // run the computation + ggml_backend_graph_compute(model.backend, gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + // get the graph outputs + struct ggml_tensor * logits = ggml_graph_get_tensor(gf, "logits"); + + //embd_w.resize(n_vocab*N); + //ggml_backend_tensor_get(logits, embd_w.data(), 0, sizeof(float)*n_vocab*N); + + // return result just for the last token + embd_w.resize(n_vocab); + ggml_backend_tensor_get(logits, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab); + + return true; +} + + +std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens, ggml_gallocr* allocr) { int n_past = 0; std::vector embd_w; @@ -767,14 +769,12 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) std::vector embd = embd_inp; - size_t mem_per_token = 3000000; - std::string result; for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) { - // predict + if (!embd.empty()) { - if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) { + if (!gpt2_eval(ctx->model, allocr, ctx->n_threads, n_past, embd, embd_w)) { printf("gpt-2: failed to generate text\n"); return ""; } diff --git a/examples/talk/gpt-2.h b/examples/talk/gpt-2.h index 756fbfa9..dc5cbdbf 100644 --- a/examples/talk/gpt-2.h +++ b/examples/talk/gpt-2.h @@ -7,10 +7,101 @@ #include #include #include +#include -struct gpt2_context; -struct gpt2_context * gpt2_init(const char * path_model); +struct gpt2_layer { + // normalization + struct ggml_tensor * ln_1_g; + struct ggml_tensor * ln_1_b; + + struct ggml_tensor * ln_2_g; + struct ggml_tensor * ln_2_b; + + // attention + struct ggml_tensor * c_attn_attn_w; + struct ggml_tensor * c_attn_attn_b; + + struct ggml_tensor * c_attn_proj_w; + struct ggml_tensor * c_attn_proj_b; + + // mlp + struct ggml_tensor * c_mlp_fc_w; + struct ggml_tensor * c_mlp_fc_b; + + struct ggml_tensor * c_mlp_proj_w; + struct ggml_tensor * c_mlp_proj_b; +}; + +constexpr int N_THREAD = 8; + +// default hparams (GPT-2 117M) +struct gpt2_hparams { + int32_t n_vocab = 50257; + int32_t n_ctx = 1024; + int32_t n_embd = 768; + int32_t n_head = 12; + int32_t n_layer = 12; + int32_t ftype = 1; + float eps = 1e-5f; +}; + +struct gpt2_model { + gpt2_hparams hparams; + + // normalization + struct ggml_tensor * ln_f_g; + struct ggml_tensor * ln_f_b; + + struct ggml_tensor * wte; // position embedding + struct ggml_tensor * wpe; // token embedding + struct ggml_tensor * lm_head; // language model head + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + // + struct ggml_context * ctx_w; + struct ggml_context * ctx_kv; + + ggml_backend* backend = NULL; + + ggml_backend_buffer * buffer_w; + ggml_backend_buffer * buffer_kv; + + std::map tensors; +}; + + +struct gpt2_context { + std::string prompt_base = R"(Hello, how are you? +I'm fine, thanks. How are you? +Thanks, I'm fine too. What are you doing? +I'm just sitting here. +It's a lovely day, isn't it? +Yes, it is. I love the weather this time of year. +I wish it would rain a little bit. +Me too. +)"; + + std::mt19937 rng; + + gpt_vocab vocab; + gpt2_model model; + + int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); + + // sampling parameters + int32_t top_k = 5; + float top_p = 0.9f; + float temp = 1.0f; +}; + +bool gpt2_model_load(const std::string &fname, gpt2_model &model, gpt_vocab &vocab, int n_ctx, int n_gpu_layers); +struct gpt2_context *gpt2_init(const char *path_model); void gpt2_free(struct gpt2_context * ctx); const char * gpt2_get_prompt(struct gpt2_context * ctx); @@ -18,4 +109,17 @@ void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt); std::vector gpt2_tokenize(const gpt2_context * ctx, const char * text); -std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens); + +std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens, ggml_gallocr* allocr); +struct ggml_cgraph *gpt2_graph( + const gpt2_model &model, + const int n_past, + const int n_tokens); + +bool gpt2_eval( + const gpt2_model &model, + ggml_gallocr_t allocr, + const int n_threads, + const int n_past, + const std::vector &embd_inp, + std::vector &embd_w); \ No newline at end of file diff --git a/examples/talk/speak b/examples/talk/speak old mode 100644 new mode 100755 diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index 428f38b7..c116617c 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -2,6 +2,9 @@ // #include "common-sdl.h" + +#include "ggml-backend.h" +#include "ggml-alloc.h" #include "common.h" #include "whisper.h" #include "gpt-2.h" @@ -221,10 +224,12 @@ int main(int argc, char ** argv) { // init audio audio_async audio(30*1000); - if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { + if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) + { fprintf(stderr, "%s: audio.init() failed!\n", __func__); return 1; } + printf("audio init successful....\n"); audio.resume(); @@ -239,6 +244,12 @@ int main(int argc, char ** argv) { std::vector pcmf32_prompt; gpt2_set_prompt(ctx_gpt, ""); + ggml_gallocr_t allocr = NULL; + // allocate the compute buffer + { + // create a graph allocator with the backend's default buffer type + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(ctx_gpt->model.backend)); + } const int voice_id = rand()%6; @@ -319,7 +330,7 @@ int main(int argc, char ** argv) { std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base); - text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens); + text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens, allocr); //text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), ""); text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));