diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ef7d5fa0..05984d8c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -333,10 +333,11 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - tensor->extra = extra; - ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + } if (ggml_is_quantized(tensor->type)) { // initialize padding to 0 to avoid possible NaN values @@ -486,6 +487,22 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + if (buffer == nullptr) { + return; + } + + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + + if (ctx != nullptr) { + for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) { + release_extra_gpu(extra); + } + ctx->tensor_extras.clear(); // reset the tensor_extras vector + } +} + static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, /* .get_base = */ ggml_backend_sycl_buffer_get_base, @@ -495,7 +512,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, - /* .reset = */ NULL, + /* .reset = */ ggml_backend_sycl_buffer_reset, }; // sycl buffer type @@ -576,7 +593,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { static std::mutex mutex; std::lock_guard lock(mutex); - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); auto dev_count = ggml_backend_sycl_get_device_count(); @@ -3761,7 +3777,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) { } int ggml_backend_sycl_get_device_count() { - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n"); return ggml_sycl_info().device_count; }