diff --git a/pkg/concurrency/concurrency_suite_test.go b/pkg/concurrency/concurrency_suite_test.go new file mode 100644 index 00000000..d3175965 --- /dev/null +++ b/pkg/concurrency/concurrency_suite_test.go @@ -0,0 +1,13 @@ +package concurrency + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestConcurrency(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Concurrency test suite") +} diff --git a/pkg/concurrency/jobresult.go b/pkg/concurrency/jobresult.go new file mode 100644 index 00000000..e0d77a38 --- /dev/null +++ b/pkg/concurrency/jobresult.go @@ -0,0 +1,69 @@ +package concurrency + +import ( + "context" + "sync" +) + +// This is a Read-ONLY structure that contains the result of an arbitrary asynchronous action +type JobResult[RequestType any, ResultType any] struct { + request *RequestType + result *ResultType + err error + once sync.Once + done *chan struct{} +} + +// This structure is returned in a pair with a JobResult and serves as the structure that has access to be updated. +type WritableJobResult[RequestType any, ResultType any] struct { + *JobResult[RequestType, ResultType] +} + +// Wait blocks until the result is ready and then returns the result, or the context expires. +// Returns *ResultType instead of ResultType since its possible we have only an error and nil for ResultType. +// Is this correct and idiomatic? +func (jr *JobResult[RequestType, ResultType]) Wait(ctx context.Context) (*ResultType, error) { + if jr.done == nil { // If the channel is blanked out, result is ready. + return jr.result, jr.err + } + select { + case <-*jr.done: // Wait for the result to be ready + jr.done = nil + if jr.err != nil { + return nil, jr.err + } + return jr.result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Accessor function to allow holders of JobResults to access the associated request, without allowing the pointer to be updated. +func (jr *JobResult[RequestType, ResultType]) Request() *RequestType { + return jr.request +} + +// This is the function that actually updates the Result and Error on the JobResult... but it's normally not accessible +func (jr *JobResult[RequestType, ResultType]) setResult(result ResultType, err error) { + jr.once.Do(func() { + jr.result = &result + jr.err = err + close(*jr.done) // Signal that the result is ready - since this is only ran once, jr.done cannot be set to nil yet. + }) +} + +// Only the WritableJobResult can actually call setResult - prevents accidental corruption +func (wjr *WritableJobResult[RequestType, ResultType]) SetResult(result ResultType, err error) { + wjr.JobResult.setResult(result, err) +} + +// NewJobResult binds a request to a matched pair of JobResult and WritableJobResult +func NewJobResult[RequestType any, ResultType any](request RequestType) (*JobResult[RequestType, ResultType], *WritableJobResult[RequestType, ResultType]) { + done := make(chan struct{}) + jr := &JobResult[RequestType, ResultType]{ + once: sync.Once{}, + request: &request, + done: &done, + } + return jr, &WritableJobResult[RequestType, ResultType]{JobResult: jr} +} diff --git a/pkg/concurrency/jobresult_test.go b/pkg/concurrency/jobresult_test.go new file mode 100644 index 00000000..a7bd1ff6 --- /dev/null +++ b/pkg/concurrency/jobresult_test.go @@ -0,0 +1,80 @@ +package concurrency_test + +import ( + "context" + "fmt" + "time" + + . "github.com/mudler/LocalAI/pkg/concurrency" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("pkg/concurrency unit tests", func() { + It("can be used to recieve a result across goroutines", func() { + jr, wjr := NewJobResult[string, string]("foo") + Expect(jr).ToNot(BeNil()) + Expect(wjr).ToNot(BeNil()) + + go func(wjr *WritableJobResult[string, string]) { + time.Sleep(time.Second * 5) + wjr.SetResult("bar", nil) + }(wjr) + + resPtr, err := jr.Wait(context.Background()) + Expect(err).To(BeNil()) + Expect(jr.Request).ToNot(BeNil()) + Expect(*jr.Request()).To(Equal("foo")) + Expect(resPtr).ToNot(BeNil()) + Expect(*resPtr).To(Equal("bar")) + + }) + + It("can be used to recieve an error across goroutines", func() { + jr, wjr := NewJobResult[string, string]("foo") + Expect(jr).ToNot(BeNil()) + Expect(wjr).ToNot(BeNil()) + + go func(wjr *WritableJobResult[string, string]) { + time.Sleep(time.Second * 5) + wjr.SetResult("", fmt.Errorf("test")) + }(wjr) + + _, err := jr.Wait(context.Background()) + Expect(jr.Request).ToNot(BeNil()) + Expect(*jr.Request()).To(Equal("foo")) + Expect(err).ToNot(BeNil()) + Expect(err).To(MatchError("test")) + }) + + It("can properly handle timeouts", func() { + jr, wjr := NewJobResult[string, string]("foo") + Expect(jr).ToNot(BeNil()) + Expect(wjr).ToNot(BeNil()) + + go func(wjr *WritableJobResult[string, string]) { + time.Sleep(time.Second * 5) + wjr.SetResult("bar", nil) + }(wjr) + + timeout1s, c1 := context.WithTimeoutCause(context.Background(), time.Second, fmt.Errorf("timeout")) + timeout10s, c2 := context.WithTimeoutCause(context.Background(), time.Second*10, fmt.Errorf("timeout")) + + _, err := jr.Wait(timeout1s) + Expect(jr.Request).ToNot(BeNil()) + Expect(*jr.Request()).To(Equal("foo")) + Expect(err).ToNot(BeNil()) + Expect(err).To(MatchError(context.DeadlineExceeded)) + + resPtr, err := jr.Wait(timeout10s) + Expect(jr.Request).ToNot(BeNil()) + Expect(*jr.Request()).To(Equal("foo")) + Expect(err).To(BeNil()) + Expect(resPtr).ToNot(BeNil()) + Expect(*resPtr).To(Equal("bar")) + + // Is this needed? Cleanup Either Way. + c1() + c2() + }) +}) diff --git a/pkg/downloader/downloader_suite_test.go b/pkg/downloader/downloader_suite_test.go new file mode 100644 index 00000000..752def7a --- /dev/null +++ b/pkg/downloader/downloader_suite_test.go @@ -0,0 +1,13 @@ +package downloader + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestDownloader(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Downloader test suite") +}