package config import ( "fmt" "log" "os" "time" "github.com/spf13/viper" ) type Config struct { App AppConfig `mapstructure:"app"` Server ServerConfig `mapstructure:"server"` Security SecurityConfig `mapstructure:"security"` SMS SMSConfig `mapstructure:"sms"` Database DatabaseConfig `mapstructure:"database"` Timezone string `mapstructure:"timezone"` APITokens []APIToken `mapstructure:"api_tokens"` } type AppConfig struct { Name string `mapstructure:"name"` Version string `mapstructure:"version"` } type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Debug bool `mapstructure:"debug"` } type SecurityConfig struct { Enabled bool `mapstructure:"enabled"` Username string `mapstructure:"username"` Password string `mapstructure:"password"` PasswordHash string `mapstructure:"password_hash"` // bcrypt 哈希值(推荐使用) SessionLifetime int `mapstructure:"session_lifetime"` SecretKey string `mapstructure:"secret_key"` SignVerify bool `mapstructure:"sign_verify"` SignMaxAge int64 `mapstructure:"sign_max_age"` } // Validate 验证配置的有效性 func (c *Config) Validate() error { // 验证数据库路径 if c.Database.Path == "" { return fmt.Errorf("数据库路径不能为空") } // 验证安全密钥 if c.Security.SecretKey == "" { return fmt.Errorf("安全密钥不能为空,请在配置文件中设置 secret_key") } // 验证密钥长度(至少16字节) key := []byte(c.Security.SecretKey) if len(key) < 16 { return fmt.Errorf("安全密钥长度不足,建议至少16字节(当前: %d 字节)", len(key)) } // 设置默认值 if c.Security.SessionLifetime == 0 { c.Security.SessionLifetime = DefaultSessionLifetime log.Printf("使用默认会话有效期: %d 秒", DefaultSessionLifetime) } if c.Security.SignMaxAge == 0 { c.Security.SignMaxAge = DefaultSignMaxAge log.Printf("使用默认签名有效期: %d 毫秒", DefaultSignMaxAge) } // 如果启用了登录验证,验证用户名和密码 if c.Security.Enabled { if c.Security.Username == "" { return fmt.Errorf("启用登录验证时,用户名不能为空") } if c.Security.Password == "" && c.Security.PasswordHash == "" { return fmt.Errorf("启用登录验证时,必须设置 password 或 password_hash") } } // 验证服务器端口 if c.Server.Port < 1 || c.Server.Port > 65535 { return fmt.Errorf("服务器端口无效: %d", c.Server.Port) } // 验证时区 if c.Timezone == "" { c.Timezone = "Asia/Shanghai" log.Printf("使用默认时区: %s", c.Timezone) } // 检查时区是否有效 if _, err := time.LoadLocation(c.Timezone); err != nil { return fmt.Errorf("无效的时区配置: %s", c.Timezone) } // 日志提示 if c.Security.Password != "" && c.Security.PasswordHash != "" { log.Printf("警告: 同时设置了 password 和 password_hash,将优先使用 password_hash") } if c.Security.Password != "" { log.Printf("警告: 使用明文密码不安全,建议使用 password_hash") } return nil } type SMSConfig struct { MaxMessages int `mapstructure:"max_messages"` AutoCleanup bool `mapstructure:"auto_cleanup"` CleanupDays int `mapstructure:"cleanup_days"` } type DatabaseConfig struct { Path string `mapstructure:"path"` } type APIToken struct { Name string `mapstructure:"name"` Token string `mapstructure:"token"` Secret string `mapstructure:"secret"` Enabled bool `mapstructure:"enabled"` } var cfg *Config func Load(configPath string) (*Config, error) { viper.SetConfigFile(configPath) viper.SetConfigType("yaml") // 允许环境变量覆盖 viper.AutomaticEnv() if err := viper.ReadInConfig(); err != nil { return nil, fmt.Errorf("读取配置文件失败: %w", err) } cfg = &Config{} if err := viper.Unmarshal(cfg); err != nil { return nil, fmt.Errorf("解析配置文件失败: %w", err) } // 验证配置 if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } return cfg, nil } func Get() *Config { return cfg } // GetSessionLifetimeDuration 返回会话 lifetime 为 duration func (c *Config) GetSessionLifetimeDuration() time.Duration { return time.Duration(c.Security.SessionLifetime) * time.Second } // GetSignMaxAgeDuration 返回签名最大有效期 func (c *Config) GetSignMaxAgeDuration() time.Duration { return time.Duration(c.Security.SignMaxAge) * time.Millisecond } // GetServerAddress 返回服务器地址 func (c *Config) GetServerAddress() string { return fmt.Sprintf("%s:%d", c.Server.Host, c.Server.Port) } // GetTokenByName 根据名称获取 Token 配置 func (c *Config) GetTokenByName(name string) *APIToken { for i := range c.APITokens { if c.APITokens[i].Name == name && c.APITokens[i].Enabled { return &c.APITokens[i] } } return nil } // GetTokenByValue 根据 token 值获取配置 func (c *Config) GetTokenByValue(token string) *APIToken { for i := range c.APITokens { if c.APITokens[i].Token == token && c.APITokens[i].Enabled { return &c.APITokens[i] } } return nil } // Save 保存配置到文件 func (c *Config) Save(path string) error { viper.Set("app", c.App) viper.Set("server", c.Server) viper.Set("security", c.Security) viper.Set("sms", c.SMS) viper.Set("database", c.Database) viper.Set("timezone", c.Timezone) viper.Set("api_tokens", c.APITokens) return viper.WriteConfigAs(path) } // LoadDefault 加载默认配置文件 func LoadDefault() (*Config, error) { configPath := "config.yaml" if _, err := os.Stat(configPath); os.IsNotExist(err) { // 尝试查找上层目录 if _, err := os.Stat("../config.yaml"); err == nil { configPath = "../config.yaml" } } return Load(configPath) }