mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
groundwork: add pkg/concurrency and the associated test file (#2745)
groundwork: add pkg/concurrency and the associated test case Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
63fc22baab
commit
fc29c04f82
13
pkg/concurrency/concurrency_suite_test.go
Normal file
13
pkg/concurrency/concurrency_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package concurrency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConcurrency(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Concurrency test suite")
|
||||||
|
}
|
69
pkg/concurrency/jobresult.go
Normal file
69
pkg/concurrency/jobresult.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package concurrency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a Read-ONLY structure that contains the result of an arbitrary asynchronous action
|
||||||
|
type JobResult[RequestType any, ResultType any] struct {
|
||||||
|
request *RequestType
|
||||||
|
result *ResultType
|
||||||
|
err error
|
||||||
|
once sync.Once
|
||||||
|
done *chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This structure is returned in a pair with a JobResult and serves as the structure that has access to be updated.
|
||||||
|
type WritableJobResult[RequestType any, ResultType any] struct {
|
||||||
|
*JobResult[RequestType, ResultType]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until the result is ready and then returns the result, or the context expires.
|
||||||
|
// Returns *ResultType instead of ResultType since its possible we have only an error and nil for ResultType.
|
||||||
|
// Is this correct and idiomatic?
|
||||||
|
func (jr *JobResult[RequestType, ResultType]) Wait(ctx context.Context) (*ResultType, error) {
|
||||||
|
if jr.done == nil { // If the channel is blanked out, result is ready.
|
||||||
|
return jr.result, jr.err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-*jr.done: // Wait for the result to be ready
|
||||||
|
jr.done = nil
|
||||||
|
if jr.err != nil {
|
||||||
|
return nil, jr.err
|
||||||
|
}
|
||||||
|
return jr.result, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accessor function to allow holders of JobResults to access the associated request, without allowing the pointer to be updated.
|
||||||
|
func (jr *JobResult[RequestType, ResultType]) Request() *RequestType {
|
||||||
|
return jr.request
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the function that actually updates the Result and Error on the JobResult... but it's normally not accessible
|
||||||
|
func (jr *JobResult[RequestType, ResultType]) setResult(result ResultType, err error) {
|
||||||
|
jr.once.Do(func() {
|
||||||
|
jr.result = &result
|
||||||
|
jr.err = err
|
||||||
|
close(*jr.done) // Signal that the result is ready - since this is only ran once, jr.done cannot be set to nil yet.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only the WritableJobResult can actually call setResult - prevents accidental corruption
|
||||||
|
func (wjr *WritableJobResult[RequestType, ResultType]) SetResult(result ResultType, err error) {
|
||||||
|
wjr.JobResult.setResult(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJobResult binds a request to a matched pair of JobResult and WritableJobResult
|
||||||
|
func NewJobResult[RequestType any, ResultType any](request RequestType) (*JobResult[RequestType, ResultType], *WritableJobResult[RequestType, ResultType]) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
jr := &JobResult[RequestType, ResultType]{
|
||||||
|
once: sync.Once{},
|
||||||
|
request: &request,
|
||||||
|
done: &done,
|
||||||
|
}
|
||||||
|
return jr, &WritableJobResult[RequestType, ResultType]{JobResult: jr}
|
||||||
|
}
|
80
pkg/concurrency/jobresult_test.go
Normal file
80
pkg/concurrency/jobresult_test.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package concurrency_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/mudler/LocalAI/pkg/concurrency"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("pkg/concurrency unit tests", func() {
|
||||||
|
It("can be used to recieve a result across goroutines", func() {
|
||||||
|
jr, wjr := NewJobResult[string, string]("foo")
|
||||||
|
Expect(jr).ToNot(BeNil())
|
||||||
|
Expect(wjr).ToNot(BeNil())
|
||||||
|
|
||||||
|
go func(wjr *WritableJobResult[string, string]) {
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
wjr.SetResult("bar", nil)
|
||||||
|
}(wjr)
|
||||||
|
|
||||||
|
resPtr, err := jr.Wait(context.Background())
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(jr.Request).ToNot(BeNil())
|
||||||
|
Expect(*jr.Request()).To(Equal("foo"))
|
||||||
|
Expect(resPtr).ToNot(BeNil())
|
||||||
|
Expect(*resPtr).To(Equal("bar"))
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
It("can be used to recieve an error across goroutines", func() {
|
||||||
|
jr, wjr := NewJobResult[string, string]("foo")
|
||||||
|
Expect(jr).ToNot(BeNil())
|
||||||
|
Expect(wjr).ToNot(BeNil())
|
||||||
|
|
||||||
|
go func(wjr *WritableJobResult[string, string]) {
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
wjr.SetResult("", fmt.Errorf("test"))
|
||||||
|
}(wjr)
|
||||||
|
|
||||||
|
_, err := jr.Wait(context.Background())
|
||||||
|
Expect(jr.Request).ToNot(BeNil())
|
||||||
|
Expect(*jr.Request()).To(Equal("foo"))
|
||||||
|
Expect(err).ToNot(BeNil())
|
||||||
|
Expect(err).To(MatchError("test"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("can properly handle timeouts", func() {
|
||||||
|
jr, wjr := NewJobResult[string, string]("foo")
|
||||||
|
Expect(jr).ToNot(BeNil())
|
||||||
|
Expect(wjr).ToNot(BeNil())
|
||||||
|
|
||||||
|
go func(wjr *WritableJobResult[string, string]) {
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
wjr.SetResult("bar", nil)
|
||||||
|
}(wjr)
|
||||||
|
|
||||||
|
timeout1s, c1 := context.WithTimeoutCause(context.Background(), time.Second, fmt.Errorf("timeout"))
|
||||||
|
timeout10s, c2 := context.WithTimeoutCause(context.Background(), time.Second*10, fmt.Errorf("timeout"))
|
||||||
|
|
||||||
|
_, err := jr.Wait(timeout1s)
|
||||||
|
Expect(jr.Request).ToNot(BeNil())
|
||||||
|
Expect(*jr.Request()).To(Equal("foo"))
|
||||||
|
Expect(err).ToNot(BeNil())
|
||||||
|
Expect(err).To(MatchError(context.DeadlineExceeded))
|
||||||
|
|
||||||
|
resPtr, err := jr.Wait(timeout10s)
|
||||||
|
Expect(jr.Request).ToNot(BeNil())
|
||||||
|
Expect(*jr.Request()).To(Equal("foo"))
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(resPtr).ToNot(BeNil())
|
||||||
|
Expect(*resPtr).To(Equal("bar"))
|
||||||
|
|
||||||
|
// Is this needed? Cleanup Either Way.
|
||||||
|
c1()
|
||||||
|
c2()
|
||||||
|
})
|
||||||
|
})
|
13
pkg/downloader/downloader_suite_test.go
Normal file
13
pkg/downloader/downloader_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloader(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Downloader test suite")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user