mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
talk-llama : sync llama.cpp
This commit is contained in:
parent
bd41733db2
commit
e72e4158de
@ -11,6 +11,10 @@
|
||||
# include "ggml-cuda.h"
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
# include "ggml-opencl.h"
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
# include "ggml-vulkan.h"
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
# include "ggml-sycl.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
@ -52,6 +56,7 @@
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
@ -196,6 +201,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -217,6 +223,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
@ -641,6 +648,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ORION,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
@ -1256,8 +1282,14 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
|
||||
if (host_buffer) {
|
||||
buft = ggml_backend_cuda_host_buffer_type();
|
||||
}
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
buft = ggml_backend_sycl_host_buffer_type();
|
||||
#elif defined(GGML_USE_CPU_HBM)
|
||||
buft = ggml_backend_cpu_hbm_buffer_type();
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (host_buffer) {
|
||||
buft = ggml_backend_vk_host_buffer_type();
|
||||
}
|
||||
#endif
|
||||
|
||||
if (buft == nullptr) {
|
||||
@ -1275,6 +1307,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
|
||||
buft = ggml_backend_metal_buffer_type();
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
buft = ggml_backend_cuda_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
buft = ggml_backend_vk_buffer_type();
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
buft = ggml_backend_sycl_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
buft = ggml_backend_opencl_buffer_type();
|
||||
#endif
|
||||
@ -1332,6 +1368,7 @@ enum e_model {
|
||||
MODEL_7B,
|
||||
MODEL_8B,
|
||||
MODEL_13B,
|
||||
MODEL_14B,
|
||||
MODEL_15B,
|
||||
MODEL_30B,
|
||||
MODEL_34B,
|
||||
@ -2683,6 +2720,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_8B: return "8B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_14B: return "14B";
|
||||
case MODEL_15B: return "15B";
|
||||
case MODEL_30B: return "30B";
|
||||
case MODEL_34B: return "34B";
|
||||
@ -2950,7 +2988,15 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 40: model.type = e_model::MODEL_14B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@ -3933,6 +3979,38 @@ static bool llm_load_tensors(
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
|
||||
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -4563,6 +4641,126 @@ struct llm_build_context {
|
||||
ctx0 = nullptr;
|
||||
}
|
||||
}
|
||||
struct ggml_cgraph * build_orion() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
cb(inp_pos, "inp_pos", -1);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
// if (model.layers[il].bq) {
|
||||
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
// cb(Qcur, "Qcur", il);
|
||||
// }
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
// if (model.layers[il].bk) {
|
||||
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
// cb(Kcur, "Kcur", il);
|
||||
// }
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
// if (model.layers[il].bv) {
|
||||
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
// cb(Vcur, "Vcur", il);
|
||||
// }
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, model.output_norm_b,
|
||||
LLM_NORM, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct ggml_cgraph * build_llama() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
@ -6520,6 +6718,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_codeshell();
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
result = llm.build_orion();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -6652,7 +6854,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
|
||||
if (ggml_cpu_has_cublas() && fully_offloaded) {
|
||||
if ((ggml_cpu_has_cublas() || ggml_cpu_has_vulkan()) && fully_offloaded) {
|
||||
n_threads = 1;
|
||||
}
|
||||
|
||||
@ -7946,6 +8148,11 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
|
||||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
@ -8054,21 +8261,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
float scale = candidates->data[0].p; // scale by max prob
|
||||
size_t i = 1; // first token always matches
|
||||
bool min_p_applied = false;
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].p < p * scale && i >= min_keep) {
|
||||
break; // prob too small
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
if (!candidates->sorted) {
|
||||
std::vector<llama_token_data> filtered_tokens;
|
||||
|
||||
float max_logit = -FLT_MAX;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
max_logit = std::max(max_logit, candidates->data[i].logit);
|
||||
}
|
||||
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit >= min_logit) {
|
||||
filtered_tokens.push_back(candidates->data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// if we have enough values the operation was a success
|
||||
if (filtered_tokens.size() >= min_keep) {
|
||||
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||
candidates->size = filtered_tokens.size();
|
||||
min_p_applied = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
// if the candidates are sorted or the unsorted implementation failed, use this implementation
|
||||
if (!min_p_applied) {
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
||||
size_t i = 1; // first token always matches
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
||||
break; // prob too small
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
@ -9997,6 +10239,26 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (model->n_gpu_layers > 0) {
|
||||
ggml_backend_t backend = ggml_backend_vk_init();
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
if (model->n_gpu_layers > 0) {
|
||||
ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
#endif
|
||||
ctx->backend_cpu = ggml_backend_cpu_init();
|
||||
if (ctx->backend_cpu == nullptr) {
|
||||
|
@ -6,6 +6,9 @@
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
#include "ggml-sycl.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
|
||||
#else
|
||||
#define LLAMA_MAX_DEVICES 1
|
||||
#endif // GGML_USE_CUBLAS
|
||||
@ -46,7 +49,7 @@
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 4
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user