mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-13 04:28:07 +00:00
talk-llama : sync llama.cpp
This commit is contained in:
@ -31,7 +31,7 @@
|
||||
#endif
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
|
||||
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
|
||||
// loading time will be recalculated after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = 0;
|
||||
@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||
model.t_start_us = tm.t_start_us;
|
||||
|
||||
try {
|
||||
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
|
||||
ml.print_info();
|
||||
|
||||
@ -4642,7 +4642,7 @@ struct llm_build_context {
|
||||
0);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@ -4651,7 +4651,7 @@ struct llm_build_context {
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
k_pe = ggml_rope_ext(
|
||||
ctx0, k_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@ -6496,7 +6496,7 @@ struct llm_build_context {
|
||||
0);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@ -6505,7 +6505,7 @@ struct llm_build_context {
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
|
||||
k_pe = ggml_rope_ext(
|
||||
ctx0, k_pe, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@ -7215,17 +7215,30 @@ struct llm_build_context {
|
||||
struct ggml_tensor * Qcur = nullptr;
|
||||
struct ggml_tensor * Kcur = nullptr;
|
||||
struct ggml_tensor * Vcur = nullptr;
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
if (model.type == LLM_TYPE_1_5B || model.type == LLM_TYPE_4B || model.type == LLM_TYPE_9B) {
|
||||
Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
} else {
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
if (model.layers[il].bqkv) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
@ -7700,17 +7713,13 @@ struct llm_build_context {
|
||||
1
|
||||
);
|
||||
|
||||
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_states,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self.v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
||||
)
|
||||
ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
|
||||
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
|
||||
)
|
||||
);
|
||||
|
||||
@ -8432,13 +8441,141 @@ static enum ggml_status llama_graph_compute(
|
||||
return status;
|
||||
}
|
||||
|
||||
static int llama_prepare_sbatch(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
uint32_t & n_outputs) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !lctx.kv_self.recurrent,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int llama_prepare_ubatch(
|
||||
llama_context & lctx,
|
||||
llama_kv_slot_restorer & kv_slot_restorer,
|
||||
llama_ubatch & ubatch,
|
||||
const uint32_t n_outputs,
|
||||
const uint32_t n_tokens_all) {
|
||||
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (lctx.kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
|
||||
}
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += int32_t(ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
// in case of unsuccessful decoding (error or warning),
|
||||
// the kv_cache state will be returned to its original state
|
||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
// - inp_batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
@ -8455,37 +8592,18 @@ static int llama_decode_impl(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// temporarily allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
|
||||
@ -8495,99 +8613,27 @@ static int llama_decode_impl(
|
||||
uint32_t n_outputs = 0;
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
{
|
||||
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
@ -8640,7 +8686,7 @@ static int llama_decode_impl(
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += n_tokens;
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
@ -9374,14 +9420,9 @@ int64_t llama_time_us(void) {
|
||||
return ggml_time_us();
|
||||
}
|
||||
|
||||
struct llama_model * llama_load_model_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params) {
|
||||
return llama_model_load_from_file(path_model, params);
|
||||
}
|
||||
|
||||
struct llama_model * llama_model_load_from_file(
|
||||
const char * path_model,
|
||||
static struct llama_model * llama_model_load_from_file_impl(
|
||||
const std::string & path_model,
|
||||
std::vector<std::string> & splits,
|
||||
struct llama_model_params params) {
|
||||
ggml_time_init();
|
||||
|
||||
@ -9404,53 +9445,13 @@ struct llama_model * llama_model_load_from_file(
|
||||
};
|
||||
}
|
||||
|
||||
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
|
||||
// split the servers set them into model->rpc_servers
|
||||
std::string servers(params.rpc_servers);
|
||||
size_t pos = 0;
|
||||
while ((pos = servers.find(',')) != std::string::npos) {
|
||||
std::string server = servers.substr(0, pos);
|
||||
model->rpc_servers.push_back(server);
|
||||
servers.erase(0, pos + 1);
|
||||
}
|
||||
model->rpc_servers.push_back(servers);
|
||||
}
|
||||
|
||||
// add RPC devices
|
||||
if (!model->rpc_servers.empty()) {
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
|
||||
llama_model_free(model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
|
||||
llama_model_free(model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const std::string & server : model->rpc_servers) {
|
||||
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (dev) {
|
||||
model->devices.push_back(dev);
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
|
||||
llama_model_free(model);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create list of devices to use with this model
|
||||
if (params.devices) {
|
||||
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
|
||||
model->devices.push_back(*dev);
|
||||
}
|
||||
} else {
|
||||
std::vector<ggml_backend_dev_t> rpc_servers;
|
||||
// use all available devices
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
@ -9461,10 +9462,19 @@ struct llama_model * llama_model_load_from_file(
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
model->devices.push_back(dev);
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
||||
rpc_servers.push_back(dev);
|
||||
} else {
|
||||
model->devices.push_back(dev);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
// add RPC servers at the front of the list
|
||||
if (!rpc_servers.empty()) {
|
||||
model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
|
||||
}
|
||||
}
|
||||
|
||||
// if using single GPU mode, remove all except the main GPU
|
||||
@ -9485,7 +9495,7 @@ struct llama_model * llama_model_load_from_file(
|
||||
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
|
||||
}
|
||||
|
||||
const int status = llama_model_load(path_model, *model, params);
|
||||
const int status = llama_model_load(path_model, splits, *model, params);
|
||||
GGML_ASSERT(status <= 0);
|
||||
if (status < 0) {
|
||||
if (status == -1) {
|
||||
@ -9501,6 +9511,35 @@ struct llama_model * llama_model_load_from_file(
|
||||
return model;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
struct llama_model * llama_load_model_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params) {
|
||||
return llama_model_load_from_file(path_model, params);
|
||||
}
|
||||
|
||||
struct llama_model * llama_model_load_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params) {
|
||||
std::vector<std::string> splits = {};
|
||||
return llama_model_load_from_file_impl(path_model, splits, params);
|
||||
}
|
||||
|
||||
struct llama_model * llama_model_load_from_splits(
|
||||
const char ** paths,
|
||||
size_t n_paths,
|
||||
struct llama_model_params params) {
|
||||
std::vector<std::string> splits;
|
||||
if (n_paths == 0) {
|
||||
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < n_paths; ++i) {
|
||||
splits.push_back(paths[i]);
|
||||
}
|
||||
return llama_model_load_from_file_impl(splits.front(), splits, params);
|
||||
}
|
||||
|
||||
struct llama_context * llama_init_from_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params) {
|
||||
|
Reference in New Issue
Block a user