talk-llama : sync llama.cpp
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-13 13:20:19 +03:00
parent a14c89aefa
commit f890560575
25 changed files with 2847 additions and 1125 deletions

View File

@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_236B: return "236B";
case LLM_TYPE_290B: return "290B";
case LLM_TYPE_314B: return "314B";
case LLM_TYPE_405B: return "405B";
case LLM_TYPE_671B: return "671B";
case LLM_TYPE_SMALL: return "0.1B";
case LLM_TYPE_MEDIUM: return "0.4B";
@ -116,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
}
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
@ -298,6 +303,10 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
// add extra buffer types, only if no GPU device is present
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
@ -582,6 +591,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_7B; break;
case 80: type = LLM_TYPE_70B; break;
case 162: type = LLM_TYPE_405B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
@ -773,6 +783,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// fall through
case LLM_ARCH_QWEN2:
{
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
@ -1481,6 +1492,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
@ -1648,8 +1662,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
std::regex pattern(overrides->pattern);
if (std::regex_search(tensor_name, pattern)) {
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
buft = overrides->buft;
LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
tensor_name.c_str(),
ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
ggml_backend_buft_name(buft));
break;
}
}
@ -1666,6 +1683,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
auto * buft_dev = ggml_backend_buft_get_device(buft);
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error("no CPU backend found");
}
buft = ggml_backend_dev_buffer_type(cpu_dev);
}
@ -1847,7 +1867,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
if (n_ff > 0) {
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
}
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
@ -1857,9 +1879,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
if (n_ff > 0) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
// optional MLP bias
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
@ -3503,7 +3527,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
@ -4108,6 +4136,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (!dev) {
// FIXME: workaround for CPU backend buft having a NULL device
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
}
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
@ -4237,7 +4268,7 @@ uint64_t llama_model::n_elements() const {
}
void llama_model::print_info() const {
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
bool is_var = false;
@ -4298,7 +4329,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
@ -4445,6 +4476,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
return it->second;
}
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
return layers[il].rope_short;
}
struct llm_build_llama : public llm_graph_context {
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -4485,7 +4529,7 @@ struct llm_build_llama : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -4691,6 +4735,7 @@ struct llm_build_deci : public llm_graph_context {
ggml_tensor * inpSA = inpL;
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_head = hparams.n_head(il);
const int64_t n_ff = hparams.n_ff(il);
if (n_head == 0) {
// attention-free layer of Llama-3_1-Nemotron-51B
@ -4710,7 +4755,7 @@ struct llm_build_deci : public llm_graph_context {
} else if (n_head > 0) {
// self-attention
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -4766,6 +4811,11 @@ struct llm_build_deci : public llm_graph_context {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// FFN-free layer of Llama-3_1-Nemotron-Ultra-253B
if (n_ff == 0) {
continue;
}
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
@ -7192,7 +7242,7 @@ struct llm_build_phi3 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor* attn_norm_output = build_norm(inpL,
model.layers[il].attn_norm,
@ -7944,7 +7994,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// norm
cur = build_norm(inpL,
@ -8711,7 +8761,7 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto kv_head = kv_self->head;
@ -9012,7 +9062,7 @@ struct llm_build_cohere2 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -9950,7 +10000,7 @@ struct llm_build_deepseek : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -11314,7 +11364,7 @@ struct llm_build_exaone : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -11459,7 +11509,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11855,7 +11905,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12695,7 +12745,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -12815,36 +12865,46 @@ struct llm_build_bailingmoe : public llm_graph_context {
}
};
llama_memory_i * llama_model::create_memory() const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
switch (arch) {
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
res = nullptr;
} break;
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_unified(hparams, {
/*.get_rope_factors =*/ nullptr
});
res = new llama_kv_cache_recurrent(
*this,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max));
} break;
default:
{
res = new llama_kv_cache_unified(hparams, {
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
const auto padding = llama_kv_cache_unified::get_padding(cparams);
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
return layers[il].rope_short;
}
});
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
res = new llama_kv_cache_unified(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
padding);
}
}
@ -13226,8 +13286,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DECI:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:
case LLM_ARCH_PLAMO:
case LLM_ARCH_ORION:
case LLM_ARCH_INTERNLM2:
case LLM_ARCH_MINICPM:
case LLM_ARCH_XVERSE:
@ -13265,6 +13323,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_PHIMOE:
case LLM_ARCH_PLAMO:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
@ -13272,6 +13331,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_CODESHELL:
case LLM_ARCH_ORION:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_MINICPM3:
@ -13344,6 +13404,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
// one-off fix for very popular models (so we are not flooded with issues)
// do not extend this list unless absolutely necessary
// Mistral-Small-2503 does not have built-in chat template
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
return "mistral-v7-tekken";
}
return nullptr;
}