whisper.cpp/examples/talk-llama/llama-adapter.h

75 lines
1.7 KiB
C
Raw Normal View History

2025-01-13 08:55:48 +02:00
#pragma once
2025-01-14 09:53:50 +02:00
#include "llama.h"
2025-01-13 08:55:48 +02:00
#include "ggml-cpp.h"
2025-01-14 09:53:50 +02:00
#include <string>
2025-01-13 08:55:48 +02:00
#include <unordered_map>
#include <vector>
2025-01-14 09:53:50 +02:00
// TODO: pimpl
2025-01-13 08:55:48 +02:00
//
// llama_adapter_cvec
//
2025-01-14 09:53:50 +02:00
struct llama_adapter_cvec {
struct ggml_tensor * tensor_for(int il) const;
2025-01-13 08:55:48 +02:00
2025-01-14 09:53:50 +02:00
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
int32_t apply(
const llama_model & model,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
private:
bool init(const llama_model & model);
2025-01-13 08:55:48 +02:00
int32_t layer_start = -1;
int32_t layer_end = -1;
2025-01-14 09:53:50 +02:00
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
2025-01-13 08:55:48 +02:00
2025-01-14 09:53:50 +02:00
std::vector<struct ggml_tensor *> tensors; // per layer
2025-01-13 08:55:48 +02:00
};
//
// llama_adapter_lora
//
2025-01-14 09:53:50 +02:00
struct llama_adapter_lora_weight {
2025-01-13 08:55:48 +02:00
struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr;
2025-01-14 09:53:50 +02:00
// get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const {
const float rank = (float) b->ne[0];
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
return scale;
}
llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
2025-01-13 08:55:48 +02:00
};
2025-01-14 09:53:50 +02:00
struct llama_adapter_lora {
2025-01-13 08:55:48 +02:00
// map tensor name to lora_a_b
2025-01-14 09:53:50 +02:00
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
2025-01-13 08:55:48 +02:00
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
float alpha;
2025-01-14 09:53:50 +02:00
llama_adapter_lora() = default;
~llama_adapter_lora() = default;
2025-01-13 08:55:48 +02:00
2025-01-14 09:53:50 +02:00
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
2025-01-13 08:55:48 +02:00
};