// Package mux provides stream multiplexing over a single net.Conn. // // Wire format per frame: // // StreamID (4B, big-endian) // Flags (1B) // Length (2B, big-endian, max 65535) // Data (Length bytes) // // Total header = 7 bytes. // // Flags: // // 0x01 SYN — open a new stream // 0x02 FIN — close a stream // 0x04 DATA — payload data // 0x08 PING — keepalive (StreamID=0) // 0x10 PONG — keepalive response (StreamID=0) // 0x20 RST — reset/abort a stream package mux import ( "encoding/binary" "errors" "fmt" "io" "log" "net" "sync" "sync/atomic" "time" ) const ( headerSize = 7 maxPayload = 65535 FlagSYN byte = 0x01 FlagFIN byte = 0x02 FlagDATA byte = 0x04 FlagPING byte = 0x08 FlagPONG byte = 0x10 FlagRST byte = 0x20 defaultWindowSize = 256 * 1024 // 256KB per stream receive buffer pingInterval = 15 * time.Second pingTimeout = 10 * time.Second acceptBacklog = 64 ) var ( ErrSessionClosed = errors.New("mux: session closed") ErrStreamClosed = errors.New("mux: stream closed") ErrStreamReset = errors.New("mux: stream reset by peer") ErrTimeout = errors.New("mux: timeout") ErrAcceptBacklog = errors.New("mux: accept backlog full") ) // ─── Session ─── // A Session multiplexes many Streams over a single underlying net.Conn. type Session struct { conn net.Conn streams map[uint32]*Stream mu sync.RWMutex nextID uint32 // client uses odd, server uses even isServer bool acceptCh chan *Stream writeMu sync.Mutex // serialize frame writes closed int32 quit chan struct{} once sync.Once // stats BytesSent int64 BytesReceived int64 } // NewSession wraps a net.Conn as a mux session. // isServer determines stream ID allocation: server=even, client=odd. func NewSession(conn net.Conn, isServer bool) *Session { s := &Session{ conn: conn, streams: make(map[uint32]*Stream), acceptCh: make(chan *Stream, acceptBacklog), quit: make(chan struct{}), isServer: isServer, } if isServer { s.nextID = 2 } else { s.nextID = 1 } go s.readLoop() go s.pingLoop() return s } // Open creates a new outbound stream. func (s *Session) Open() (*Stream, error) { if s.IsClosed() { return nil, ErrSessionClosed } id := atomic.AddUint32(&s.nextID, 2) - 2 // increment by 2 to keep odd/even st := newStream(id, s) s.mu.Lock() s.streams[id] = st s.mu.Unlock() // Send SYN if err := s.writeFrame(id, FlagSYN, nil); err != nil { s.mu.Lock() delete(s.streams, id) s.mu.Unlock() return nil, err } return st, nil } // Accept waits for an inbound stream opened by the remote side. func (s *Session) Accept() (*Stream, error) { select { case st := <-s.acceptCh: return st, nil case <-s.quit: return nil, ErrSessionClosed } } // Close shuts down the session and all streams. func (s *Session) Close() error { s.once.Do(func() { atomic.StoreInt32(&s.closed, 1) close(s.quit) s.mu.Lock() for _, st := range s.streams { st.closeLocal() } s.streams = make(map[uint32]*Stream) s.mu.Unlock() s.conn.Close() }) return nil } // IsClosed reports if the session is closed. func (s *Session) IsClosed() bool { return atomic.LoadInt32(&s.closed) == 1 } // NumStreams returns active stream count. func (s *Session) NumStreams() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.streams) } // ─── Frame I/O ─── func (s *Session) writeFrame(streamID uint32, flags byte, data []byte) error { if len(data) > maxPayload { return fmt.Errorf("mux: payload too large: %d > %d", len(data), maxPayload) } hdr := make([]byte, headerSize) binary.BigEndian.PutUint32(hdr[0:4], streamID) hdr[4] = flags binary.BigEndian.PutUint16(hdr[5:7], uint16(len(data))) s.writeMu.Lock() defer s.writeMu.Unlock() s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := s.conn.Write(hdr); err != nil { return err } if len(data) > 0 { if _, err := s.conn.Write(data); err != nil { return err } } atomic.AddInt64(&s.BytesSent, int64(headerSize+len(data))) return nil } func (s *Session) readLoop() { hdr := make([]byte, headerSize) for { if _, err := io.ReadFull(s.conn, hdr); err != nil { if !s.IsClosed() { log.Printf("[mux] read header error: %v", err) } s.Close() return } streamID := binary.BigEndian.Uint32(hdr[0:4]) flags := hdr[4] length := binary.BigEndian.Uint16(hdr[5:7]) var data []byte if length > 0 { data = make([]byte, length) if _, err := io.ReadFull(s.conn, data); err != nil { if !s.IsClosed() { log.Printf("[mux] read data error: %v", err) } s.Close() return } } atomic.AddInt64(&s.BytesReceived, int64(headerSize+int(length))) s.handleFrame(streamID, flags, data) } } func (s *Session) handleFrame(streamID uint32, flags byte, data []byte) { // Ping/Pong on StreamID 0 if flags&FlagPING != 0 { s.writeFrame(0, FlagPONG, nil) return } if flags&FlagPONG != 0 { return // pong received, connection alive } // SYN — new inbound stream if flags&FlagSYN != 0 { st := newStream(streamID, s) s.mu.Lock() s.streams[streamID] = st s.mu.Unlock() select { case s.acceptCh <- st: default: log.Printf("[mux] accept backlog full, dropping stream %d", streamID) s.writeFrame(streamID, FlagRST, nil) s.mu.Lock() delete(s.streams, streamID) s.mu.Unlock() } return } // Find the stream s.mu.RLock() st, ok := s.streams[streamID] s.mu.RUnlock() if !ok { if flags&FlagRST == 0 { s.writeFrame(streamID, FlagRST, nil) } return } // RST if flags&FlagRST != 0 { st.resetByPeer() s.mu.Lock() delete(s.streams, streamID) s.mu.Unlock() return } // DATA if flags&FlagDATA != 0 && len(data) > 0 { st.pushData(data) } // FIN if flags&FlagFIN != 0 { st.finByPeer() } } func (s *Session) removeStream(id uint32) { s.mu.Lock() delete(s.streams, id) s.mu.Unlock() } func (s *Session) pingLoop() { ticker := time.NewTicker(pingInterval) defer ticker.Stop() for { select { case <-ticker.C: if err := s.writeFrame(0, FlagPING, nil); err != nil { return } case <-s.quit: return } } } // ─── Stream ─── // A Stream is a virtual connection within a Session, implementing net.Conn. type Stream struct { id uint32 sess *Session readBuf *ringBuffer readCh chan struct{} // signaled when data arrives closed int32 finRecv int32 // remote sent FIN finSent int32 // we sent FIN reset int32 mu sync.Mutex } func newStream(id uint32, sess *Session) *Stream { return &Stream{ id: id, sess: sess, readBuf: newRingBuffer(defaultWindowSize), readCh: make(chan struct{}, 1), } } // Read implements io.Reader. func (st *Stream) Read(p []byte) (int, error) { for { if atomic.LoadInt32(&st.reset) == 1 { return 0, ErrStreamReset } n := st.readBuf.Read(p) if n > 0 { return n, nil } // Buffer empty — check if FIN received if atomic.LoadInt32(&st.finRecv) == 1 { return 0, io.EOF } if atomic.LoadInt32(&st.closed) == 1 { return 0, ErrStreamClosed } // Wait for data select { case <-st.readCh: case <-st.sess.quit: return 0, ErrSessionClosed } } } // Write implements io.Writer. func (st *Stream) Write(p []byte) (int, error) { if atomic.LoadInt32(&st.closed) == 1 || atomic.LoadInt32(&st.reset) == 1 { return 0, ErrStreamClosed } total := 0 for len(p) > 0 { chunk := p if len(chunk) > maxPayload { chunk = p[:maxPayload] } if err := st.sess.writeFrame(st.id, FlagDATA, chunk); err != nil { return total, err } total += len(chunk) p = p[len(chunk):] } return total, nil } // Close sends FIN and closes the stream. func (st *Stream) Close() error { if !atomic.CompareAndSwapInt32(&st.closed, 0, 1) { return nil } if atomic.CompareAndSwapInt32(&st.finSent, 0, 1) { st.sess.writeFrame(st.id, FlagFIN, nil) } st.sess.removeStream(st.id) st.notify() return nil } // LocalAddr implements net.Conn. func (st *Stream) LocalAddr() net.Addr { return st.sess.conn.LocalAddr() } func (st *Stream) RemoteAddr() net.Addr { return st.sess.conn.RemoteAddr() } func (st *Stream) SetDeadline(t time.Time) error { return nil // TODO: implement per-stream deadlines } func (st *Stream) SetReadDeadline(t time.Time) error { return nil } func (st *Stream) SetWriteDeadline(t time.Time) error { return nil } func (st *Stream) pushData(data []byte) { st.readBuf.Write(data) st.notify() } func (st *Stream) finByPeer() { atomic.StoreInt32(&st.finRecv, 1) st.notify() } func (st *Stream) resetByPeer() { atomic.StoreInt32(&st.reset, 1) atomic.StoreInt32(&st.closed, 1) st.notify() } func (st *Stream) closeLocal() { atomic.StoreInt32(&st.closed, 1) st.notify() } func (st *Stream) notify() { select { case st.readCh <- struct{}{}: default: } } // ─── Ring Buffer ─── // Lock-free-ish ring buffer for stream receive data. type ringBuffer struct { buf []byte r, w int mu sync.Mutex size int } func newRingBuffer(size int) *ringBuffer { return &ringBuffer{ buf: make([]byte, size), size: size, } } func (rb *ringBuffer) Write(p []byte) int { rb.mu.Lock() defer rb.mu.Unlock() n := 0 for _, b := range p { next := (rb.w + 1) % rb.size if next == rb.r { break // full } rb.buf[rb.w] = b rb.w = next n++ } return n } func (rb *ringBuffer) Read(p []byte) int { rb.mu.Lock() defer rb.mu.Unlock() n := 0 for n < len(p) && rb.r != rb.w { p[n] = rb.buf[rb.r] rb.r = (rb.r + 1) % rb.size n++ } return n } func (rb *ringBuffer) Len() int { rb.mu.Lock() defer rb.mu.Unlock() if rb.w >= rb.r { return rb.w - rb.r } return rb.size - rb.r + rb.w }