mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
feat(federation): add load balanced option (#2915)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
031627584b
commit
252961751c
@ -10,11 +10,12 @@ import (
|
|||||||
type FederatedCLI struct {
|
type FederatedCLI struct {
|
||||||
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
|
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
|
||||||
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
|
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
|
||||||
|
LoadBalanced bool `env:"LOCALAI_LOAD_BALANCED,LOAD_BALANCED" default:"false" help:"Enable load balancing" group:"p2p"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
||||||
|
|
||||||
fs := p2p.NewFederatedServer(f.Address, p2p.FederatedID, f.Peer2PeerToken)
|
fs := p2p.NewFederatedServer(f.Address, p2p.FederatedID, f.Peer2PeerToken, f.LoadBalanced)
|
||||||
|
|
||||||
return fs.Start(context.Background())
|
return fs.Start(context.Background())
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,44 @@ const FederatedID = "federated"
|
|||||||
|
|
||||||
type FederatedServer struct {
|
type FederatedServer struct {
|
||||||
listenAddr, service, p2ptoken string
|
listenAddr, service, p2ptoken string
|
||||||
|
requestTable map[string]int
|
||||||
|
loadBalanced bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFederatedServer(listenAddr, service, p2pToken string) *FederatedServer {
|
func NewFederatedServer(listenAddr, service, p2pToken string, loadBalanced bool) *FederatedServer {
|
||||||
return &FederatedServer{
|
return &FederatedServer{
|
||||||
listenAddr: listenAddr,
|
listenAddr: listenAddr,
|
||||||
service: service,
|
service: service,
|
||||||
p2ptoken: p2pToken,
|
p2ptoken: p2pToken,
|
||||||
|
requestTable: map[string]int{},
|
||||||
|
loadBalanced: loadBalanced,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FederatedServer) SelectLeastUsedServer() string {
|
||||||
|
// cycle over requestTable and find the entry with the lower number
|
||||||
|
// if there are multiple entries with the same number, select one randomly
|
||||||
|
// if there are no entries, return an empty string
|
||||||
|
var min int
|
||||||
|
var minKey string
|
||||||
|
for k, v := range fs.requestTable {
|
||||||
|
if min == 0 || v < min {
|
||||||
|
min = v
|
||||||
|
minKey = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return minKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FederatedServer) RecordRequest(nodeID string) {
|
||||||
|
// increment the counter for the nodeID in the requestTable
|
||||||
|
fs.requestTable[nodeID]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FederatedServer) EnsureRecordExist(nodeID string) {
|
||||||
|
// if the nodeID is not in the requestTable, add it with a counter of 0
|
||||||
|
_, ok := fs.requestTable[nodeID]
|
||||||
|
if !ok {
|
||||||
|
fs.requestTable[nodeID] = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,10 +100,23 @@ func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// open a TCP stream to one of the tunnels
|
tunnelAddr := ""
|
||||||
// chosen randomly
|
|
||||||
// TODO: optimize this and track usage
|
if fs.loadBalanced {
|
||||||
tunnelAddr := tunnelAddresses[rand.IntN(len(tunnelAddresses))]
|
for _, t := range tunnelAddresses {
|
||||||
|
fs.EnsureRecordExist(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelAddr = fs.SelectLeastUsedServer()
|
||||||
|
log.Debug().Msgf("Selected tunnel %s", tunnelAddr)
|
||||||
|
if tunnelAddr == "" {
|
||||||
|
tunnelAddr = tunnelAddresses[rand.IntN(len(tunnelAddresses))]
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.RecordRequest(tunnelAddr)
|
||||||
|
} else {
|
||||||
|
tunnelAddr = tunnelAddresses[rand.IntN(len(tunnelAddresses))]
|
||||||
|
}
|
||||||
|
|
||||||
tunnelConn, err := net.Dial("tcp", tunnelAddr)
|
tunnelConn, err := net.Dial("tcp", tunnelAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user