mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-06 10:59:11 +00:00
ggml: fix gradient allocation logic (ggml/966)
* ggml: fix gradient allocation logic * gradient allocation in ggml_build_backward_expand * fixup * fix test-backend-ops grad * suggestions by slaren * fix test1.c * fix legacy opt API * fix test-grad0 * remove keep arg
This commit is contained in:
parent
6c91da80b8
commit
0ac6666cd2
@ -1410,14 +1410,14 @@ extern "C" {
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * a, // data
|
||||
struct ggml_tensor * b); // row indices
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_get_rows_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c);
|
||||
struct ggml_tensor * a, // gradients of ggml_get_rows result
|
||||
struct ggml_tensor * b, // row indices
|
||||
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_diag(
|
||||
struct ggml_context * ctx,
|
||||
@ -1568,9 +1568,9 @@ extern "C" {
|
||||
// a - dy
|
||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * a, // gradients of ggml_rope result
|
||||
struct ggml_tensor * b, // positions
|
||||
struct ggml_tensor * c, // freq factors
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
@ -2037,14 +2037,14 @@ extern "C" {
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * a, // logits
|
||||
struct ggml_tensor * b); // labels
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c);
|
||||
struct ggml_tensor * a, // logits
|
||||
struct ggml_tensor * b, // labels
|
||||
struct ggml_tensor * c); // gradients of cross_entropy_loss result
|
||||
|
||||
// AdamW optimizer step
|
||||
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||
@ -2066,7 +2066,7 @@ extern "C" {
|
||||
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
|
||||
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
|
||||
|
||||
GGML_API void ggml_build_opt_adamw(
|
||||
struct ggml_context * ctx,
|
||||
|
760
ggml/src/ggml.c
760
ggml/src/ggml.c
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user