mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
groundwork: add pkg/concurrency and the associated test file (#2745)
groundwork: add pkg/concurrency and the associated test case Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
63fc22baab
commit
fc29c04f82
13
pkg/concurrency/concurrency_suite_test.go
Normal file
13
pkg/concurrency/concurrency_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package concurrency
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Concurrency test suite")
|
||||
}
|
69
pkg/concurrency/jobresult.go
Normal file
69
pkg/concurrency/jobresult.go
Normal file
@ -0,0 +1,69 @@
|
||||
package concurrency
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// This is a Read-ONLY structure that contains the result of an arbitrary asynchronous action
|
||||
type JobResult[RequestType any, ResultType any] struct {
|
||||
request *RequestType
|
||||
result *ResultType
|
||||
err error
|
||||
once sync.Once
|
||||
done *chan struct{}
|
||||
}
|
||||
|
||||
// This structure is returned in a pair with a JobResult and serves as the structure that has access to be updated.
|
||||
type WritableJobResult[RequestType any, ResultType any] struct {
|
||||
*JobResult[RequestType, ResultType]
|
||||
}
|
||||
|
||||
// Wait blocks until the result is ready and then returns the result, or the context expires.
|
||||
// Returns *ResultType instead of ResultType since its possible we have only an error and nil for ResultType.
|
||||
// Is this correct and idiomatic?
|
||||
func (jr *JobResult[RequestType, ResultType]) Wait(ctx context.Context) (*ResultType, error) {
|
||||
if jr.done == nil { // If the channel is blanked out, result is ready.
|
||||
return jr.result, jr.err
|
||||
}
|
||||
select {
|
||||
case <-*jr.done: // Wait for the result to be ready
|
||||
jr.done = nil
|
||||
if jr.err != nil {
|
||||
return nil, jr.err
|
||||
}
|
||||
return jr.result, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Accessor function to allow holders of JobResults to access the associated request, without allowing the pointer to be updated.
|
||||
func (jr *JobResult[RequestType, ResultType]) Request() *RequestType {
|
||||
return jr.request
|
||||
}
|
||||
|
||||
// This is the function that actually updates the Result and Error on the JobResult... but it's normally not accessible
|
||||
func (jr *JobResult[RequestType, ResultType]) setResult(result ResultType, err error) {
|
||||
jr.once.Do(func() {
|
||||
jr.result = &result
|
||||
jr.err = err
|
||||
close(*jr.done) // Signal that the result is ready - since this is only ran once, jr.done cannot be set to nil yet.
|
||||
})
|
||||
}
|
||||
|
||||
// Only the WritableJobResult can actually call setResult - prevents accidental corruption
|
||||
func (wjr *WritableJobResult[RequestType, ResultType]) SetResult(result ResultType, err error) {
|
||||
wjr.JobResult.setResult(result, err)
|
||||
}
|
||||
|
||||
// NewJobResult binds a request to a matched pair of JobResult and WritableJobResult
|
||||
func NewJobResult[RequestType any, ResultType any](request RequestType) (*JobResult[RequestType, ResultType], *WritableJobResult[RequestType, ResultType]) {
|
||||
done := make(chan struct{})
|
||||
jr := &JobResult[RequestType, ResultType]{
|
||||
once: sync.Once{},
|
||||
request: &request,
|
||||
done: &done,
|
||||
}
|
||||
return jr, &WritableJobResult[RequestType, ResultType]{JobResult: jr}
|
||||
}
|
80
pkg/concurrency/jobresult_test.go
Normal file
80
pkg/concurrency/jobresult_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package concurrency_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/concurrency"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("pkg/concurrency unit tests", func() {
|
||||
It("can be used to recieve a result across goroutines", func() {
|
||||
jr, wjr := NewJobResult[string, string]("foo")
|
||||
Expect(jr).ToNot(BeNil())
|
||||
Expect(wjr).ToNot(BeNil())
|
||||
|
||||
go func(wjr *WritableJobResult[string, string]) {
|
||||
time.Sleep(time.Second * 5)
|
||||
wjr.SetResult("bar", nil)
|
||||
}(wjr)
|
||||
|
||||
resPtr, err := jr.Wait(context.Background())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(jr.Request).ToNot(BeNil())
|
||||
Expect(*jr.Request()).To(Equal("foo"))
|
||||
Expect(resPtr).ToNot(BeNil())
|
||||
Expect(*resPtr).To(Equal("bar"))
|
||||
|
||||
})
|
||||
|
||||
It("can be used to recieve an error across goroutines", func() {
|
||||
jr, wjr := NewJobResult[string, string]("foo")
|
||||
Expect(jr).ToNot(BeNil())
|
||||
Expect(wjr).ToNot(BeNil())
|
||||
|
||||
go func(wjr *WritableJobResult[string, string]) {
|
||||
time.Sleep(time.Second * 5)
|
||||
wjr.SetResult("", fmt.Errorf("test"))
|
||||
}(wjr)
|
||||
|
||||
_, err := jr.Wait(context.Background())
|
||||
Expect(jr.Request).ToNot(BeNil())
|
||||
Expect(*jr.Request()).To(Equal("foo"))
|
||||
Expect(err).ToNot(BeNil())
|
||||
Expect(err).To(MatchError("test"))
|
||||
})
|
||||
|
||||
It("can properly handle timeouts", func() {
|
||||
jr, wjr := NewJobResult[string, string]("foo")
|
||||
Expect(jr).ToNot(BeNil())
|
||||
Expect(wjr).ToNot(BeNil())
|
||||
|
||||
go func(wjr *WritableJobResult[string, string]) {
|
||||
time.Sleep(time.Second * 5)
|
||||
wjr.SetResult("bar", nil)
|
||||
}(wjr)
|
||||
|
||||
timeout1s, c1 := context.WithTimeoutCause(context.Background(), time.Second, fmt.Errorf("timeout"))
|
||||
timeout10s, c2 := context.WithTimeoutCause(context.Background(), time.Second*10, fmt.Errorf("timeout"))
|
||||
|
||||
_, err := jr.Wait(timeout1s)
|
||||
Expect(jr.Request).ToNot(BeNil())
|
||||
Expect(*jr.Request()).To(Equal("foo"))
|
||||
Expect(err).ToNot(BeNil())
|
||||
Expect(err).To(MatchError(context.DeadlineExceeded))
|
||||
|
||||
resPtr, err := jr.Wait(timeout10s)
|
||||
Expect(jr.Request).ToNot(BeNil())
|
||||
Expect(*jr.Request()).To(Equal("foo"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(resPtr).ToNot(BeNil())
|
||||
Expect(*resPtr).To(Equal("bar"))
|
||||
|
||||
// Is this needed? Cleanup Either Way.
|
||||
c1()
|
||||
c2()
|
||||
})
|
||||
})
|
13
pkg/downloader/downloader_suite_test.go
Normal file
13
pkg/downloader/downloader_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestDownloader(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Downloader test suite")
|
||||
}
|
Loading…
Reference in New Issue
Block a user