diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 19d22d63..eac8f557 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3847,21 +3847,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne } } -static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor, - const sycl::nd_item<3> &item_ct1) { - int ne0 = ne00 * scale_factor; - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (nidx >= ne0) { +static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { + int index = item_ct1.get_local_id(0) + + item_ct1.get_group(0) * item_ct1.get_local_range(0); + if (index >= ne10 * ne11 * ne12 * ne13) { return; } // operation - int i00 = nidx / scale_factor; - int i01 = item_ct1.get_group(1) / scale_factor; - int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02; - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - dst[offset_dst] = x[offset_src]; + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, @@ -10085,18 +10091,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, }); } -static void upscale_f32_sycl(const float *x, float *dst, const int ne00, - const int ne01, const int ne02, - const int scale_factor, dpct::queue_ptr stream) { - int ne0 = (ne00 * scale_factor); - int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; - sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks); +static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, dpct::queue_ptr stream) { + int dst_size = ne10 * ne11 * ne12 * ne13; + int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); stream->parallel_for( - sycl::nd_range<3>(gridDim * - sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1); + sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); }); } @@ -13985,15 +13990,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors -#pragma message("TODO: generalize upscale operator") -#pragma message(" https://github.com/ggerganov/ggml/pull/814") - GGML_ASSERT(false && "TODO: generalize upscale operator"); + const float sf0 = (float)dst->ne[0]/src0->ne[0]; + const float sf1 = (float)dst->ne[1]/src0->ne[1]; + const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; - const int scale_factor = dst->op_params[0]; - - upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream); + upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); (void) src1; (void) dst;