Files
tao-memory-mcp/main.go
2026-03-14 16:16:32 +08:00

296 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
)
// --- MCP Protocol Types ---
type MCPRequest struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type MCPResponse struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id,omitempty"`
Result any `json:"result,omitempty"`
Error interface{} `json:"error,omitempty"`
}
type MCPContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
func sendMCPResponse(id any, result any) MCPResponse {
return MCPResponse{JSONRPC: "2.0", ID: id, Result: result}
}
func (s *TaoServer) dispatchMCP(token string, req MCPRequest) {
var resp MCPResponse
resp.JSONRPC = "2.0"
resp.ID = req.ID
switch req.Method {
case "initialize":
resp.Result = map[string]interface{}{
"protocolVersion": "2025-06-18",
"capabilities": map[string]interface{}{
"tools": map[string]interface{}{"listChanged": false},
"resources": map[string]interface{}{"listChanged": false},
"prompts": map[string]interface{}{"listChanged": false},
"logging": map[string]interface{}{},
},
"serverInfo": map[string]string{
"name": "Tao-Memory-Server",
"version": "1.2.0",
},
}
case "notifications/initialized":
log.Printf("[MCP Notify] initialized from %s", token)
return
case "tools/list":
resp.Result = map[string]interface{}{
"tools": buildToolList(),
}
case "tools/call":
var params struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
_ = json.Unmarshal(req.Params, &params)
if tool, ok := ToolRegistry[params.Name]; ok {
result, err := tool.Handler(params.Arguments)
if err != nil {
resp.Result = map[string]interface{}{
"content": []MCPContent{{Type: "text", Text: "error: " + err.Error()}},
}
} else {
resp.Result = map[string]interface{}{
"content": []MCPContent{{Type: "text", Text: result}},
}
}
} else {
resp.Result = map[string]interface{}{
"content": []MCPContent{{Type: "text", Text: "error: tool not found"}},
}
}
default:
_ = s.Record("agent_action", fmt.Sprintf("执行指令: %+v", req), 2)
return
}
if token == "" {
log.Printf("[MCP Response] missing token for method=%s", req.Method)
return
}
if ch, ok := s.conns.Load(token); ok {
if b, err := json.Marshal(resp); err == nil {
ch.(chan string) <- string(b)
log.Printf("[MCP Response] sent via SSE method=%s", req.Method)
}
} else {
log.Printf("[MCP Response] no SSE channel for token=%s method=%s", token, req.Method)
}
}
func getEnv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
// --- 以简御繁:鉴权 ---
func (s *TaoServer) checkAuth(r *http.Request) bool {
token := getEnv("TAO_AUTH_TOKEN", "")
if token == "" {
return true // 未配置则不启用鉴权
}
// Header Bearer
h := r.Header.Get("Authorization")
if h == "Bearer "+token {
return true
}
// Query token
if q := r.URL.Query().Get("token"); q != "" && q == token {
return true
}
return false
}
func (s *TaoServer) requireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.WriteHeader(http.StatusOK)
return
}
if !s.checkAuth(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next(w, r)
}
}
// --- 感官 (Webhook Adapters) ---
// 适配 Gitea 的 Push Webhook
func (s *TaoServer) GiteaHandler(w http.ResponseWriter, r *http.Request) {
var payload struct {
Repository struct {
Name string `json:"name"`
} `json:"repository"`
Commits []struct {
Message string `json:"message"`
} `json:"commits"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "Bad Request", 400)
return
}
if len(payload.Commits) > 0 {
msg := payload.Commits[0].Message
summary := fmt.Sprintf("代码演化于 [%s]: %s", payload.Repository.Name, msg)
_ = s.Record("code", summary, 4)
}
w.WriteHeader(200)
}
// 适配 SmsReceiver-go 的短信推送
func (s *TaoServer) SmsHandler(w http.ResponseWriter, r *http.Request) {
var payload struct {
From string `json:"from"`
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "Bad Request", 400)
return
}
summary := fmt.Sprintf("收到信号 [%s]: %s", payload.From, payload.Content)
_ = s.Record("sms", summary, 3)
w.WriteHeader(200)
}
// --- MCP SSE ---
func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("[SSE Connect] Remote=%s URL=%s", r.RemoteAddr, r.URL.String())
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
// 告知客户端 POST 入口(按客户端拼接习惯输出)
style := getEnv("TAO_ENDPOINT_STYLE", "message")
endpoint := "mcp/message"
if style == "message" {
endpoint = "message"
}
// 若通过 query token 访问,也把 token 拼到 endpoint便于客户端无 Header
token := r.URL.Query().Get("token")
if token != "" {
if strings.Contains(endpoint, "?") {
endpoint = endpoint + "&token=" + token
} else {
endpoint = endpoint + "?token=" + token
}
}
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
flusher.Flush()
var msgChan chan string
if token != "" {
msgChan = make(chan string, 50)
s.conns.Store(token, msgChan)
defer s.conns.Delete(token)
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
fmt.Fprintf(w, ":ping\n\n")
flusher.Flush()
case msg := <-msgChan:
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
}
}
}
// --- MCP Message ---
func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes))
var req MCPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", 400)
return
}
token := r.URL.Query().Get("token")
w.WriteHeader(http.StatusAccepted)
go s.dispatchMCP(token, req)
}
// --- 主程序 (Main) ---
func main() {
server := &TaoServer{
config: Config{
MemoryRoot: getEnv("MEMORY_ROOT", "./knowledge_ocean"),
Port: getEnv("PORT", "5001"),
},
}
server.RegisterTools()
// 启动 Webhook 监听 (感知层)
http.HandleFunc("/ingest/gitea", server.requireAuth(server.GiteaHandler))
http.HandleFunc("/ingest/sms", server.requireAuth(server.SmsHandler))
// MCP SSE + Message
http.HandleFunc("/mcp/sse", server.requireAuth(server.SSEHandler))
http.HandleFunc("/mcp/message", server.requireAuth(server.MessageHandler))
fmt.Printf("Tao Memory Server 启动。道场地址: :%s\n", server.config.Port)
log.Fatal(http.ListenAndServe(":"+server.config.Port, nil))
}