🔴 高优先级 (6项全部完成): - 数据库事务支持 (InsertMessageWithLog) - SQL注入修复 (参数化查询) - 配置验证 (Validate方法) - 会话密钥强化 (长度验证) - 签名验证增强 (SignVerificationResult) - 密码哈希支持 (bcrypt) 🟡 中优先级 (15项全部完成): - 连接池配置 (MaxOpenConns, MaxIdleConns) - 查询优化 (范围查询, 索引) - 健康检查增强 (/health 端点) - API版本控制 (/api/v1/*) - 认证中间件 (RequireAuth, RequireAPIAuth) - 定时任务优化 (robfig/cron) - 配置文件示例 (config.example.yaml) - 常量定义 (config/constants.go) - 开发文档 (DEVELOPMENT.md) 🟢 低优先级 (9项全部完成): - Docker支持 (Dockerfile, docker-compose.yml) - Makefile构建脚本 - 优化报告 (OPTIMIZATION_REPORT.md) - 密码哈希工具 (tools/password_hash.go) - 14个新文件 - 30项优化100%完成 版本: v2.0.0
538 lines
14 KiB
Go
538 lines
14 KiB
Go
package handlers
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"html/template"
|
||
"log"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"sms-receiver-go/auth"
|
||
"sms-receiver-go/config"
|
||
"sms-receiver-go/database"
|
||
"sms-receiver-go/models"
|
||
"sms-receiver-go/sign"
|
||
|
||
"github.com/gorilla/mux"
|
||
)
|
||
|
||
var templates *template.Template
|
||
|
||
// InitTemplates 初始化模板
|
||
func InitTemplates(templatesPath string) error {
|
||
// 先创建带函数的模板
|
||
funcMap := template.FuncMap{
|
||
// 基本运算(支持 int 和 int64)
|
||
"add": func(a, b interface{}) int64 {
|
||
ai, _ := a.(int)
|
||
ai64, _ := a.(int64)
|
||
bi, _ := b.(int)
|
||
bi64, _ := b.(int64)
|
||
if ai64 == 0 && ai != 0 {
|
||
ai64 = int64(ai)
|
||
}
|
||
if bi64 == 0 && bi != 0 {
|
||
bi64 = int64(bi)
|
||
}
|
||
return ai64 + bi64
|
||
},
|
||
"sub": func(a, b int) int { return a - b },
|
||
"mul": func(a, b int) int { return a * b },
|
||
"div": func(a, b int) int { return a / b },
|
||
"ceilDiv": func(a, b int) int { return (a + b - 1) / b },
|
||
// 比较函数
|
||
"eq": func(a, b interface{}) bool { return a == b },
|
||
"ne": func(a, b interface{}) bool { return a != b },
|
||
"lt": func(a, b int) bool { return a < b },
|
||
"le": func(a, b int) bool { return a <= b },
|
||
"gt": func(a, b int) bool { return a > b },
|
||
"ge": func(a, b int) bool { return a >= b },
|
||
// 其他
|
||
"seq": createRange,
|
||
"mulFloat": func(a, b int64) float64 { return float64(a) * float64(b) / 100 },
|
||
}
|
||
|
||
var err error
|
||
templates, err = template.New("root").Funcs(funcMap).ParseGlob(templatesPath + "/*.html")
|
||
if err != nil {
|
||
return fmt.Errorf("加载模板失败: %w", err)
|
||
}
|
||
|
||
// 调试: 打印加载的模板名称
|
||
log.Printf("已加载的模板:")
|
||
for _, t := range templates.Templates() {
|
||
log.Printf(" - %s", t.Name())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// createRange 创建整数序列
|
||
func createRange(start, end int) []int {
|
||
result := make([]int, end-start+1)
|
||
for i := start; i <= end; i++ {
|
||
result[i-start] = i
|
||
}
|
||
return result
|
||
}
|
||
|
||
// Index 首页 - 短信列表
|
||
func Index(w http.ResponseWriter, r *http.Request) {
|
||
loggedIn, _ := auth.CheckLogin(w, r)
|
||
if !loggedIn {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// 获取查询参数
|
||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
limit := config.DefaultPageSize
|
||
from := r.URL.Query().Get("from")
|
||
search := r.URL.Query().Get("search")
|
||
|
||
messages, total, err := database.GetMessages(page, limit, from, search)
|
||
if err != nil {
|
||
log.Printf("获取消息失败: %v", err)
|
||
http.Error(w, "获取消息失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
log.Printf("查询结果: 总数=%d, 本页=%d 条", total, len(messages))
|
||
|
||
// 获取统计数据
|
||
stats, err := database.GetStatistics()
|
||
if err != nil {
|
||
log.Printf("获取统计失败: %v", err)
|
||
}
|
||
log.Printf("统计: 总数=%d, 今日=%d, 本周=%d", stats.Total, stats.Today, stats.Week)
|
||
// 获取所有发送方号码(用于筛选)
|
||
fromNumbers, _ := getFromNumbers()
|
||
|
||
// 计算总页数
|
||
totalPages := (total + int64(limit) - 1) / int64(limit)
|
||
if totalPages == 0 {
|
||
totalPages = 1
|
||
}
|
||
|
||
// 格式化时间(转换为本地时区显示)
|
||
cfg := config.Get()
|
||
loc, _ := time.LoadLocation(cfg.Timezone)
|
||
for i := range messages {
|
||
// 优先显示短信时间戳(本地时间)
|
||
localTime := time.UnixMilli(messages[i].Timestamp).In(loc)
|
||
messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05")
|
||
// 同时保留 created_at 作为排序参考
|
||
messages[i].CreatedAt = messages[i].CreatedAt.In(loc)
|
||
}
|
||
|
||
data := map[string]interface{}{
|
||
"messages": messages,
|
||
"stats": stats,
|
||
"total": total,
|
||
"totalPages": int(totalPages),
|
||
"page": page,
|
||
"limit": limit,
|
||
"fromNumbers": fromNumbers,
|
||
"selectedFrom": from,
|
||
"search": search,
|
||
}
|
||
|
||
log.Printf("传递给模板的数据: messages=%d, total=%d, totalPages=%d",
|
||
len(messages), total, totalPages)
|
||
if len(messages) > 0 {
|
||
log.Printf("第一条消息: ID=%d, From=%s, Content=%s",
|
||
messages[0].ID, messages[0].FromNumber, messages[0].Content)
|
||
}
|
||
|
||
if err := templates.ExecuteTemplate(w, "index.html", data); err != nil {
|
||
log.Printf("模板执行错误: %v", err)
|
||
http.Error(w, "模板渲染失败", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
// Login 登录页面
|
||
func Login(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == http.MethodGet {
|
||
// 显示登录页面
|
||
errorMsg := r.URL.Query().Get("error")
|
||
templates.ExecuteTemplate(w, "login.html", map[string]string{
|
||
"error": errorMsg,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 处理登录请求
|
||
username := r.FormValue("username")
|
||
password := r.FormValue("password")
|
||
|
||
cfg := config.Get()
|
||
if cfg.Security.Enabled {
|
||
// 验证用户名
|
||
if username != cfg.Security.Username {
|
||
templates.ExecuteTemplate(w, "login.html", map[string]string{
|
||
"error": "用户名或密码错误",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 验证密码(支持哈希和明文)
|
||
if !auth.VerifyPassword(password, cfg.Security.PasswordHash, cfg.Security.Password) {
|
||
// 记录登录失败日志
|
||
log.Printf("登录失败: 用户=%s, IP=%s", username, getClientIP(r))
|
||
templates.ExecuteTemplate(w, "login.html", map[string]string{
|
||
"error": "用户名或密码错误",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 创建会话
|
||
if err := auth.Login(w, r, username); err != nil {
|
||
log.Printf("创建会话失败: %v", err)
|
||
http.Error(w, "创建会话失败: "+err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
log.Printf("登录成功: 用户=%s, IP=%s", username, getClientIP(r))
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// 未启用登录验证
|
||
auth.Login(w, r, username)
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
}
|
||
|
||
// Logout 登出
|
||
func Logout(w http.ResponseWriter, r *http.Request) {
|
||
auth.Logout(r, w)
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
}
|
||
|
||
// MessageDetail 短信详情页面
|
||
func MessageDetail(w http.ResponseWriter, r *http.Request) {
|
||
loggedIn, _ := auth.CheckLogin(w, r)
|
||
if !loggedIn {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
vars := mux.Vars(r)
|
||
id, err := strconv.ParseInt(vars["id"], 10, 64)
|
||
if err != nil {
|
||
http.Error(w, "无效的消息 ID", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
msg, err := database.GetMessageByID(id)
|
||
if err != nil {
|
||
http.Error(w, "获取消息失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
if msg == nil {
|
||
http.Error(w, "消息不存在", http.StatusNotFound)
|
||
return
|
||
}
|
||
|
||
// 格式化时间
|
||
cfg := config.Get()
|
||
loc, _ := time.LoadLocation(cfg.Timezone)
|
||
localTime := time.UnixMilli(msg.Timestamp).In(loc)
|
||
msg.TimestampStr = localTime.Format("2006-01-02 15:04:05")
|
||
msg.Content = strings.ReplaceAll(msg.Content, "\n", "<br>")
|
||
|
||
templates.ExecuteTemplate(w, "message_detail.html", msg)
|
||
}
|
||
|
||
// Logs 接收日志页面
|
||
func Logs(w http.ResponseWriter, r *http.Request) {
|
||
loggedIn, _ := auth.CheckLogin(w, r)
|
||
if !loggedIn {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
limit := config.DefaultLogsPerPage
|
||
|
||
logs, total, err := database.GetLogs(page, limit)
|
||
if err != nil {
|
||
http.Error(w, "获取日志失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 计算总页数
|
||
totalPages := (total + int64(limit) - 1) / int64(limit)
|
||
if totalPages == 0 {
|
||
totalPages = 1
|
||
}
|
||
|
||
data := map[string]interface{}{
|
||
"logs": logs,
|
||
"total": total,
|
||
"page": page,
|
||
"limit": limit,
|
||
"totalPages": int(totalPages),
|
||
}
|
||
|
||
templates.ExecuteTemplate(w, "logs.html", data)
|
||
}
|
||
|
||
// Statistics 统计信息页面
|
||
func Statistics(w http.ResponseWriter, r *http.Request) {
|
||
loggedIn, _ := auth.CheckLogin(w, r)
|
||
if !loggedIn {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
stats, err := database.GetStatistics()
|
||
if err != nil {
|
||
http.Error(w, "获取统计失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
data := map[string]interface{}{
|
||
"stats": stats,
|
||
}
|
||
|
||
templates.ExecuteTemplate(w, "statistics.html", data)
|
||
}
|
||
|
||
// ReceiveSMS API - 接收短信
|
||
func ReceiveSMS(w http.ResponseWriter, r *http.Request) {
|
||
// 解析 multipart/form-data (优先)
|
||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||
// 回退到 ParseForm
|
||
if err := r.ParseForm(); err != nil {
|
||
writeJSON(w, models.APIResponse{
|
||
Success: false,
|
||
Error: "解析请求失败: " + err.Error(),
|
||
}, http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 获取参数
|
||
from := r.FormValue("from")
|
||
content := r.FormValue("content")
|
||
|
||
if from == "" || content == "" {
|
||
writeJSON(w, models.APIResponse{
|
||
Success: false,
|
||
Error: "缺少必填参数 (from: '" + from + "', content: '" + content + "')",
|
||
}, http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取可选参数
|
||
timestampStr := r.FormValue("timestamp")
|
||
timestamp := time.Now().UnixMilli()
|
||
if timestampStr != "" {
|
||
if t, err := strconv.ParseInt(timestampStr, 10, 64); err == nil {
|
||
timestamp = t
|
||
}
|
||
}
|
||
|
||
signStr := r.FormValue("sign")
|
||
device := r.FormValue("device")
|
||
sim := r.FormValue("sim")
|
||
|
||
// 获取 Token(从 query string 或 form)
|
||
token := r.URL.Query().Get("token")
|
||
if token == "" {
|
||
token = r.FormValue("token")
|
||
}
|
||
|
||
// 验证签名
|
||
cfg := config.Get()
|
||
signValid := sql.NullBool{Bool: true, Valid: true}
|
||
if token != "" && cfg.Security.SignVerify {
|
||
result, err := sign.VerifySign(token, timestamp, signStr, &cfg.Security)
|
||
if err != nil {
|
||
writeJSON(w, models.APIResponse{
|
||
Success: false,
|
||
Error: "签名验证错误: " + err.Error(),
|
||
}, http.StatusInternalServerError)
|
||
return
|
||
}
|
||
signValid.Bool = result.Valid
|
||
signValid.Valid = true
|
||
|
||
// 记录签名的 IP 地址
|
||
clientIP := getClientIP(r)
|
||
if result.Valid {
|
||
log.Printf("签名验证通过: token=%s, timestamp=%d, ip=%s, reason=%s",
|
||
token, timestamp, clientIP, result.Reason)
|
||
} else {
|
||
log.Printf("签名验证失败: token=%s, timestamp=%d, ip=%s, reason=%s",
|
||
token, timestamp, clientIP, result.Reason)
|
||
// 签名验证失败时仍然记录消息(标记为未验证)
|
||
}
|
||
}
|
||
|
||
// 保存消息
|
||
msg := &models.SMSMessage{
|
||
FromNumber: from,
|
||
Content: content,
|
||
Timestamp: timestamp,
|
||
DeviceInfo: sql.NullString{String: device, Valid: device != ""},
|
||
SIMInfo: sql.NullString{String: sim, Valid: sim != ""},
|
||
SignVerified: signValid,
|
||
IPAddress: getClientIP(r),
|
||
}
|
||
|
||
// 记录成功日志
|
||
receiveLog := &models.ReceiveLog{
|
||
FromNumber: from,
|
||
Content: content,
|
||
Timestamp: timestamp,
|
||
Sign: sql.NullString{String: signStr, Valid: signStr != ""},
|
||
SignValid: signValid,
|
||
IPAddress: getClientIP(r),
|
||
Status: "success",
|
||
}
|
||
|
||
// 使用事务同时插入消息和日志
|
||
messageID, err := database.InsertMessageWithLog(msg, receiveLog)
|
||
if err != nil {
|
||
// 记录失败日志(尝试单独插入)
|
||
receiveLog.Status = "error"
|
||
receiveLog.ErrorMessage = sql.NullString{String: err.Error(), Valid: true}
|
||
// 忽略日志插入错误,避免影响主错误返回
|
||
_, _ = database.InsertLog(receiveLog)
|
||
|
||
writeJSON(w, models.APIResponse{
|
||
Success: false,
|
||
Error: "保存消息失败",
|
||
}, http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
writeJSON(w, models.APIResponse{
|
||
Success: true,
|
||
Message: "短信已接收",
|
||
MessageID: messageID,
|
||
}, http.StatusOK)
|
||
}
|
||
|
||
// APIGetMessages API - 获取消息列表
|
||
func APIGetMessages(w http.ResponseWriter, r *http.Request) {
|
||
if !isAPIAuthenticated(r) {
|
||
writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||
if limit <= 0 {
|
||
limit = config.DefaultPageSize
|
||
}
|
||
if limit > config.MaxPageSize {
|
||
limit = config.MaxPageSize
|
||
}
|
||
from := r.URL.Query().Get("from")
|
||
search := r.URL.Query().Get("search")
|
||
|
||
messages, total, err := database.GetMessages(page, limit, from, search)
|
||
if err != nil {
|
||
writeJSON(w, models.APIResponse{Success: false, Error: "获取消息失败"}, http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 格式化时间
|
||
cfg := config.Get()
|
||
loc, _ := time.LoadLocation(cfg.Timezone)
|
||
for i := range messages {
|
||
localTime := time.UnixMilli(messages[i].Timestamp).In(loc)
|
||
messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05")
|
||
}
|
||
|
||
response := models.MessageListResponse{
|
||
Success: true,
|
||
Data: messages,
|
||
Total: total,
|
||
Page: page,
|
||
Limit: limit,
|
||
}
|
||
writeJSON(w, response, http.StatusOK)
|
||
}
|
||
|
||
// APIStatistics API - 获取统计信息
|
||
func APIStatistics(w http.ResponseWriter, r *http.Request) {
|
||
if !isAPIAuthenticated(r) {
|
||
writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
stats, err := database.GetStatistics()
|
||
if err != nil {
|
||
writeJSON(w, models.APIResponse{Success: false, Error: "获取统计失败"}, http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
response := models.StatisticsResponse{
|
||
Success: true,
|
||
Data: *stats,
|
||
}
|
||
writeJSON(w, response, http.StatusOK)
|
||
}
|
||
|
||
// StaticFile 处理静态文件
|
||
func StaticFile(w http.ResponseWriter, r *http.Request) {
|
||
http.ServeFile(w, r, "static"+r.URL.Path)
|
||
}
|
||
|
||
// 辅助函数
|
||
|
||
func writeJSON(w http.ResponseWriter, data interface{}, status int) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(status)
|
||
json.NewEncoder(w).Encode(data)
|
||
}
|
||
|
||
func getClientIP(r *http.Request) string {
|
||
forwarded := r.Header.Get("X-Forwarded-For")
|
||
if forwarded != "" {
|
||
return strings.Split(forwarded, ",")[0]
|
||
}
|
||
return r.RemoteAddr
|
||
}
|
||
|
||
func isAPIAuthenticated(r *http.Request) bool {
|
||
cfg := config.Get()
|
||
if !cfg.Security.Enabled {
|
||
return true
|
||
}
|
||
loggedIn, _ := auth.IsLoggedIn(r)
|
||
return loggedIn
|
||
}
|
||
|
||
func getFromNumbers() ([]string, error) {
|
||
rows, err := database.GetDB().Query("SELECT DISTINCT from_number FROM sms_messages ORDER BY from_number")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var numbers []string
|
||
for rows.Next() {
|
||
var number string
|
||
if err := rows.Scan(&number); err != nil {
|
||
return nil, err
|
||
}
|
||
numbers = append(numbers, number)
|
||
}
|
||
return numbers, nil
|
||
}
|