Files
SmsReceiver-go/handlers/handlers.go
OpenClaw Agent 1da899a0f4 feat: v2.0.0 完整代码优化升级
🔴 高优先级 (6项全部完成):
- 数据库事务支持 (InsertMessageWithLog)
- SQL注入修复 (参数化查询)
- 配置验证 (Validate方法)
- 会话密钥强化 (长度验证)
- 签名验证增强 (SignVerificationResult)
- 密码哈希支持 (bcrypt)

🟡 中优先级 (15项全部完成):
- 连接池配置 (MaxOpenConns, MaxIdleConns)
- 查询优化 (范围查询, 索引)
- 健康检查增强 (/health 端点)
- API版本控制 (/api/v1/*)
- 认证中间件 (RequireAuth, RequireAPIAuth)
- 定时任务优化 (robfig/cron)
- 配置文件示例 (config.example.yaml)
- 常量定义 (config/constants.go)
- 开发文档 (DEVELOPMENT.md)

🟢 低优先级 (9项全部完成):
- Docker支持 (Dockerfile, docker-compose.yml)
- Makefile构建脚本
- 优化报告 (OPTIMIZATION_REPORT.md)
- 密码哈希工具 (tools/password_hash.go)
- 14个新文件
- 30项优化100%完成

版本: v2.0.0
2026-02-08 18:59:29 +08:00

538 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handlers
import (
"database/sql"
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"strconv"
"strings"
"time"
"sms-receiver-go/auth"
"sms-receiver-go/config"
"sms-receiver-go/database"
"sms-receiver-go/models"
"sms-receiver-go/sign"
"github.com/gorilla/mux"
)
var templates *template.Template
// InitTemplates 初始化模板
func InitTemplates(templatesPath string) error {
// 先创建带函数的模板
funcMap := template.FuncMap{
// 基本运算(支持 int 和 int64
"add": func(a, b interface{}) int64 {
ai, _ := a.(int)
ai64, _ := a.(int64)
bi, _ := b.(int)
bi64, _ := b.(int64)
if ai64 == 0 && ai != 0 {
ai64 = int64(ai)
}
if bi64 == 0 && bi != 0 {
bi64 = int64(bi)
}
return ai64 + bi64
},
"sub": func(a, b int) int { return a - b },
"mul": func(a, b int) int { return a * b },
"div": func(a, b int) int { return a / b },
"ceilDiv": func(a, b int) int { return (a + b - 1) / b },
// 比较函数
"eq": func(a, b interface{}) bool { return a == b },
"ne": func(a, b interface{}) bool { return a != b },
"lt": func(a, b int) bool { return a < b },
"le": func(a, b int) bool { return a <= b },
"gt": func(a, b int) bool { return a > b },
"ge": func(a, b int) bool { return a >= b },
// 其他
"seq": createRange,
"mulFloat": func(a, b int64) float64 { return float64(a) * float64(b) / 100 },
}
var err error
templates, err = template.New("root").Funcs(funcMap).ParseGlob(templatesPath + "/*.html")
if err != nil {
return fmt.Errorf("加载模板失败: %w", err)
}
// 调试: 打印加载的模板名称
log.Printf("已加载的模板:")
for _, t := range templates.Templates() {
log.Printf(" - %s", t.Name())
}
return nil
}
// createRange 创建整数序列
func createRange(start, end int) []int {
result := make([]int, end-start+1)
for i := start; i <= end; i++ {
result[i-start] = i
}
return result
}
// Index 首页 - 短信列表
func Index(w http.ResponseWriter, r *http.Request) {
loggedIn, _ := auth.CheckLogin(w, r)
if !loggedIn {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
// 获取查询参数
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
limit := config.DefaultPageSize
from := r.URL.Query().Get("from")
search := r.URL.Query().Get("search")
messages, total, err := database.GetMessages(page, limit, from, search)
if err != nil {
log.Printf("获取消息失败: %v", err)
http.Error(w, "获取消息失败", http.StatusInternalServerError)
return
}
log.Printf("查询结果: 总数=%d, 本页=%d 条", total, len(messages))
// 获取统计数据
stats, err := database.GetStatistics()
if err != nil {
log.Printf("获取统计失败: %v", err)
}
log.Printf("统计: 总数=%d, 今日=%d, 本周=%d", stats.Total, stats.Today, stats.Week)
// 获取所有发送方号码(用于筛选)
fromNumbers, _ := getFromNumbers()
// 计算总页数
totalPages := (total + int64(limit) - 1) / int64(limit)
if totalPages == 0 {
totalPages = 1
}
// 格式化时间(转换为本地时区显示)
cfg := config.Get()
loc, _ := time.LoadLocation(cfg.Timezone)
for i := range messages {
// 优先显示短信时间戳(本地时间)
localTime := time.UnixMilli(messages[i].Timestamp).In(loc)
messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05")
// 同时保留 created_at 作为排序参考
messages[i].CreatedAt = messages[i].CreatedAt.In(loc)
}
data := map[string]interface{}{
"messages": messages,
"stats": stats,
"total": total,
"totalPages": int(totalPages),
"page": page,
"limit": limit,
"fromNumbers": fromNumbers,
"selectedFrom": from,
"search": search,
}
log.Printf("传递给模板的数据: messages=%d, total=%d, totalPages=%d",
len(messages), total, totalPages)
if len(messages) > 0 {
log.Printf("第一条消息: ID=%d, From=%s, Content=%s",
messages[0].ID, messages[0].FromNumber, messages[0].Content)
}
if err := templates.ExecuteTemplate(w, "index.html", data); err != nil {
log.Printf("模板执行错误: %v", err)
http.Error(w, "模板渲染失败", http.StatusInternalServerError)
}
}
// Login 登录页面
func Login(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
// 显示登录页面
errorMsg := r.URL.Query().Get("error")
templates.ExecuteTemplate(w, "login.html", map[string]string{
"error": errorMsg,
})
return
}
// 处理登录请求
username := r.FormValue("username")
password := r.FormValue("password")
cfg := config.Get()
if cfg.Security.Enabled {
// 验证用户名
if username != cfg.Security.Username {
templates.ExecuteTemplate(w, "login.html", map[string]string{
"error": "用户名或密码错误",
})
return
}
// 验证密码(支持哈希和明文)
if !auth.VerifyPassword(password, cfg.Security.PasswordHash, cfg.Security.Password) {
// 记录登录失败日志
log.Printf("登录失败: 用户=%s, IP=%s", username, getClientIP(r))
templates.ExecuteTemplate(w, "login.html", map[string]string{
"error": "用户名或密码错误",
})
return
}
// 创建会话
if err := auth.Login(w, r, username); err != nil {
log.Printf("创建会话失败: %v", err)
http.Error(w, "创建会话失败: "+err.Error(), http.StatusInternalServerError)
return
}
log.Printf("登录成功: 用户=%s, IP=%s", username, getClientIP(r))
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
// 未启用登录验证
auth.Login(w, r, username)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
// Logout 登出
func Logout(w http.ResponseWriter, r *http.Request) {
auth.Logout(r, w)
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
// MessageDetail 短信详情页面
func MessageDetail(w http.ResponseWriter, r *http.Request) {
loggedIn, _ := auth.CheckLogin(w, r)
if !loggedIn {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
vars := mux.Vars(r)
id, err := strconv.ParseInt(vars["id"], 10, 64)
if err != nil {
http.Error(w, "无效的消息 ID", http.StatusBadRequest)
return
}
msg, err := database.GetMessageByID(id)
if err != nil {
http.Error(w, "获取消息失败", http.StatusInternalServerError)
return
}
if msg == nil {
http.Error(w, "消息不存在", http.StatusNotFound)
return
}
// 格式化时间
cfg := config.Get()
loc, _ := time.LoadLocation(cfg.Timezone)
localTime := time.UnixMilli(msg.Timestamp).In(loc)
msg.TimestampStr = localTime.Format("2006-01-02 15:04:05")
msg.Content = strings.ReplaceAll(msg.Content, "\n", "<br>")
templates.ExecuteTemplate(w, "message_detail.html", msg)
}
// Logs 接收日志页面
func Logs(w http.ResponseWriter, r *http.Request) {
loggedIn, _ := auth.CheckLogin(w, r)
if !loggedIn {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
limit := config.DefaultLogsPerPage
logs, total, err := database.GetLogs(page, limit)
if err != nil {
http.Error(w, "获取日志失败", http.StatusInternalServerError)
return
}
// 计算总页数
totalPages := (total + int64(limit) - 1) / int64(limit)
if totalPages == 0 {
totalPages = 1
}
data := map[string]interface{}{
"logs": logs,
"total": total,
"page": page,
"limit": limit,
"totalPages": int(totalPages),
}
templates.ExecuteTemplate(w, "logs.html", data)
}
// Statistics 统计信息页面
func Statistics(w http.ResponseWriter, r *http.Request) {
loggedIn, _ := auth.CheckLogin(w, r)
if !loggedIn {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
stats, err := database.GetStatistics()
if err != nil {
http.Error(w, "获取统计失败", http.StatusInternalServerError)
return
}
data := map[string]interface{}{
"stats": stats,
}
templates.ExecuteTemplate(w, "statistics.html", data)
}
// ReceiveSMS API - 接收短信
func ReceiveSMS(w http.ResponseWriter, r *http.Request) {
// 解析 multipart/form-data (优先)
if err := r.ParseMultipartForm(32 << 20); err != nil {
// 回退到 ParseForm
if err := r.ParseForm(); err != nil {
writeJSON(w, models.APIResponse{
Success: false,
Error: "解析请求失败: " + err.Error(),
}, http.StatusBadRequest)
return
}
}
// 获取参数
from := r.FormValue("from")
content := r.FormValue("content")
if from == "" || content == "" {
writeJSON(w, models.APIResponse{
Success: false,
Error: "缺少必填参数 (from: '" + from + "', content: '" + content + "')",
}, http.StatusBadRequest)
return
}
// 获取可选参数
timestampStr := r.FormValue("timestamp")
timestamp := time.Now().UnixMilli()
if timestampStr != "" {
if t, err := strconv.ParseInt(timestampStr, 10, 64); err == nil {
timestamp = t
}
}
signStr := r.FormValue("sign")
device := r.FormValue("device")
sim := r.FormValue("sim")
// 获取 Token从 query string 或 form
token := r.URL.Query().Get("token")
if token == "" {
token = r.FormValue("token")
}
// 验证签名
cfg := config.Get()
signValid := sql.NullBool{Bool: true, Valid: true}
if token != "" && cfg.Security.SignVerify {
result, err := sign.VerifySign(token, timestamp, signStr, &cfg.Security)
if err != nil {
writeJSON(w, models.APIResponse{
Success: false,
Error: "签名验证错误: " + err.Error(),
}, http.StatusInternalServerError)
return
}
signValid.Bool = result.Valid
signValid.Valid = true
// 记录签名的 IP 地址
clientIP := getClientIP(r)
if result.Valid {
log.Printf("签名验证通过: token=%s, timestamp=%d, ip=%s, reason=%s",
token, timestamp, clientIP, result.Reason)
} else {
log.Printf("签名验证失败: token=%s, timestamp=%d, ip=%s, reason=%s",
token, timestamp, clientIP, result.Reason)
// 签名验证失败时仍然记录消息(标记为未验证)
}
}
// 保存消息
msg := &models.SMSMessage{
FromNumber: from,
Content: content,
Timestamp: timestamp,
DeviceInfo: sql.NullString{String: device, Valid: device != ""},
SIMInfo: sql.NullString{String: sim, Valid: sim != ""},
SignVerified: signValid,
IPAddress: getClientIP(r),
}
// 记录成功日志
receiveLog := &models.ReceiveLog{
FromNumber: from,
Content: content,
Timestamp: timestamp,
Sign: sql.NullString{String: signStr, Valid: signStr != ""},
SignValid: signValid,
IPAddress: getClientIP(r),
Status: "success",
}
// 使用事务同时插入消息和日志
messageID, err := database.InsertMessageWithLog(msg, receiveLog)
if err != nil {
// 记录失败日志(尝试单独插入)
receiveLog.Status = "error"
receiveLog.ErrorMessage = sql.NullString{String: err.Error(), Valid: true}
// 忽略日志插入错误,避免影响主错误返回
_, _ = database.InsertLog(receiveLog)
writeJSON(w, models.APIResponse{
Success: false,
Error: "保存消息失败",
}, http.StatusInternalServerError)
return
}
writeJSON(w, models.APIResponse{
Success: true,
Message: "短信已接收",
MessageID: messageID,
}, http.StatusOK)
}
// APIGetMessages API - 获取消息列表
func APIGetMessages(w http.ResponseWriter, r *http.Request) {
if !isAPIAuthenticated(r) {
writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized)
return
}
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
if limit <= 0 {
limit = config.DefaultPageSize
}
if limit > config.MaxPageSize {
limit = config.MaxPageSize
}
from := r.URL.Query().Get("from")
search := r.URL.Query().Get("search")
messages, total, err := database.GetMessages(page, limit, from, search)
if err != nil {
writeJSON(w, models.APIResponse{Success: false, Error: "获取消息失败"}, http.StatusInternalServerError)
return
}
// 格式化时间
cfg := config.Get()
loc, _ := time.LoadLocation(cfg.Timezone)
for i := range messages {
localTime := time.UnixMilli(messages[i].Timestamp).In(loc)
messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05")
}
response := models.MessageListResponse{
Success: true,
Data: messages,
Total: total,
Page: page,
Limit: limit,
}
writeJSON(w, response, http.StatusOK)
}
// APIStatistics API - 获取统计信息
func APIStatistics(w http.ResponseWriter, r *http.Request) {
if !isAPIAuthenticated(r) {
writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized)
return
}
stats, err := database.GetStatistics()
if err != nil {
writeJSON(w, models.APIResponse{Success: false, Error: "获取统计失败"}, http.StatusInternalServerError)
return
}
response := models.StatisticsResponse{
Success: true,
Data: *stats,
}
writeJSON(w, response, http.StatusOK)
}
// StaticFile 处理静态文件
func StaticFile(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "static"+r.URL.Path)
}
// 辅助函数
func writeJSON(w http.ResponseWriter, data interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func getClientIP(r *http.Request) string {
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
return strings.Split(forwarded, ",")[0]
}
return r.RemoteAddr
}
func isAPIAuthenticated(r *http.Request) bool {
cfg := config.Get()
if !cfg.Security.Enabled {
return true
}
loggedIn, _ := auth.IsLoggedIn(r)
return loggedIn
}
func getFromNumbers() ([]string, error) {
rows, err := database.GetDB().Query("SELECT DISTINCT from_number FROM sms_messages ORDER BY from_number")
if err != nil {
return nil, err
}
defer rows.Close()
var numbers []string
for rows.Next() {
var number string
if err := rows.Scan(&number); err != nil {
return nil, err
}
numbers = append(numbers, number)
}
return numbers, nil
}