mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-20 05:07:52 +00:00
ggml : add ALiBi support for ggml_soft_max_ext (llama/5488)
This commit is contained in:
parent
1b25d2fa0a
commit
eca5ff9868
14
ggml-alloc.c
14
ggml-alloc.c
@ -468,7 +468,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
struct ggml_tensor * parent = node->src[i];
|
struct ggml_tensor * parent = node->src[i];
|
||||||
if (parent == NULL) {
|
if (parent == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the node's data is external, then we cannot re-use it
|
// if the node's data is external, then we cannot re-use it
|
||||||
@ -565,7 +565,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * src = node->src[j];
|
struct ggml_tensor * src = node->src[j];
|
||||||
if (src == NULL) {
|
if (src == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_gallocr_hash_get(galloc, src)->n_children += 1;
|
ggml_gallocr_hash_get(galloc, src)->n_children += 1;
|
||||||
@ -599,7 +599,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * parent = node->src[j];
|
struct ggml_tensor * parent = node->src[j];
|
||||||
if (parent == NULL) {
|
if (parent == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
ggml_gallocr_allocate_node(galloc, parent, buffer_id);
|
ggml_gallocr_allocate_node(galloc, parent, buffer_id);
|
||||||
}
|
}
|
||||||
@ -611,7 +611,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * parent = node->src[j];
|
struct ggml_tensor * parent = node->src[j];
|
||||||
if (parent == NULL) {
|
if (parent == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
AT_PRINTF("%s", parent->name);
|
AT_PRINTF("%s", parent->name);
|
||||||
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
||||||
@ -624,7 +624,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * parent = node->src[j];
|
struct ggml_tensor * parent = node->src[j];
|
||||||
if (parent == NULL) {
|
if (parent == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||||
p_hn->n_children -= 1;
|
p_hn->n_children -= 1;
|
||||||
@ -810,7 +810,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * src = node->src[j];
|
struct ggml_tensor * src = node->src[j];
|
||||||
if (src == NULL) {
|
if (src == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
|
if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
@ -857,7 +857,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
|
|||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
struct ggml_tensor * src = node->src[j];
|
struct ggml_tensor * src = node->src[j];
|
||||||
if (src == NULL) {
|
if (src == NULL) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
|
ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
|
||||||
}
|
}
|
||||||
|
257
ggml-cuda.cu
257
ggml-cuda.cu
@ -5956,149 +5956,31 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
|
||||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
|
||||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
|
||||||
const int rowx = blockIdx.x;
|
|
||||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
|
||||||
|
|
||||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
|
||||||
|
|
||||||
extern __shared__ half data_soft_max_f16[];
|
|
||||||
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
|
||||||
// (shared memory) buffer to cache values between iterations:
|
|
||||||
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
|
||||||
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
|
||||||
// in that case col_smem == col_data must be enforced to avoid race conditions
|
|
||||||
|
|
||||||
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
|
||||||
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
|
||||||
const int col_smem = vals_smem ? col0 + tid : col_data;
|
|
||||||
|
|
||||||
const int ix = rowx*ncols_data + col_data;
|
|
||||||
const int iy = rowy*ncols_data + col_data;
|
|
||||||
|
|
||||||
half2 val;
|
|
||||||
if (need_check && col_data + 0 >= ncols_data) {
|
|
||||||
val.x = -INFINITY;
|
|
||||||
} else {
|
|
||||||
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
|
||||||
}
|
|
||||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
|
||||||
val.y = -INFINITY;
|
|
||||||
} else {
|
|
||||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
|
||||||
}
|
|
||||||
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
|
||||||
vals[col_smem] = val;
|
|
||||||
}
|
|
||||||
max_val = __hmax2(max_val, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the max value in the block
|
|
||||||
max_val = warp_reduce_max(max_val);
|
|
||||||
if (block_size > WARP_SIZE) {
|
|
||||||
if (warp_id == 0) {
|
|
||||||
buf_iw[lane_id] = -INFINITY;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (lane_id == 0) {
|
|
||||||
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
max_val = __half2half2(buf_iw[lane_id]);
|
|
||||||
max_val = warp_reduce_max(max_val);
|
|
||||||
} else {
|
|
||||||
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
|
||||||
}
|
|
||||||
|
|
||||||
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
|
||||||
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
|
||||||
|
|
||||||
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const half2 val = h2exp(vals[col_smem] - max_val);
|
|
||||||
|
|
||||||
tmp += val;
|
|
||||||
vals[col_smem] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the sum of exps in the block
|
|
||||||
tmp = warp_reduce_sum(tmp);
|
|
||||||
if (block_size > WARP_SIZE) {
|
|
||||||
if (warp_id == 0) {
|
|
||||||
buf_iw[lane_id] = 0.0f;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (lane_id == 0) {
|
|
||||||
buf_iw[warp_id] = tmp.x + tmp.y;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
tmp = __half2half2(buf_iw[lane_id]);
|
|
||||||
tmp = warp_reduce_sum(tmp);
|
|
||||||
} else {
|
|
||||||
tmp = __half2half2(tmp.x + tmp.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
|
||||||
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
|
||||||
const int col_smem = vals_smem ? col0 + tid : col_data;
|
|
||||||
|
|
||||||
const int idst = rowx*ncols_data + col_data;
|
|
||||||
const half2 result = vals[col_smem] * inv_sum;
|
|
||||||
|
|
||||||
if (need_check && col_data + 0 >= ncols_data) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst[idst] = result.x;
|
|
||||||
|
|
||||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[idst + WARP_SIZE] = result.y;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int rowx = blockIdx.x;
|
const int rowx = blockIdx.x;
|
||||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
||||||
|
|
||||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
float slope = 0.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const int h = rowx/nrows_y; // head index
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = powf(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
// shared memory buffer to cache values between iterations:
|
// shared memory buffer to cache values between iterations:
|
||||||
@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||||||
const int ix = rowx*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int iy = rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
}
|
}
|
||||||
@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
|
||||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
|
||||||
const dim3 block_dims(nth, 1, 1);
|
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
|
||||||
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
||||||
if (shmem <= g_device_caps[g_main_device].smpb) {
|
|
||||||
switch (ncols_x) {
|
|
||||||
case 32:
|
|
||||||
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 256:
|
|
||||||
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 512:
|
|
||||||
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 1024:
|
|
||||||
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 2048:
|
|
||||||
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
case 4096:
|
|
||||||
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
|
||||||
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
|
||||||
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
if (shmem < g_device_caps[g_main_device].smpb) {
|
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 1024:
|
case 1024:
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 2048:
|
case 2048:
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 4096:
|
case 4096:
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -9092,28 +8939,34 @@ static void ggml_cuda_op_soft_max(
|
|||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
#ifdef GGML_CUDA_F16
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
const bool use_f16_soft_max = true;
|
|
||||||
#else
|
|
||||||
const bool use_f16_soft_max = false;
|
|
||||||
#endif // GGML_CUDA_F16
|
|
||||||
#else
|
|
||||||
const bool use_f16_soft_max = false;
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
|
|
||||||
if (use_f16_soft_max) {
|
// positions tensor
|
||||||
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
float * src2_dd = dst_dd; // default to avoid null checks in the kernel
|
||||||
|
cuda_pool_alloc<float> src2_f;
|
||||||
|
|
||||||
|
ggml_tensor * src2 = dst->src[2];
|
||||||
|
const bool use_src2 = src2 != nullptr;
|
||||||
|
|
||||||
|
if (use_src2) {
|
||||||
|
const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
|
||||||
|
ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
|
||||||
|
|
||||||
|
if (src2_on_device) {
|
||||||
|
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
||||||
|
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) dst;
|
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_scale(
|
static void ggml_cuda_op_scale(
|
||||||
|
33
ggml-metal.m
33
ggml-metal.m
@ -737,6 +737,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
size_t offs_src0 = 0;
|
size_t offs_src0 = 0;
|
||||||
size_t offs_src1 = 0;
|
size_t offs_src1 = 0;
|
||||||
|
size_t offs_src2 = 0;
|
||||||
size_t offs_dst = 0;
|
size_t offs_dst = 0;
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
||||||
@ -755,6 +756,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
||||||
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
@ -816,6 +818,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
||||||
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
||||||
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||||
|
|
||||||
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||||
@ -1198,6 +1201,15 @@ static bool ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const float scale = ((float *) dst->op_params)[0];
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
|
const float max_bias = ((float *) dst->op_params)[1];
|
||||||
|
|
||||||
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1206,11 +1218,20 @@ static bool ggml_metal_graph_compute(
|
|||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_src2) {
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
} else {
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
}
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||||
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
||||||
|
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
||||||
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
||||||
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
||||||
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
||||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
@ -1523,8 +1544,6 @@ static bool ggml_metal_graph_compute(
|
|||||||
// max size of the src1ids array in the kernel stack
|
// max size of the src1ids array in the kernel stack
|
||||||
GGML_ASSERT(ne11 <= 512);
|
GGML_ASSERT(ne11 <= 512);
|
||||||
|
|
||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
||||||
|
|
||||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||||
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||||
|
@ -351,11 +351,16 @@ kernel void kernel_sum_rows(
|
|||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
device const float * src2,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||||
|
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
float slope = 0.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const int64_t h = i02;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
@ -437,11 +455,16 @@ kernel void kernel_soft_max(
|
|||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
device const float * src2,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||||
|
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
|
float slope = 0.0f;
|
||||||
|
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const int64_t h = i02;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
82
ggml.c
82
ggml.c
@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
|
struct ggml_tensor * pos,
|
||||||
float scale,
|
float scale,
|
||||||
|
float max_bias,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(a));
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(ggml_is_matrix(mask));
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pos) {
|
||||||
|
GGML_ASSERT(ggml_is_vector(pos));
|
||||||
|
GGML_ASSERT(pos->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
GGML_ASSERT(pos);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
float params[] = { scale };
|
float params[] = { scale, max_bias };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_SOFT_MAX;
|
result->op = GGML_OP_SOFT_MAX;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = mask;
|
result->src[1] = mask;
|
||||||
|
result->src[2] = pos;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
struct ggml_tensor * ggml_soft_max(
|
struct ggml_tensor * ggml_soft_max(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max_inplace(
|
struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max_ext(
|
struct ggml_tensor * ggml_soft_max_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale) {
|
struct ggml_tensor * pos,
|
||||||
return ggml_soft_max_impl(ctx, a, mask, scale, false);
|
float scale,
|
||||||
|
float max_bias) {
|
||||||
|
return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_soft_max_back
|
// ggml_soft_max_back
|
||||||
@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
|
const struct ggml_tensor * src2,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
assert(ggml_is_contiguous(dst));
|
assert(ggml_is_contiguous(dst));
|
||||||
assert(ggml_are_same_shape(src0, dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
@ -11504,15 +11520,28 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
// TODO: handle transposed/permuted matrices
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||||
|
|
||||||
|
// TODO: is this supposed to be ceil instead of floor?
|
||||||
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||||
|
const uint32_t n_head_kv = ne02;
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
const int nc = src0->ne[0];
|
||||||
const int nr = ggml_nrows(src0);
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
|
|
||||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
|
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
||||||
|
float * pos = src2 ? (float *) src2->data : src0->data;
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
ggml_vec_acc_f32(nc, wp, mp);
|
ggml_vec_acc_f32(nc, wp, mp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ALiBi bias
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const uint32_t h = (i1/ne01)%ne02; // head
|
||||||
|
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < nc; i++) {
|
||||||
|
wp[i] = wp[i] + slope*pos[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
//printf("p[%d] = %f\n", i, p[i]);
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
|
|||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
|
const struct ggml_tensor * src2,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -11730,14 +11773,8 @@ static void ggml_compute_forward_alibi_f32(
|
|||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
for (int64_t i = 0; i < ne0; i++) {
|
|
||||||
for (int64_t j = 0; j < ne1; j++) {
|
|
||||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
for (int64_t k = 0; k < ne2_ne3; k++) {
|
||||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
|
|
||||||
// TODO: k*nb2 or k*nb3
|
// TODO: k*nb2 or k*nb3
|
||||||
|
|
||||||
float m_k;
|
float m_k;
|
||||||
|
|
||||||
if (k < n_heads_log2_floor) {
|
if (k < n_heads_log2_floor) {
|
||||||
@ -11746,6 +11783,10 @@ static void ggml_compute_forward_alibi_f32(
|
|||||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < ne0; i++) {
|
||||||
|
for (int64_t j = 0; j < ne1; j++) {
|
||||||
|
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||||
|
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||||
pdst[0] = i * m_k + src[0];
|
pdst[0] = i * m_k + src[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -11790,14 +11831,8 @@ static void ggml_compute_forward_alibi_f16(
|
|||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
for (int i = 0; i < ne0; i++) {
|
|
||||||
for (int j = 0; j < ne1; j++) {
|
|
||||||
for (int k = 0; k < ne2_ne3; k++) {
|
for (int k = 0; k < ne2_ne3; k++) {
|
||||||
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
|
|
||||||
// TODO: k*nb2 or k*nb3
|
// TODO: k*nb2 or k*nb3
|
||||||
|
|
||||||
float m_k;
|
float m_k;
|
||||||
|
|
||||||
if (k < n_heads_log2_floor) {
|
if (k < n_heads_log2_floor) {
|
||||||
@ -11806,6 +11841,11 @@ static void ggml_compute_forward_alibi_f16(
|
|||||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < ne0; i++) {
|
||||||
|
for (int j = 0; j < ne1; j++) {
|
||||||
|
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||||
|
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||||
|
|
||||||
// we return F32
|
// we return F32
|
||||||
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
||||||
}
|
}
|
||||||
@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
{
|
{
|
||||||
|
13
ggml.h
13
ggml.h
@ -1383,13 +1383,17 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// fused soft_max(a*scale + mask)
|
// fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
|
||||||
// mask is optional
|
// mask is optional
|
||||||
|
// pos is required when max_bias > 0.0f
|
||||||
|
// max_bias = 0.0f for no ALiBi
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale);
|
struct ggml_tensor * pos,
|
||||||
|
float scale,
|
||||||
|
float max_bias);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
@ -1491,12 +1495,13 @@ extern "C" {
|
|||||||
|
|
||||||
// alibi position embedding
|
// alibi position embedding
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
GGML_API struct ggml_tensor * ggml_alibi(
|
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_head,
|
int n_head,
|
||||||
float bias_max);
|
float bias_max),
|
||||||
|
"use ggml_soft_max_ext instead (will be removed in Mar 2024)");
|
||||||
|
|
||||||
// clamp
|
// clamp
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user