diff --git a/main.go b/main.go index f45e035..069dcbe 100644 --- a/main.go +++ b/main.go @@ -392,6 +392,45 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { go s.dispatchMCP(token, client, req) } +// --- MCP Unified Gateway (/mcp) --- +func (s *TaoServer) MCPUnifiedHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.WriteHeader(http.StatusOK) + return + } + if r.Method == http.MethodGet { + s.SSEHandler(w, r) + return + } + if r.Method == http.MethodPost { + accept := r.Header.Get("Accept") + if strings.Contains(accept, "text/event-stream") { + s.SSEHandler(w, r) + return + } + if r.ContentLength == 0 { + s.SSEHandler(w, r) + return + } + bodyBytes, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + if len(bodyBytes) == 0 { + s.SSEHandler(w, r) + return + } + if strings.Contains(string(bodyBytes), "\"jsonrpc\":\"2.0\"") { + s.MessageHandler(w, r) + return + } + http.Error(w, "Bad Request", 400) + return + } + http.Error(w, "Method Not Allowed", 405) +} + // --- 主程序 (Main) --- func main() { if getEnv("TAO_AUTH_TOKEN", "") == "" && !getEnvBool("TAO_ALLOW_ANON", false) { @@ -422,6 +461,8 @@ func main() { // MCP SSE + Message http.HandleFunc("/mcp/sse", server.requireAuth(server.SSEHandler)) + http.HandleFunc("/mcp", server.requireAuth(server.MCPUnifiedHandler)) + http.HandleFunc("/mcp/", server.requireAuth(server.MCPUnifiedHandler)) http.HandleFunc("/mcp/message", server.requireAuth(server.MessageHandler)) fmt.Printf("Tao Memory Server 启动。道场地址: :%s\n", server.config.Port)