diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go new file mode 100644 index 00000000..4c3c1a88 --- /dev/null +++ b/pkg/model/loader_test.go @@ -0,0 +1,105 @@ +package model_test + +import ( + "github.com/go-skynet/LocalAI/pkg/model" + . "github.com/go-skynet/LocalAI/pkg/model" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }} +{{- if .Content}} +{{.Content }} +{{- end }} +{{- if .FunctionCall}} +{{toJson .FunctionCall}} +{{- end }} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }} +<|im_end|>` + +var testMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "template": chatML, + "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "user", + RoleName: "user", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "assistant": { + "template": chatML, + "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_call": { + "template": chatML, + "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "", + FunctionCall: map[string]string{"function": "test"}, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_response": { + "template": chatML, + "expected": "<|im_start|>tool\n\nResponse from tool\n\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "tool", + RoleName: "tool", + Content: "Response from tool", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, +} + +var _ = Describe("Templates", func() { + Context("chat message", func() { + modelLoader := NewModelLoader("") + for key := range testMatch { + foo := testMatch[key] + It("renders correctly "+key, func() { + templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) + Expect(err).ToNot(HaveOccurred()) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) +}) diff --git a/pkg/model/model_suite_test.go b/pkg/model/model_suite_test.go new file mode 100644 index 00000000..6fa9c004 --- /dev/null +++ b/pkg/model/model_suite_test.go @@ -0,0 +1,13 @@ +package model_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestModel(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "LocalAI model test") +}