Files
inp2p/internal/server/sdwan.go

126 lines
2.5 KiB
Go

package server
import (
"encoding/json"
"errors"
"os"
"sort"
"sync"
"time"
"github.com/openp2p-cn/inp2p/pkg/protocol"
)
type sdwanStore struct {
mu sync.RWMutex
path string
cfg protocol.SDWANConfig
multi map[int64]protocol.SDWANConfig
}
func newSDWANStore(path string) *sdwanStore {
s := &sdwanStore{path: path, multi: make(map[int64]protocol.SDWANConfig)}
_ = s.load()
return s
}
func (s *sdwanStore) load() error {
s.mu.Lock()
defer s.mu.Unlock()
b, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
// try multi-tenant first
var m map[int64]protocol.SDWANConfig
if err := json.Unmarshal(b, &m); err == nil && len(m) > 0 {
for k, v := range m {
m[k] = normalizeSDWAN(v)
}
s.multi = m
return nil
}
var c protocol.SDWANConfig
if err := json.Unmarshal(b, &c); err != nil {
return err
}
s.cfg = normalizeSDWAN(c)
return nil
}
func (s *sdwanStore) save(cfg protocol.SDWANConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
cfg = normalizeSDWAN(cfg)
cfg.UpdatedAt = time.Now().Unix()
b, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(s.path, b, 0644); err != nil {
return err
}
s.cfg = cfg
return nil
}
func (s *sdwanStore) saveTenant(tenantID int64, cfg protocol.SDWANConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
cfg = normalizeSDWAN(cfg)
cfg.UpdatedAt = time.Now().Unix()
if s.multi == nil {
s.multi = make(map[int64]protocol.SDWANConfig)
}
s.multi[tenantID] = cfg
b, err := json.MarshalIndent(s.multi, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(s.path, b, 0644); err != nil {
return err
}
return nil
}
func (s *sdwanStore) get() protocol.SDWANConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cfg
}
func (s *sdwanStore) getTenant(tenantID int64) protocol.SDWANConfig {
s.mu.RLock()
defer s.mu.RUnlock()
if s.multi == nil {
return protocol.SDWANConfig{}
}
return s.multi[tenantID]
}
func normalizeSDWAN(c protocol.SDWANConfig) protocol.SDWANConfig {
if c.Mode == "" {
c.Mode = "hub"
}
if !c.Enabled {
c.Enabled = true
}
// de-dup nodes by node name, keep last and sort for stable output
m := make(map[string]string)
for _, n := range c.Nodes {
if n.Node == "" {
continue
}
m[n.Node] = n.IP
}
c.Nodes = c.Nodes[:0]
for node, ip := range m {
c.Nodes = append(c.Nodes, protocol.SDWANNode{Node: node, IP: ip})
}
sort.Slice(c.Nodes, func(i, j int) bool { return c.Nodes[i].Node < c.Nodes[j].Node })
return c
}