mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
ggml : mul_mat_id use the same tensor for all the experts (llama/6387)
* ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
f12e982c0b
commit
1dce94cf26
214
ggml-cuda.cu
214
ggml-cuda.cu
@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
|
||||
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||
|
||||
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
||||
if (tensor->view_src != NULL) {
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
tensor->backend = tensor->view_src->backend;
|
||||
tensor->extra = tensor->view_src->extra;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
template<typename ... Srcs>
|
||||
static __global__ void k_compute_batched_ptrs_id(
|
||||
const void ** ptrs_src, void ** ptrs_dst,
|
||||
int ne12, int ne13,
|
||||
int ne23,
|
||||
int nb02, int nb03,
|
||||
int nb12, int nb13,
|
||||
int nb2, int nb3,
|
||||
int r2, int r3,
|
||||
ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
|
||||
const half * src1_f16, half * dst_f16,
|
||||
const int32_t * ids, const int id,
|
||||
Srcs... src0s) {
|
||||
|
||||
int i = ids[id];
|
||||
|
||||
half * src0_f16;
|
||||
const void * srcs_ar[] = { (const half *) src0s... };
|
||||
if (src0_type == GGML_TYPE_F16) {
|
||||
src0_f16 = (half *) srcs_ar[i];
|
||||
} else {
|
||||
src0_f16 = src0_as_f16;
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
|
||||
to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
|
||||
}
|
||||
}
|
||||
|
||||
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (i13 >= ne13 || i12 >= ne12) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i03 = i13 / r3;
|
||||
int i02 = i12 / r2;
|
||||
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
|
||||
const struct ggml_tensor * ids = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src00 = dst->src[2];
|
||||
|
||||
const int id = dst->op_params[0];
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src00));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
||||
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
|
||||
const int64_t ne01 = src00->ne[1];
|
||||
const int64_t ne02 = src00->ne[2];
|
||||
const int64_t ne03 = src00->ne[3];
|
||||
|
||||
//const int64_t nb01 = src00->nb[1];
|
||||
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
|
||||
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
//const int64_t nb11 = src1->nb[1];
|
||||
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
|
||||
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
|
||||
|
||||
const int64_t ne1 = ggml_nelements(src1);
|
||||
const int64_t ne = ggml_nelements(dst);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
|
||||
|
||||
//ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
//void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
//half * src0_as_f16 = (half *) src0_ddq;
|
||||
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
// convert src1 to fp16
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
|
||||
size_t src1_as = 0;
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
// use cublasGemmBatchedEx
|
||||
const int ne23 = ne12*ne13;
|
||||
|
||||
const void ** ptrs_src = nullptr;
|
||||
void ** ptrs_dst = nullptr;
|
||||
|
||||
size_t ptrs_src_s = 0;
|
||||
size_t ptrs_dst_s = 0;
|
||||
|
||||
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
||||
|
||||
int64_t src0_ne = ggml_nelements(src00);
|
||||
half * src0_as_f16 = nullptr;
|
||||
size_t src0_as = 0;
|
||||
if (src00->type != GGML_TYPE_F16) {
|
||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
|
||||
}
|
||||
|
||||
static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
|
||||
ptrs_src, ptrs_dst,
|
||||
ne12, ne13,
|
||||
ne23,
|
||||
ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
|
||||
nb12, nb13,
|
||||
dst->nb[2], dst->nb[3],
|
||||
r2, r3,
|
||||
src00->type, src0_as_f16, src0_ne,
|
||||
src1_as_f16, dst_f16,
|
||||
(const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
|
||||
dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
|
||||
dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
|
||||
dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
|
||||
dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
|
||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
|
||||
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
||||
ne23,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
}
|
||||
if (ptrs_src_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||
}
|
||||
if (ptrs_dst_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
||||
}
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
#if 0
|
||||
ggml_cuda_mul_mat_id_cublas(dst);
|
||||
// TODO: mmq/mmv support
|
||||
#endif
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * ids = dst->src[2];
|
||||
|
||||
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const size_t nb11 = src1->nb[1];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = ((int32_t *) dst->op_params)[0];
|
||||
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
||||
const int32_t n_as = src0->ne[2];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
const char * ids_dev = (const char *) ids->data;
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
src0_row.nb[3] = src0->nb[2];
|
||||
|
||||
if (src1->ne[1] == 1) {
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
src0_row.data = src0_original + row_id*src0->nb[2];
|
||||
src1_row.data = src1_original + i01*src1->nb[1];
|
||||
dst_row.data = dst_original + i01*dst->nb[1];
|
||||
|
||||
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
int64_t num_src1_rows = 0;
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
continue;
|
||||
}
|
||||
|
||||
src0_row.data = src0_original + row_id*src0->nb[2];
|
||||
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
|
||||
@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
|
||||
num_src1_rows = 0;
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ASSERT(false);
|
||||
CUDA_CHECK(err);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
|
||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols) return;
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
int * dst_row = dst + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
if (col < ncols) {
|
||||
dst_row[col] = col;
|
||||
}
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 2; k <= ncols; k *= 2) {
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
GGML_ASSERT((ncols & (ncols - 1)) == 0);
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const dim3 block_dims(ncols, 1, 1);
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
209
ggml-metal.m
209
ggml-metal.m
@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
{
|
||||
//GGML_ASSERT(ne00 == ne10);
|
||||
//GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
||||
|
||||
const int n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
// TODO: make this more general
|
||||
GGML_ASSERT(n_as <= 8);
|
||||
const int n_as = src0->ne[2];
|
||||
|
||||
// max size of the src1ids array in the kernel shared buffer
|
||||
GGML_ASSERT(ne11 <= 4096);
|
||||
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
||||
// src2 = ids
|
||||
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
|
||||
const int64_t ne21 = src2->ne[1];
|
||||
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
||||
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2->nb[1];
|
||||
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
||||
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
||||
|
||||
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src2));
|
||||
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
||||
const uint r2 = ne12/ne22;
|
||||
const uint r3 = ne13/ne23;
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
int ne11_mm_min = n_as;
|
||||
@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
const int idx = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
// batch size
|
||||
GGML_ASSERT(ne01 == ne11);
|
||||
GGML_ASSERT(ne21 == ne11); // ?
|
||||
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
||||
const uint r2 = 1;
|
||||
const uint r3 = 1;
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
// indirect matrix multiplication
|
||||
// !!!
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
ne20 % 32 == 0 && ne20 >= 64 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
ne11 > ne11_mm_min) {
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src2->type) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||
@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
||||
}
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:19];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src2t) {
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
};
|
||||
|
||||
if (ggml_is_quantized(src2t)) {
|
||||
GGML_ASSERT(ne20 >= nth0*nth1);
|
||||
if (ggml_is_quantized(src0t)) {
|
||||
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||
}
|
||||
|
||||
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
||||
@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
|
||||
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
|
||||
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q3_K) {
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#else
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#endif
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int64_t ne00_padded = 1;
|
||||
while (ne00_padded < ne00) {
|
||||
ne00_padded *= 2;
|
||||
}
|
||||
|
||||
// Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (order) {
|
||||
@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
|
410
ggml-metal.metal
410
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
57
ggml.c
57
ggml.c
@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
|
||||
|
||||
// ggml_mul_mat_id
|
||||
|
||||
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
|
||||
// this will allow computing all the used experts in a single matrix multiplication
|
||||
struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * const as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b) {
|
||||
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
||||
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
|
||||
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
|
||||
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
|
||||
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
|
||||
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (as[0]->grad || b->grad) {
|
||||
if (as->grad || b->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, id);
|
||||
ggml_set_op_params_i32(result, 1, n_as);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT_ID;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = ids;
|
||||
result->src[0] = as;
|
||||
result->src[1] = b;
|
||||
|
||||
for (int i = 0; i < n_as; i++) {
|
||||
struct ggml_tensor * a = as[i];
|
||||
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
||||
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
||||
GGML_ASSERT(!ggml_is_transposed(a));
|
||||
result->src[i + 2] = a;
|
||||
}
|
||||
result->src[2] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * ids = dst->src[0];
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
|
||||
const struct ggml_tensor * ids = dst->src[2];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
// broadcast is not supported with mmid
|
||||
assert(ne12 == 1);
|
||||
assert(ne13 == 1);
|
||||
|
||||
// row groups
|
||||
const int id = ggml_get_op_params_i32(dst, 0);
|
||||
const int n_as = ggml_get_op_params_i32(dst, 1);
|
||||
const int n_as = src0->ne[2];
|
||||
|
||||
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||
(char *) params->wdata :
|
||||
@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
continue;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
|
||||
size_t src0_offset = cur_a*src0->nb[2];
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
continue;
|
||||
}
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
//const int64_t i03 = i13/r3;
|
||||
//const int64_t i02 = i12/r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
|
||||
const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
|
||||
const char * src0_row = (const char *) src0->data + src0_offset;
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
cur = 0;
|
||||
const struct ggml_tensor * src0 = node->src[2];
|
||||
const struct ggml_tensor * src0 = node->src[0];
|
||||
const struct ggml_tensor * src1 = node->src[1];
|
||||
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
||||
if (src1->type != vec_dot_type) {
|
||||
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||
}
|
||||
const int n_as = ggml_get_op_params_i32(node, 1);
|
||||
const int n_as = src0->ne[2];
|
||||
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
||||
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
||||
|
3
ggml.h
3
ggml.h
@ -1164,8 +1164,7 @@ extern "C" {
|
||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * const as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b);
|
||||
|
Loading…
Reference in New Issue
Block a user