优化端口探测和进度显示

This commit is contained in:
user123
2026-01-30 20:17:03 +08:00
parent 099e9b88e4
commit f9e9fd8e10
3 changed files with 119 additions and 99 deletions

Binary file not shown.

View File

@@ -23,6 +23,8 @@ import (
"time" "time"
) )
type ProgressCallback func(percent float64)
var ( var (
cachedNpmPrefix string cachedNpmPrefix string
cachedNodePath string cachedNodePath string
@@ -35,6 +37,7 @@ var (
const ( const (
downloadConcurrentThreshold int64 = 20 * 1024 * 1024 downloadConcurrentThreshold int64 = 20 * 1024 * 1024
downloadConcurrentParts = 4 downloadConcurrentParts = 4
DefaultGatewayPort = 18789
) )
const gitProxyEnv = "GIT_PROXY" const gitProxyEnv = "GIT_PROXY"
@@ -414,7 +417,7 @@ func ConfigureGitProxy() error {
} }
// downloadFile 下载文件 // downloadFile 下载文件
func downloadFile(url, dest, expectedSHA256 string) error { func downloadFile(url, dest, expectedSHA256 string, onProgress ProgressCallback) error {
if ok, err := verifyFileSHA256(dest, expectedSHA256); err == nil && ok { if ok, err := verifyFileSHA256(dest, expectedSHA256); err == nil && ok {
return nil return nil
} }
@@ -432,7 +435,7 @@ func downloadFile(url, dest, expectedSHA256 string) error {
return err return err
} }
if err := downloadWithResume(url, partPath, size, acceptRanges); err != nil { if err := downloadWithResume(url, partPath, size, acceptRanges, onProgress); err != nil {
return err return err
} }
@@ -448,19 +451,19 @@ func downloadFile(url, dest, expectedSHA256 string) error {
return os.Rename(partPath, dest) return os.Rename(partPath, dest)
} }
func downloadWithResume(url, dest string, size int64, acceptRanges bool) error { func downloadWithResume(url, dest string, size int64, acceptRanges bool, onProgress ProgressCallback) error {
if size > 0 && acceptRanges { if size > 0 && acceptRanges {
if info, err := os.Stat(dest); err == nil && info.Size() > 0 && info.Size() < size { if info, err := os.Stat(dest); err == nil && info.Size() > 0 && info.Size() < size {
return downloadRange(url, dest, info.Size(), size-1, size) return downloadRange(url, dest, info.Size(), size-1, size, onProgress)
} }
if size >= downloadConcurrentThreshold { if size >= downloadConcurrentThreshold {
return downloadConcurrent(url, dest, size, downloadConcurrentParts) return downloadConcurrent(url, dest, size, downloadConcurrentParts, onProgress)
} }
} }
return downloadRange(url, dest, 0, -1, size) return downloadRange(url, dest, 0, -1, size, onProgress)
} }
func downloadRange(url, dest string, start, end, total int64) error { func downloadRange(url, dest string, start, end, total int64, onProgress ProgressCallback) error {
out, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0644) out, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return fmt.Errorf("创建文件失败: %v", err) return fmt.Errorf("创建文件失败: %v", err)
@@ -502,20 +505,27 @@ func downloadRange(url, dest string, start, end, total int64) error {
if total <= 0 && resp.ContentLength > 0 { if total <= 0 && resp.ContentLength > 0 {
total = start + resp.ContentLength total = start + resp.ContentLength
} }
progress := newProgressReporter(total, start)
progress.Start() reader := &countingReader{
reader := &countingReader{r: resp.Body, written: progress.written} r: resp.Body,
update: func(n int64) {
if total > 0 && onProgress != nil {
current := start + n
percent := float64(current) * 100 / float64(total)
onProgress(percent)
}
},
}
if _, err = io.Copy(out, reader); err != nil { if _, err = io.Copy(out, reader); err != nil {
progress.Stop()
return fmt.Errorf("写入文件失败: %v", err) return fmt.Errorf("写入文件失败: %v", err)
} }
progress.Stop()
return nil return nil
} }
func downloadConcurrent(url, dest string, size int64, parts int) error { func downloadConcurrent(url, dest string, size int64, parts int, onProgress ProgressCallback) error {
if parts < 2 { if parts < 2 {
return downloadRange(url, dest, 0, -1, size) return downloadRange(url, dest, 0, -1, size, onProgress)
} }
out, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) out, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
@@ -529,8 +539,9 @@ func downloadConcurrent(url, dest string, size int64, parts int) error {
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, parts) errCh := make(chan error, parts)
progress := newProgressReporter(size, 0)
progress.Start() // Shared progress tracking
var written int64
partSize := size / int64(parts) partSize := size / int64(parts)
for i := 0; i < parts; i++ { for i := 0; i < parts; i++ {
@@ -560,7 +571,19 @@ func downloadConcurrent(url, dest string, size int64, parts int) error {
errCh <- fmt.Errorf("分段下载失败,状态码: %d", resp.StatusCode) errCh <- fmt.Errorf("分段下载失败,状态码: %d", resp.StatusCode)
return return
} }
writer := &writeAtWriter{file: out, offset: s, written: progress.written}
writer := &writeAtWriter{
file: out,
offset: s,
update: func(n int) {
newVal := atomic.AddInt64(&written, int64(n))
if onProgress != nil {
percent := float64(newVal) * 100 / float64(size)
onProgress(percent)
}
},
}
if _, err := io.Copy(writer, resp.Body); err != nil { if _, err := io.Copy(writer, resp.Body); err != nil {
errCh <- fmt.Errorf("写入文件失败: %v", err) errCh <- fmt.Errorf("写入文件失败: %v", err)
return return
@@ -571,7 +594,6 @@ func downloadConcurrent(url, dest string, size int64, parts int) error {
wg.Wait() wg.Wait()
close(errCh) close(errCh)
out.Close() out.Close()
progress.Stop()
for err := range errCh { for err := range errCh {
if err != nil { if err != nil {
@@ -584,89 +606,35 @@ func downloadConcurrent(url, dest string, size int64, parts int) error {
type writeAtWriter struct { type writeAtWriter struct {
file *os.File file *os.File
offset int64 offset int64
written *int64 update func(int)
} }
func (w *writeAtWriter) Write(p []byte) (int, error) { func (w *writeAtWriter) Write(p []byte) (int, error) {
n, err := w.file.WriteAt(p, w.offset) n, err := w.file.WriteAt(p, w.offset)
w.offset += int64(n) w.offset += int64(n)
if w.written != nil && n > 0 { if w.update != nil && n > 0 {
atomic.AddInt64(w.written, int64(n)) w.update(n)
} }
return n, err return n, err
} }
type countingReader struct { type countingReader struct {
r io.Reader r io.Reader
written *int64 totalRead int64
update func(int64)
} }
func (c *countingReader) Read(p []byte) (int, error) { func (c *countingReader) Read(p []byte) (int, error) {
n, err := c.r.Read(p) n, err := c.r.Read(p)
if n > 0 && c.written != nil { if n > 0 {
atomic.AddInt64(c.written, int64(n)) c.totalRead += int64(n)
if c.update != nil {
c.update(c.totalRead)
}
} }
return n, err return n, err
} }
type progressReporter struct {
total int64
written *int64
done chan struct{}
once sync.Once
}
func newProgressReporter(total, initial int64) *progressReporter {
current := initial
return &progressReporter{
total: total,
written: &current,
done: make(chan struct{}),
}
}
func (p *progressReporter) Start() {
if p == nil || p.total <= 0 {
return
}
p.print()
go func() {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.print()
case <-p.done:
p.print()
fmt.Print("\n")
return
}
}
}()
}
func (p *progressReporter) Stop() {
if p == nil || p.total <= 0 {
return
}
p.once.Do(func() {
close(p.done)
})
}
func (p *progressReporter) print() {
current := atomic.LoadInt64(p.written)
if current < 0 {
current = 0
}
if current > p.total {
current = p.total
}
percent := float64(current) * 100 / float64(p.total)
fmt.Printf("\r下载进度: %.2f%%", percent)
}
func probeRemoteFile(url string) (int64, bool, error) { func probeRemoteFile(url string) (int64, bool, error) {
client := &http.Client{Timeout: 30 * time.Second} client := &http.Client{Timeout: 30 * time.Second}
req, err := http.NewRequest("HEAD", url, nil) req, err := http.NewRequest("HEAD", url, nil)
@@ -747,7 +715,7 @@ func fileSHA256(path string) (string, error) {
} }
// InstallNode 安装 Node.js // InstallNode 安装 Node.js
func InstallNode() error { func InstallNode(onProgress ProgressCallback) error {
if _, ok := CheckNode(); ok { if _, ok := CheckNode(); ok {
return nil return nil
} }
@@ -756,12 +724,10 @@ func InstallNode() error {
tempDir := os.TempDir() tempDir := os.TempDir()
msiPath := filepath.Join(tempDir, "node-v24.13.0-x64.msi") msiPath := filepath.Join(tempDir, "node-v24.13.0-x64.msi")
if err := downloadFile(msiUrl, msiPath, nodeMsiSHA256); err != nil { if err := downloadFile(msiUrl, msiPath, nodeMsiSHA256, onProgress); err != nil {
return err return err
} }
fmt.Println("正在安装 Node.js (可能需要管理员权限)...")
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
installCmd := exec.Command("msiexec", "/i", msiPath, "/qn") installCmd := exec.Command("msiexec", "/i", msiPath, "/qn")
output, err := installCmd.CombinedOutput() output, err := installCmd.CombinedOutput()
@@ -790,7 +756,7 @@ func InstallNode() error {
} }
// InstallGit 安装 Git // InstallGit 安装 Git
func InstallGit() error { func InstallGit(onProgress ProgressCallback) error {
if _, ok := CheckGit(); ok { if _, ok := CheckGit(); ok {
return nil return nil
} }
@@ -799,12 +765,10 @@ func InstallGit() error {
tempDir := os.TempDir() tempDir := os.TempDir()
exePath := filepath.Join(tempDir, "Git-2.52.0-64-bit.exe") exePath := filepath.Join(tempDir, "Git-2.52.0-64-bit.exe")
fmt.Println("正在下载 Git...") if err := downloadFile(gitUrl, exePath, gitExeSHA256, onProgress); err != nil {
if err := downloadFile(gitUrl, exePath, gitExeSHA256); err != nil {
return fmt.Errorf("git 下载失败: %v", err) return fmt.Errorf("git 下载失败: %v", err)
} }
fmt.Println("正在安装 Git (可能需要管理员权限)...")
installCmd := exec.Command(exePath, installCmd := exec.Command(exePath,
"/VERYSILENT", "/VERYSILENT",
"/NORESTART", "/NORESTART",
@@ -825,7 +789,7 @@ func InstallGit() error {
} }
// InstallOpenclawNpm 安装包 // InstallOpenclawNpm 安装包
func InstallOpenclawNpm(tag string) error { func InstallOpenclawNpm(tag string, onProgress ProgressCallback) error {
SetupNodeEnv() SetupNodeEnv()
pkgName := "openclaw" pkgName := "openclaw"
@@ -947,7 +911,7 @@ func GenerateAndWriteConfig(opts ConfigOptions) error {
Gateway: GatewayConfig{ Gateway: GatewayConfig{
Mode: "local", Mode: "local",
Bind: "loopback", Bind: "loopback",
Port: 18789, Port: DefaultGatewayPort,
Auth: &AuthConfig{ Auth: &AuthConfig{
Mode: "token", Mode: "token",
Token: token, Token: token,
@@ -1053,6 +1017,34 @@ func GetConfigPath() (string, error) {
return filepath.Join(userHome, ".openclaw", "openclaw.json"), nil return filepath.Join(userHome, ".openclaw", "openclaw.json"), nil
} }
func loadConfig() (*OpenclawConfig, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, err
}
var config OpenclawConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// GetGatewayPort 获取网关端口
func GetGatewayPort() int {
config, err := loadConfig()
if err != nil {
return DefaultGatewayPort
}
if config.Gateway.Port > 0 {
return config.Gateway.Port
}
return DefaultGatewayPort
}
// GetGatewayToken 获取 Gateway Token // GetGatewayToken 获取 Gateway Token
func GetGatewayToken() (string, error) { func GetGatewayToken() (string, error) {
configPath, err := GetConfigPath() configPath, err := GetConfigPath()
@@ -1091,7 +1083,8 @@ func StartGateway() error {
// IsGatewayRunning 检查端口 // IsGatewayRunning 检查端口
func IsGatewayRunning() bool { func IsGatewayRunning() bool {
conn, err := net.DialTimeout("tcp", "127.0.0.1:18789", 500*time.Millisecond) port := GetGatewayPort()
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
if err == nil { if err == nil {
conn.Close() conn.Close()
return true return true
@@ -1101,6 +1094,7 @@ func IsGatewayRunning() bool {
// KillGateway 停止网关 // KillGateway 停止网关
func KillGateway() error { func KillGateway() error {
port := GetGatewayPort()
cmd := exec.Command("netstat", "-ano") cmd := exec.Command("netstat", "-ano")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
out, err := cmd.Output() out, err := cmd.Output()
@@ -1110,9 +1104,10 @@ func KillGateway() error {
scanner := bufio.NewScanner(bytes.NewReader(out)) scanner := bufio.NewScanner(bytes.NewReader(out))
var pid string var pid string
portStr := fmt.Sprintf(":%d", port)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if strings.Contains(line, ":18789") && strings.Contains(line, "LISTENING") { if strings.Contains(line, portStr) && strings.Contains(line, "LISTENING") {
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) > 0 { if len(fields) > 0 {
pid = fields[len(fields)-1] pid = fields[len(fields)-1]

View File

@@ -65,6 +65,7 @@ type Model struct {
actionType ActionType actionType ActionType
spinner spinner.Model spinner spinner.Model
progressMsg string progressMsg string
progressPercent float64
actionErr error actionErr error
actionDone bool actionDone bool
@@ -107,6 +108,7 @@ type actionResultMsg struct {
type installProgressMsg struct { type installProgressMsg struct {
step string step string
percent float64
err error err error
done bool done bool
channel chan installProgressMsg channel chan installProgressMsg
@@ -236,11 +238,13 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.actionErr = msg.err m.actionErr = msg.err
m.actionDone = true m.actionDone = true
m.progressMsg = fmt.Sprintf("安装失败: %v", msg.err) m.progressMsg = fmt.Sprintf("安装失败: %v", msg.err)
m.progressPercent = 0
return m, nil return m, nil
} }
if msg.done { if msg.done {
m.actionDone = true m.actionDone = true
m.progressMsg = "安装流程完成!" m.progressMsg = "安装流程完成!"
m.progressPercent = 0
m.envRefreshActive = true m.envRefreshActive = true
m.envRefreshAttempt = 0 m.envRefreshAttempt = 0
m.envRefreshMax = 5 m.envRefreshMax = 5
@@ -248,6 +252,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, envRefreshCmd(0) return m, envRefreshCmd(0)
} }
m.progressMsg = msg.step m.progressMsg = msg.step
m.progressPercent = msg.percent
return m, waitForInstallProgress(msg.channel) return m, waitForInstallProgress(msg.channel)
case envRefreshMsg: case envRefreshMsg:
@@ -324,6 +329,7 @@ func (m Model) handleMenuSelect() (tea.Model, tea.Cmd) {
m.actionDone = false m.actionDone = false
m.actionErr = nil m.actionErr = nil
m.progressMsg = "准备安装..." m.progressMsg = "准备安装..."
m.progressPercent = 0
return m, runInstallFlowCmd() return m, runInstallFlowCmd()
case 3: // 卸载 case 3: // 卸载
m.state = StateAction m.state = StateAction
@@ -331,6 +337,7 @@ func (m Model) handleMenuSelect() (tea.Model, tea.Cmd) {
m.actionDone = false m.actionDone = false
m.actionErr = nil m.actionErr = nil
m.progressMsg = "正在卸载..." m.progressMsg = "正在卸载..."
m.progressPercent = 0
return m, runUninstallCmd return m, runUninstallCmd
case 4: // 退出 case 4: // 退出
return m, tea.Quit return m, tea.Quit
@@ -517,7 +524,8 @@ func (m Model) renderDashboard() string {
if m.gatewayOk && m.gatewayToken != "" { if m.gatewayOk && m.gatewayToken != "" {
statusRows = append(statusRows, "") statusRows = append(statusRows, "")
url := fmt.Sprintf("http://127.0.0.1:18789/?token=%s", m.gatewayToken) port := sys.GetGatewayPort()
url := fmt.Sprintf("http://127.0.0.1:%d/?token=%s", port, m.gatewayToken)
statusRows = append(statusRows, fmt.Sprintf("访问地址: %s", style.SuccessStyle.Render(url))) statusRows = append(statusRows, fmt.Sprintf("访问地址: %s", style.SuccessStyle.Render(url)))
} }
@@ -671,6 +679,17 @@ func (m Model) renderAction() string {
style.SubHeaderStyle.Render(title), style.SubHeaderStyle.Render(title),
"", "",
fmt.Sprintf("%s %s", icon, m.progressMsg), fmt.Sprintf("%s %s", icon, m.progressMsg),
)
if m.progressPercent > 0 {
content = lipgloss.JoinVertical(lipgloss.Center,
content,
fmt.Sprintf("%.0f%%", m.progressPercent),
)
}
content = lipgloss.JoinVertical(lipgloss.Center,
content,
"", "",
) )
@@ -774,14 +793,18 @@ func runInstallFlowCmd() tea.Cmd {
// 1. Install Node // 1. Install Node
ch <- installProgressMsg{step: "正在安装 Node.js...", channel: ch} ch <- installProgressMsg{step: "正在安装 Node.js...", channel: ch}
if err := sys.InstallNode(); err != nil { if err := sys.InstallNode(func(percent float64) {
ch <- installProgressMsg{step: "正在安装 Node.js...", percent: percent, channel: ch}
}); err != nil {
ch <- installProgressMsg{err: fmt.Errorf("node.js 安装失败: %v", err), channel: ch} ch <- installProgressMsg{err: fmt.Errorf("node.js 安装失败: %v", err), channel: ch}
return return
} }
// 2. Install Git // 2. Install Git
ch <- installProgressMsg{step: "正在安装 Git...", channel: ch} ch <- installProgressMsg{step: "正在安装 Git...", channel: ch}
if err := sys.InstallGit(); err != nil { if err := sys.InstallGit(func(percent float64) {
ch <- installProgressMsg{step: "正在安装 Git...", percent: percent, channel: ch}
}); err != nil {
ch <- installProgressMsg{err: fmt.Errorf("git 安装失败: %v", err), channel: ch} ch <- installProgressMsg{err: fmt.Errorf("git 安装失败: %v", err), channel: ch}
return return
} }
@@ -802,7 +825,9 @@ func runInstallFlowCmd() tea.Cmd {
// 5. Install OpenClaw // 5. Install OpenClaw
ch <- installProgressMsg{step: "正在安装 OpenClaw...", channel: ch} ch <- installProgressMsg{step: "正在安装 OpenClaw...", channel: ch}
if err := sys.InstallOpenclawNpm("latest"); err != nil { if err := sys.InstallOpenclawNpm("latest", func(percent float64) {
ch <- installProgressMsg{step: "正在安装 OpenClaw...", percent: percent, channel: ch}
}); err != nil {
ch <- installProgressMsg{err: fmt.Errorf("openclaw 安装失败: %v", err), channel: ch} ch <- installProgressMsg{err: fmt.Errorf("openclaw 安装失败: %v", err), channel: ch}
return return
} }