CUDA: fix MMQ for non-contiguous src0, add tests (llama/10021)

* CUDA: fix MMQ for non-contiguous src0, add tests

* revise test code
This commit is contained in:
Johannes Gäßler
2024-10-24 11:09:36 +02:00
committed by Georgi Gerganov
parent 10eb603a3c
commit ab0385f43b
3 changed files with 13 additions and 11 deletions

View File

@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne00 = src0->ne[0];
const int64_t nb01 = src0->nb[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);
@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;