vulkan: skip integer div/mod in get_offsets for batch_idx==0 (llama/10506)

This commit is contained in:
Jeff Bolz 2024-11-27 01:08:54 -06:00 committed by Georgi Gerganov
parent 2f16e51553
commit 2d6e9dd723

View File

@ -52,13 +52,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#endif
#ifndef MUL_MAT_ID
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
uint batch_idx_a = 0;
if (batch_idx != 0) {
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const uint batch_idx_a = i03 * p.ne02 + i02;
batch_idx_a = i03 * p.ne02 + i02;
}
#else
const uint expert_id = data_ids[expert_idx];
#endif