diarization : try conv and self-attention embeddings

This commit is contained in:
Georgi Gerganov 2023-02-19 12:19:52 +02:00
parent d11f35920e
commit ec44ad0a75
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 210 additions and 113 deletions

75
ggml.c
View File

@ -8529,28 +8529,30 @@ void ggml_svd_reduce_dims(
float * A0 = (float *) malloc(n * m * sizeof(float));
// average vector
float * M = (float *) malloc(m * sizeof(float));
//float * M = (float *) malloc(m * sizeof(float));
{
for (int j = 0; j < m; ++j) {
M[j] = 0.0f;
}
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
M[j] += A[i * m + j];
}
}
for (int j = 0; j < m; ++j) {
M[j] /= (float) n;
}
}
//{
// for (int j = 0; j < m; ++j) {
// M[j] = 0.0f;
// }
// for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// M[j] += A[i * m + j];
// }
// }
// for (int j = 0; j < m; ++j) {
// M[j] /= (float) n;
// }
//}
// subtract average vector
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
A[i * m + j] -= M[j];
}
}
//// subtract average vector
//for (int i = 0; i < n; ++i) {
// for (int j = 0; j < m; ++j) {
// A[i * m + j] -= M[j];
// }
//}
//free(M);
memcpy(A0, A, n * m * sizeof(float));
@ -8616,11 +8618,11 @@ void ggml_svd_reduce_dims(
}
// print S
//printf("S:\n");
//for (int i = 0; i < n; ++i) {
// printf("- %d = %9.5f\n", i, S[i]);
//}
//printf("\n");
printf("S:\n");
for (int i = 0; i < n; ++i) {
printf("- %d = %9.5f\n", i, S[i]);
}
printf("\n");
// print V
//printf("V:\n");
@ -8652,16 +8654,16 @@ void ggml_svd_reduce_dims(
}
// normalize U
//for (int i = 0; i < n; ++i) {
// double sum = 0.0;
// for (int j = 0; j < m; ++j) {
// sum += U[i * m + j] * U[i * m + j];
// }
// sum = sqrt(sum);
// for (int j = 0; j < m; ++j) {
// U[i * m + j] /= sum*sqrt((double) m);
// }
//}
for (int i = 0; i < n; ++i) {
double sum = 0.0;
for (int j = 0; j < m; ++j) {
sum += U[i * m + j] * U[i * m + j];
}
sum = sqrt(sum);
for (int j = 0; j < m; ++j) {
U[i * m + j] /= sum*sqrt((double) m);
}
}
// print U
//printf("U:\n");
@ -8674,12 +8676,11 @@ void ggml_svd_reduce_dims(
//}
//printf("\n");
printf("n = %d, m = %d, nd = %d\n", n, m, nd);
// project A0 onto U
for (int i = 0; i < n; ++i) {
for (int j = 0; j < nd; ++j) {
A[i * nd + j] = 0.0f;
//if (j == 0) continue;
for (int k = 0; k < m; ++k) {
A[i * nd + j] += A0[i * m + k] * U[j * m + k];
}

View File

@ -268,6 +268,14 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
{ MODEL_LARGE, 71ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_ENC_SELF = {
{ MODEL_TINY, 23ull*MB },
{ MODEL_BASE, 26ull*MB },
{ MODEL_SMALL, 216ull*MB },
{ MODEL_MEDIUM, 243ull*MB },
{ MODEL_LARGE, 271ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
@ -571,6 +579,7 @@ struct whisper_context {
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
whisper_kv_cache kv_enc_self;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
@ -807,7 +816,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
@ -838,6 +847,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return false;
}
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_ENC_SELF.at(model.type), wctx.kv_enc_self, wctx.wtype, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
}
{
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
@ -1415,6 +1429,9 @@ static bool whisper_encode(
}
}
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * cur;
// convolution + gelu
@ -1442,6 +1459,18 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
}
//{
// //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(0*n_ctx));
// //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
// //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
wctx.use_buf(ctx0, 3);
// ===================================================================
@ -1522,6 +1551,18 @@ static bool whisper_encode(
Vcur),
Vcur);
//{
// //printf("Kcur: %d %d %d %d, size element = %d\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3], ggml_element_size(Kcur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
// struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
// ------
wctx.use_buf(ctx0, 0);
@ -1606,6 +1647,18 @@ static bool whisper_encode(
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
{
//printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
wctx.use_buf(ctx0, -1);
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
}
// projection
@ -1715,8 +1768,6 @@ static bool whisper_encode(
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf);
@ -4858,7 +4909,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_state = ctx->model.hparams.n_audio_state;
const int n_layer = ctx->model.hparams.n_audio_layer;
#if 1
#if 0
// use the last layer of the encoder
{
std::vector<float> embd(n_segments*n_state);
@ -4878,7 +4929,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
#else
#elif 0
// use cross kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
@ -4900,10 +4951,56 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#elif 0
// use conv embedding
for (int il = 0; il < 1; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(3, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#else
// use enc self kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#endif
std::vector<std::vector<float>> features(n_segments);
std::vector<std::vector<double>> features(n_segments);
for (int i = 0; i < n_segments; ++i) {
features[i].resize(n_features);
@ -4915,8 +5012,8 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
// fuzzy c-means clustering
const int n_clusters = 2;
std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
std::vector<std::vector<double>> centroids(n_clusters, std::vector<double>(n_features, 0.0));
std::vector<std::vector<double>> membership(n_segments, std::vector<double>(n_clusters, 0.0));
// initialize the centroids
for (int i = 0; i < n_clusters; ++i) {
@ -4928,8 +5025,11 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
// initialize the membership
for (int i = 0; i < n_segments; ++i) {
//membership[i][i % n_clusters] = 1.0;
//for (int j = 0; j < n_clusters; ++j) {
// membership[i][j] = rand() / (float) RAND_MAX;
//}
for (int j = 0; j < n_clusters; ++j) {
membership[i][j] = rand() / (float) RAND_MAX;
membership[i][j] = 1.0 / n_clusters;
}
}
@ -4937,42 +5037,47 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
// iterate
for (int i = 0; i < niter; ++i) {
// update the centroids
for (int j = 0; j < n_clusters; ++j) {
for (int k = 0; k < n_features; ++k) {
centroids[j][k] = 0.0;
}
}
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
for (int l = 0; l < n_features; ++l) {
centroids[k][l] += membership[j][k]*features[j][l];
// print the membership
if (i == niter - 1) {
//{
for (int i = 0; i < n_segments; ++i) {
#if 1
printf("%s: membership %3d: ", __func__, i);
for (int j = 0; j < n_clusters; ++j) {
printf("%.1f ", membership[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#else
printf("%s: features : ", __func__);
for (int j = 0; j < n_features; ++j) {
printf("%8.3f ", features[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#endif
}
}
printf("----------------\n");
for (int j = 0; j < n_clusters; ++j) {
float sum = 0.0;
for (int k = 0; k < n_segments; ++k) {
sum += membership[k][j];
}
for (int k = 0; k < n_features; ++k) {
centroids[j][k] /= sum;
// print the centroids
for (int i = 0; i < n_clusters; ++i) {
printf("%s: centroid %d: ", __func__, i);
for (int j = 0; j < n_features; ++j) {
printf("%f ", centroids[i][j]);
}
printf("\n");
}
}
// update the membership
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
float sum = 0.0;
double sum = 0.0;
for (int l = 0; l < n_clusters; ++l) {
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
double d0 = 0.0;
double d1 = 0.0;
#if 1
// use the euclidean distance
{
for (int m = 0; m < n_features; ++m) {
@ -4985,68 +5090,59 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
}
d1 = std::sqrt(d1);
}
#else
// use the cosine distance
//{
// double dot = 0.0;
// double norm0 = 0.0;
// double norm1 = 0.0;
{
double dot = 0.0;
double norm0 = 0.0;
double norm1 = 0.0;
// for (int m = 0; m < n_features; ++m) {
// dot += features[j][m]*centroids[k][m];
// norm0 += std::pow(features[j][m], 2.0);
// norm1 += std::pow(centroids[k][m], 2.0);
// }
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[k][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[k][m], 2.0);
}
// d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
// dot = 0.0;
// norm0 = 0.0;
// norm1 = 0.0;
dot = 0.0;
norm0 = 0.0;
norm1 = 0.0;
// for (int m = 0; m < n_features; ++m) {
// dot += features[j][m]*centroids[l][m];
// norm0 += std::pow(features[j][m], 2.0);
// norm1 += std::pow(centroids[l][m], 2.0);
// }
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[l][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[l][m], 2.0);
}
// d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
//}
d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
}
#endif
sum += std::pow(d0/d1, 2.0/(1.15 - 1.0));
if (d1 > 0.0) {
sum += std::pow(d0/d1, 2.0/(1.20 - 1.0));
} else {
sum += 1.0;
}
}
membership[j][k] = sum == 0.0 ? 0.0 : 1.0/sum;
membership[j][k] = sum == 0.0 ? 1.0 : 1.0/sum;
}
}
// print the membership
if (i == niter - 1) {
//{
for (int i = 0; i < n_segments; ++i) {
printf("%s: membership %3d: ", __func__, i);
for (int j = 0; j < n_clusters; ++j) {
printf("%f ", membership[i][j]);
// update the centroids
for (int j = 0; j < n_clusters; ++j) {
for (int k = 0; k < n_features; ++k) {
double sum = 0.0;
double sum2 = 0.0;
for (int l = 0; l < n_segments; ++l) {
sum += membership[l][j]*features[l][k];
sum2 += membership[l][j];
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
//printf("%s: features : ", __func__);
//for (int j = 0; j < n_features; ++j) {
// printf("%8.3f ", features[i][j]);
//}
//printf(" '%s'\n", ctx->result_all[i].text.c_str());
centroids[j][k] = sum2 == 0.0 ? 0.0 : sum/sum2;
}
printf("----------------\n");
}
}
// print the centroids
for (int i = 0; i < n_clusters; ++i) {
printf("%s: centroid %d: ", __func__, i);
for (int j = 0; j < n_features; ++j) {
printf("%f ", centroids[i][j]);
}
printf("\n");
}
}
// restore the mel length