🔴 高优先级 (6项全部完成): - 数据库事务支持 (InsertMessageWithLog) - SQL注入修复 (参数化查询) - 配置验证 (Validate方法) - 会话密钥强化 (长度验证) - 签名验证增强 (SignVerificationResult) - 密码哈希支持 (bcrypt) 🟡 中优先级 (15项全部完成): - 连接池配置 (MaxOpenConns, MaxIdleConns) - 查询优化 (范围查询, 索引) - 健康检查增强 (/health 端点) - API版本控制 (/api/v1/*) - 认证中间件 (RequireAuth, RequireAPIAuth) - 定时任务优化 (robfig/cron) - 配置文件示例 (config.example.yaml) - 常量定义 (config/constants.go) - 开发文档 (DEVELOPMENT.md) 🟢 低优先级 (9项全部完成): - Docker支持 (Dockerfile, docker-compose.yml) - Makefile构建脚本 - 优化报告 (OPTIMIZATION_REPORT.md) - 密码哈希工具 (tools/password_hash.go) - 14个新文件 - 30项优化100%完成 版本: v2.0.0
415 lines
10 KiB
Go
415 lines
10 KiB
Go
package database
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"log"
|
||
"strings"
|
||
"time"
|
||
|
||
"sms-receiver-go/config"
|
||
"sms-receiver-go/models"
|
||
|
||
_ "github.com/mattn/go-sqlite3"
|
||
)
|
||
|
||
var db *sql.DB
|
||
|
||
// Init 初始化数据库
|
||
func Init(cfg *config.DatabaseConfig) error {
|
||
var err error
|
||
db, err = sql.Open("sqlite3", cfg.Path)
|
||
if err != nil {
|
||
return fmt.Errorf("打开数据库失败: %w", err)
|
||
}
|
||
|
||
if err = db.Ping(); err != nil {
|
||
return fmt.Errorf("数据库连接失败: %w", err)
|
||
}
|
||
|
||
// 配置连接池
|
||
db.SetMaxOpenConns(25)
|
||
db.SetMaxIdleConns(5)
|
||
db.SetConnMaxLifetime(5 * time.Minute)
|
||
|
||
// 创建表
|
||
if err = createTables(); err != nil {
|
||
return fmt.Errorf("创建表失败: %w", err)
|
||
}
|
||
|
||
log.Printf("数据库初始化成功: %s", cfg.Path)
|
||
return nil
|
||
}
|
||
|
||
// createTables 创建数据表
|
||
func createTables() error {
|
||
// 短信消息表
|
||
createMessagesSQL := `
|
||
CREATE TABLE IF NOT EXISTS sms_messages (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
from_number TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
timestamp INTEGER NOT NULL,
|
||
device_info TEXT,
|
||
sim_info TEXT,
|
||
sign_verified INTEGER,
|
||
ip_address TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
`
|
||
|
||
// 接收日志表
|
||
createLogsSQL := `
|
||
CREATE TABLE IF NOT EXISTS receive_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
from_number TEXT NOT NULL,
|
||
content TEXT NOT NULL,
|
||
timestamp INTEGER NOT NULL,
|
||
sign TEXT,
|
||
sign_valid INTEGER,
|
||
ip_address TEXT,
|
||
status TEXT NOT NULL,
|
||
error_message TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
`
|
||
|
||
// 创建索引
|
||
createIndexesSQL := `
|
||
CREATE INDEX IF NOT EXISTS idx_messages_from ON sms_messages(from_number);
|
||
CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON sms_messages(timestamp);
|
||
CREATE INDEX IF NOT EXISTS idx_messages_created ON sms_messages(created_at);
|
||
CREATE INDEX IF NOT EXISTS idx_logs_created ON receive_logs(created_at);
|
||
CREATE INDEX IF NOT EXISTS idx_logs_status ON receive_logs(status);
|
||
`
|
||
|
||
statements := []string{createMessagesSQL, createLogsSQL, createIndexesSQL}
|
||
for _, stmt := range statements {
|
||
if _, err := db.Exec(stmt); err != nil {
|
||
return fmt.Errorf("执行 SQL 失败: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// InsertMessage 插入短信消息
|
||
func InsertMessage(msg *models.SMSMessage) (int64, error) {
|
||
result, err := db.Exec(`
|
||
INSERT INTO sms_messages (from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
`,
|
||
msg.FromNumber,
|
||
msg.Content,
|
||
msg.Timestamp,
|
||
msg.DeviceInfo,
|
||
msg.SIMInfo,
|
||
msg.SignVerified,
|
||
msg.IPAddress,
|
||
)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("插入消息失败: %w", err)
|
||
}
|
||
return result.LastInsertId()
|
||
}
|
||
|
||
// InsertLog 插入接收日志
|
||
func InsertLog(log *models.ReceiveLog) (int64, error) {
|
||
result, err := db.Exec(`
|
||
INSERT INTO receive_logs (from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
`,
|
||
log.FromNumber,
|
||
log.Content,
|
||
log.Timestamp,
|
||
log.Sign,
|
||
log.SignValid,
|
||
log.IPAddress,
|
||
log.Status,
|
||
log.ErrorMessage,
|
||
)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("插入日志失败: %w", err)
|
||
}
|
||
return result.LastInsertId()
|
||
}
|
||
|
||
// InsertMessageWithLog 在事务中插入消息和日志
|
||
func InsertMessageWithLog(msg *models.SMSMessage, log *models.ReceiveLog) (int64, error) {
|
||
// 开启事务
|
||
tx, err := db.Begin()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("开启事务失败: %w", err)
|
||
}
|
||
|
||
// 确保在出错时回滚
|
||
defer func() {
|
||
if err != nil {
|
||
tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
// 插入消息
|
||
msgResult, err := tx.Exec(`
|
||
INSERT INTO sms_messages (from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
`,
|
||
msg.FromNumber,
|
||
msg.Content,
|
||
msg.Timestamp,
|
||
msg.DeviceInfo,
|
||
msg.SIMInfo,
|
||
msg.SignVerified,
|
||
msg.IPAddress,
|
||
)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("插入消息失败: %w", err)
|
||
}
|
||
|
||
messageID, err := msgResult.LastInsertId()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("获取消息ID失败: %w", err)
|
||
}
|
||
|
||
// 插入日志
|
||
_, err = tx.Exec(`
|
||
INSERT INTO receive_logs (from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
`,
|
||
log.FromNumber,
|
||
log.Content,
|
||
log.Timestamp,
|
||
log.Sign,
|
||
log.SignValid,
|
||
log.IPAddress,
|
||
log.Status,
|
||
log.ErrorMessage,
|
||
)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("插入日志失败: %w", err)
|
||
}
|
||
|
||
// 提交事务
|
||
if err = tx.Commit(); err != nil {
|
||
return 0, fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
|
||
return messageID, nil
|
||
}
|
||
|
||
// GetMessages 获取短信列表
|
||
func GetMessages(page, limit int, from string, search string) ([]models.SMSMessage, int64, error) {
|
||
offset := (page - 1) * limit
|
||
|
||
// 构建查询条件(WHERE 子句)
|
||
// 注意:条件字段名已经是固定的,不包含用户输入,因此使用字符串拼接是安全的
|
||
var conditions []string
|
||
var args []interface{}
|
||
|
||
if from != "" {
|
||
conditions = append(conditions, "from_number = ?")
|
||
args = append(args, from)
|
||
}
|
||
if search != "" {
|
||
conditions = append(conditions, "(from_number LIKE ? OR content LIKE ?)")
|
||
args = append(args, "%"+search+"%", "%"+search+"%")
|
||
}
|
||
|
||
// 构建 WHERE 子句
|
||
whereClause := ""
|
||
if len(conditions) > 0 {
|
||
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||
}
|
||
|
||
// 查询总数
|
||
var total int64
|
||
countQuery := "SELECT COUNT(*) FROM sms_messages"
|
||
if whereClause != "" {
|
||
countQuery += " " + whereClause
|
||
}
|
||
if err := db.QueryRow(countQuery, args...).Scan(&total); err != nil {
|
||
return nil, 0, fmt.Errorf("查询总数失败: %w", err)
|
||
}
|
||
|
||
// 查询数据(按短信时间戳排序,与 Python 版本一致)
|
||
query := `
|
||
SELECT id, from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address, created_at
|
||
FROM sms_messages
|
||
`
|
||
if whereClause != "" {
|
||
query += " " + whereClause
|
||
}
|
||
query += " ORDER BY timestamp DESC, id DESC LIMIT ? OFFSET ?"
|
||
|
||
args = append(args, limit, offset)
|
||
rows, err := db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("查询消息失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []models.SMSMessage
|
||
for rows.Next() {
|
||
var msg models.SMSMessage
|
||
err := rows.Scan(
|
||
&msg.ID,
|
||
&msg.FromNumber,
|
||
&msg.Content,
|
||
&msg.Timestamp,
|
||
&msg.DeviceInfo,
|
||
&msg.SIMInfo,
|
||
&msg.SignVerified,
|
||
&msg.IPAddress,
|
||
&msg.CreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("扫描消息失败: %w", err)
|
||
}
|
||
messages = append(messages, msg)
|
||
}
|
||
|
||
return messages, total, nil
|
||
}
|
||
|
||
// GetMessageByID 根据 ID 获取消息详情
|
||
func GetMessageByID(id int64) (*models.SMSMessage, error) {
|
||
var msg models.SMSMessage
|
||
err := db.QueryRow(`
|
||
SELECT id, from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address, created_at
|
||
FROM sms_messages WHERE id = ?
|
||
`, id).Scan(
|
||
&msg.ID,
|
||
&msg.FromNumber,
|
||
&msg.Content,
|
||
&msg.Timestamp,
|
||
&msg.DeviceInfo,
|
||
&msg.SIMInfo,
|
||
&msg.SignVerified,
|
||
&msg.IPAddress,
|
||
&msg.CreatedAt,
|
||
)
|
||
if err != nil {
|
||
if err == sql.ErrNoRows {
|
||
return nil, nil
|
||
}
|
||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||
}
|
||
return &msg, nil
|
||
}
|
||
|
||
// GetStatistics 获取统计信息
|
||
func GetStatistics() (*models.Statistics, error) {
|
||
stats := &models.Statistics{}
|
||
|
||
// 总数
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages").Scan(&stats.Total); err != nil {
|
||
return nil, fmt.Errorf("查询总数失败: %w", err)
|
||
}
|
||
|
||
// 今日数量(使用范围查询,避免使用函数索引)
|
||
now := time.Now()
|
||
loc := now.Location()
|
||
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
|
||
todayEnd := todayStart.Add(24 * time.Hour)
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE created_at >= ? AND created_at < ?",
|
||
todayStart, todayEnd).Scan(&stats.Today); err != nil {
|
||
return nil, fmt.Errorf("查询今日数量失败: %w", err)
|
||
}
|
||
|
||
// 本周数量(周一作为周开始)
|
||
now = time.Now()
|
||
loc = now.Location()
|
||
weekday := int(now.Weekday())
|
||
var weekStart time.Time
|
||
if weekday == 0 {
|
||
// 周日
|
||
weekStart = time.Date(now.Year(), now.Month(), now.Day()-6, 0, 0, 0, 0, loc)
|
||
} else {
|
||
// 周一到周六
|
||
weekStart = time.Date(now.Year(), now.Month(), now.Day()-(weekday-1), 0, 0, 0, 0, loc)
|
||
}
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE created_at >= ?", weekStart).Scan(&stats.Week); err != nil {
|
||
return nil, fmt.Errorf("查询本周数量失败: %w", err)
|
||
}
|
||
|
||
// 签名验证通过数量
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE sign_verified = 1").Scan(&stats.Verified); err != nil {
|
||
return nil, fmt.Errorf("查询验证通过数量失败: %w", err)
|
||
}
|
||
|
||
// 签名验证未通过数量
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE sign_verified = 0").Scan(&stats.Unverified); err != nil {
|
||
return nil, fmt.Errorf("查询验证未通过数量失败: %w", err)
|
||
}
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// GetLogs 获取接收日志
|
||
func GetLogs(page, limit int) ([]models.ReceiveLog, int64, error) {
|
||
offset := (page - 1) * limit
|
||
|
||
// 查询总数
|
||
var total int64
|
||
if err := db.QueryRow("SELECT COUNT(*) FROM receive_logs").Scan(&total); err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
rows, err := db.Query(`
|
||
SELECT id, from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message, created_at
|
||
FROM receive_logs
|
||
ORDER BY created_at DESC
|
||
LIMIT ? OFFSET ?
|
||
`, limit, offset)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var logs []models.ReceiveLog
|
||
for rows.Next() {
|
||
var log models.ReceiveLog
|
||
err := rows.Scan(
|
||
&log.ID,
|
||
&log.FromNumber,
|
||
&log.Content,
|
||
&log.Timestamp,
|
||
&log.Sign,
|
||
&log.SignValid,
|
||
&log.IPAddress,
|
||
&log.Status,
|
||
&log.ErrorMessage,
|
||
&log.CreatedAt,
|
||
)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
logs = append(logs, log)
|
||
}
|
||
|
||
return logs, total, nil
|
||
}
|
||
|
||
// CleanupOldMessages 清理旧消息
|
||
func CleanupOldMessages(days int) (int64, error) {
|
||
cutoff := time.Now().AddDate(0, 0, -days).Format("2006-01-02 15:04:05")
|
||
result, err := db.Exec("DELETE FROM sms_messages WHERE created_at < ?", cutoff)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return result.RowsAffected()
|
||
}
|
||
|
||
// GetDB 获取数据库实例
|
||
func GetDB() *sql.DB {
|
||
return db
|
||
}
|
||
|
||
// Close 关闭数据库连接
|
||
func Close() error {
|
||
if db != nil {
|
||
return db.Close()
|
||
}
|
||
return nil
|
||
}
|