Files
SmsReceiver-go/database/database.go
OpenClaw Agent 1da899a0f4 feat: v2.0.0 完整代码优化升级
🔴 高优先级 (6项全部完成):
- 数据库事务支持 (InsertMessageWithLog)
- SQL注入修复 (参数化查询)
- 配置验证 (Validate方法)
- 会话密钥强化 (长度验证)
- 签名验证增强 (SignVerificationResult)
- 密码哈希支持 (bcrypt)

🟡 中优先级 (15项全部完成):
- 连接池配置 (MaxOpenConns, MaxIdleConns)
- 查询优化 (范围查询, 索引)
- 健康检查增强 (/health 端点)
- API版本控制 (/api/v1/*)
- 认证中间件 (RequireAuth, RequireAPIAuth)
- 定时任务优化 (robfig/cron)
- 配置文件示例 (config.example.yaml)
- 常量定义 (config/constants.go)
- 开发文档 (DEVELOPMENT.md)

🟢 低优先级 (9项全部完成):
- Docker支持 (Dockerfile, docker-compose.yml)
- Makefile构建脚本
- 优化报告 (OPTIMIZATION_REPORT.md)
- 密码哈希工具 (tools/password_hash.go)
- 14个新文件
- 30项优化100%完成

版本: v2.0.0
2026-02-08 18:59:29 +08:00

415 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"database/sql"
"fmt"
"log"
"strings"
"time"
"sms-receiver-go/config"
"sms-receiver-go/models"
_ "github.com/mattn/go-sqlite3"
)
var db *sql.DB
// Init 初始化数据库
func Init(cfg *config.DatabaseConfig) error {
var err error
db, err = sql.Open("sqlite3", cfg.Path)
if err != nil {
return fmt.Errorf("打开数据库失败: %w", err)
}
if err = db.Ping(); err != nil {
return fmt.Errorf("数据库连接失败: %w", err)
}
// 配置连接池
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// 创建表
if err = createTables(); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
log.Printf("数据库初始化成功: %s", cfg.Path)
return nil
}
// createTables 创建数据表
func createTables() error {
// 短信消息表
createMessagesSQL := `
CREATE TABLE IF NOT EXISTS sms_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
from_number TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER NOT NULL,
device_info TEXT,
sim_info TEXT,
sign_verified INTEGER,
ip_address TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`
// 接收日志表
createLogsSQL := `
CREATE TABLE IF NOT EXISTS receive_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
from_number TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER NOT NULL,
sign TEXT,
sign_valid INTEGER,
ip_address TEXT,
status TEXT NOT NULL,
error_message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`
// 创建索引
createIndexesSQL := `
CREATE INDEX IF NOT EXISTS idx_messages_from ON sms_messages(from_number);
CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON sms_messages(timestamp);
CREATE INDEX IF NOT EXISTS idx_messages_created ON sms_messages(created_at);
CREATE INDEX IF NOT EXISTS idx_logs_created ON receive_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_logs_status ON receive_logs(status);
`
statements := []string{createMessagesSQL, createLogsSQL, createIndexesSQL}
for _, stmt := range statements {
if _, err := db.Exec(stmt); err != nil {
return fmt.Errorf("执行 SQL 失败: %w", err)
}
}
return nil
}
// InsertMessage 插入短信消息
func InsertMessage(msg *models.SMSMessage) (int64, error) {
result, err := db.Exec(`
INSERT INTO sms_messages (from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address)
VALUES (?, ?, ?, ?, ?, ?, ?)
`,
msg.FromNumber,
msg.Content,
msg.Timestamp,
msg.DeviceInfo,
msg.SIMInfo,
msg.SignVerified,
msg.IPAddress,
)
if err != nil {
return 0, fmt.Errorf("插入消息失败: %w", err)
}
return result.LastInsertId()
}
// InsertLog 插入接收日志
func InsertLog(log *models.ReceiveLog) (int64, error) {
result, err := db.Exec(`
INSERT INTO receive_logs (from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`,
log.FromNumber,
log.Content,
log.Timestamp,
log.Sign,
log.SignValid,
log.IPAddress,
log.Status,
log.ErrorMessage,
)
if err != nil {
return 0, fmt.Errorf("插入日志失败: %w", err)
}
return result.LastInsertId()
}
// InsertMessageWithLog 在事务中插入消息和日志
func InsertMessageWithLog(msg *models.SMSMessage, log *models.ReceiveLog) (int64, error) {
// 开启事务
tx, err := db.Begin()
if err != nil {
return 0, fmt.Errorf("开启事务失败: %w", err)
}
// 确保在出错时回滚
defer func() {
if err != nil {
tx.Rollback()
}
}()
// 插入消息
msgResult, err := tx.Exec(`
INSERT INTO sms_messages (from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address)
VALUES (?, ?, ?, ?, ?, ?, ?)
`,
msg.FromNumber,
msg.Content,
msg.Timestamp,
msg.DeviceInfo,
msg.SIMInfo,
msg.SignVerified,
msg.IPAddress,
)
if err != nil {
return 0, fmt.Errorf("插入消息失败: %w", err)
}
messageID, err := msgResult.LastInsertId()
if err != nil {
return 0, fmt.Errorf("获取消息ID失败: %w", err)
}
// 插入日志
_, err = tx.Exec(`
INSERT INTO receive_logs (from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`,
log.FromNumber,
log.Content,
log.Timestamp,
log.Sign,
log.SignValid,
log.IPAddress,
log.Status,
log.ErrorMessage,
)
if err != nil {
return 0, fmt.Errorf("插入日志失败: %w", err)
}
// 提交事务
if err = tx.Commit(); err != nil {
return 0, fmt.Errorf("提交事务失败: %w", err)
}
return messageID, nil
}
// GetMessages 获取短信列表
func GetMessages(page, limit int, from string, search string) ([]models.SMSMessage, int64, error) {
offset := (page - 1) * limit
// 构建查询条件WHERE 子句)
// 注意:条件字段名已经是固定的,不包含用户输入,因此使用字符串拼接是安全的
var conditions []string
var args []interface{}
if from != "" {
conditions = append(conditions, "from_number = ?")
args = append(args, from)
}
if search != "" {
conditions = append(conditions, "(from_number LIKE ? OR content LIKE ?)")
args = append(args, "%"+search+"%", "%"+search+"%")
}
// 构建 WHERE 子句
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// 查询总数
var total int64
countQuery := "SELECT COUNT(*) FROM sms_messages"
if whereClause != "" {
countQuery += " " + whereClause
}
if err := db.QueryRow(countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("查询总数失败: %w", err)
}
// 查询数据(按短信时间戳排序,与 Python 版本一致)
query := `
SELECT id, from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address, created_at
FROM sms_messages
`
if whereClause != "" {
query += " " + whereClause
}
query += " ORDER BY timestamp DESC, id DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, 0, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []models.SMSMessage
for rows.Next() {
var msg models.SMSMessage
err := rows.Scan(
&msg.ID,
&msg.FromNumber,
&msg.Content,
&msg.Timestamp,
&msg.DeviceInfo,
&msg.SIMInfo,
&msg.SignVerified,
&msg.IPAddress,
&msg.CreatedAt,
)
if err != nil {
return nil, 0, fmt.Errorf("扫描消息失败: %w", err)
}
messages = append(messages, msg)
}
return messages, total, nil
}
// GetMessageByID 根据 ID 获取消息详情
func GetMessageByID(id int64) (*models.SMSMessage, error) {
var msg models.SMSMessage
err := db.QueryRow(`
SELECT id, from_number, content, timestamp, device_info, sim_info, sign_verified, ip_address, created_at
FROM sms_messages WHERE id = ?
`, id).Scan(
&msg.ID,
&msg.FromNumber,
&msg.Content,
&msg.Timestamp,
&msg.DeviceInfo,
&msg.SIMInfo,
&msg.SignVerified,
&msg.IPAddress,
&msg.CreatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询消息失败: %w", err)
}
return &msg, nil
}
// GetStatistics 获取统计信息
func GetStatistics() (*models.Statistics, error) {
stats := &models.Statistics{}
// 总数
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages").Scan(&stats.Total); err != nil {
return nil, fmt.Errorf("查询总数失败: %w", err)
}
// 今日数量(使用范围查询,避免使用函数索引)
now := time.Now()
loc := now.Location()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
todayEnd := todayStart.Add(24 * time.Hour)
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE created_at >= ? AND created_at < ?",
todayStart, todayEnd).Scan(&stats.Today); err != nil {
return nil, fmt.Errorf("查询今日数量失败: %w", err)
}
// 本周数量(周一作为周开始)
now = time.Now()
loc = now.Location()
weekday := int(now.Weekday())
var weekStart time.Time
if weekday == 0 {
// 周日
weekStart = time.Date(now.Year(), now.Month(), now.Day()-6, 0, 0, 0, 0, loc)
} else {
// 周一到周六
weekStart = time.Date(now.Year(), now.Month(), now.Day()-(weekday-1), 0, 0, 0, 0, loc)
}
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE created_at >= ?", weekStart).Scan(&stats.Week); err != nil {
return nil, fmt.Errorf("查询本周数量失败: %w", err)
}
// 签名验证通过数量
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE sign_verified = 1").Scan(&stats.Verified); err != nil {
return nil, fmt.Errorf("查询验证通过数量失败: %w", err)
}
// 签名验证未通过数量
if err := db.QueryRow("SELECT COUNT(*) FROM sms_messages WHERE sign_verified = 0").Scan(&stats.Unverified); err != nil {
return nil, fmt.Errorf("查询验证未通过数量失败: %w", err)
}
return stats, nil
}
// GetLogs 获取接收日志
func GetLogs(page, limit int) ([]models.ReceiveLog, int64, error) {
offset := (page - 1) * limit
// 查询总数
var total int64
if err := db.QueryRow("SELECT COUNT(*) FROM receive_logs").Scan(&total); err != nil {
return nil, 0, err
}
rows, err := db.Query(`
SELECT id, from_number, content, timestamp, sign, sign_valid, ip_address, status, error_message, created_at
FROM receive_logs
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`, limit, offset)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var logs []models.ReceiveLog
for rows.Next() {
var log models.ReceiveLog
err := rows.Scan(
&log.ID,
&log.FromNumber,
&log.Content,
&log.Timestamp,
&log.Sign,
&log.SignValid,
&log.IPAddress,
&log.Status,
&log.ErrorMessage,
&log.CreatedAt,
)
if err != nil {
return nil, 0, err
}
logs = append(logs, log)
}
return logs, total, nil
}
// CleanupOldMessages 清理旧消息
func CleanupOldMessages(days int) (int64, error) {
cutoff := time.Now().AddDate(0, 0, -days).Format("2006-01-02 15:04:05")
result, err := db.Exec("DELETE FROM sms_messages WHERE created_at < ?", cutoff)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// GetDB 获取数据库实例
func GetDB() *sql.DB {
return db
}
// Close 关闭数据库连接
func Close() error {
if db != nil {
return db.Close()
}
return nil
}