mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-27 06:20:07 +00:00
vulkan: use aligned loads for flash attention mask (llama/12853)
Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
This commit is contained in:
parent
e8ee32d12d
commit
751e42b21e
@ -201,6 +201,11 @@ void main() {
|
|||||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||||
uint32_t k_stride = p.nb11;
|
uint32_t k_stride = p.nb11;
|
||||||
uint32_t v_stride = p.nb21;
|
uint32_t v_stride = p.nb21;
|
||||||
|
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||||
|
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||||
|
// that prevents the compiler from folding the "&" through the select
|
||||||
|
// and breaking the alignment detection.
|
||||||
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||||
{
|
{
|
||||||
@ -209,6 +214,7 @@ void main() {
|
|||||||
k_stride &= ~7;
|
k_stride &= ~7;
|
||||||
v_stride &= ~7;
|
v_stride &= ~7;
|
||||||
#endif
|
#endif
|
||||||
|
m_stride &= ~7;
|
||||||
}
|
}
|
||||||
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||||
@ -261,10 +267,7 @@ void main() {
|
|||||||
if (p.mask != 0) {
|
if (p.mask != 0) {
|
||||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||||
// When using grouped query attention, all rows use the same mask.
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||||
if (p.gqa_ratio > 1) {
|
|
||||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user