//go:build !js // +build !js package websocket import ( "bytes" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // // You probably want to use OriginPatterns instead. InsecureSkipVerify bool // OriginPatterns lists the host patterns for authorized origins. // The request host is always authorized. // Use this to enable cross origin WebSockets. // // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. // In such a case, example.com is the origin and chat.example.com is the request host. // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host // with filepath.Match. // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. // // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead // to bring attention to the danger of such a setting. OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { var o AcceptOptions if opts != nil { o = *opts } return &o } // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // // Accept will not allow cross origin requests by default. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } opts = opts.cloneWithDefaults() if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { if errors.Is(err, filepath.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } } hj, ok := w.(http.Hijacker) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) if ok { w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) // See https://github.com/nhooyr/websocket/issues/166 if ginWriter, ok := w.(interface { WriteHeaderNow() }); ok { ginWriter.WriteHeaderNow() } netConn, brw, err := hj.Hijack() if err != nil { err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } // https://github.com/golang/go/issues/32314 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, copts: copts, flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") if len(websocketSecKeys) == 0 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } if len(websocketSecKeys) > 1 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") } // The RFC states to remove any leading or trailing whitespace. websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) } return 0, nil } func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") if origin == "" { return nil } u, err := url.Parse(origin) if err != nil { return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if strings.EqualFold(r.Host, u.Host) { return nil } for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } if matched { return nil } } if u.Host == "" { return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) } return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { return cp } } } return "" } func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { return nil, false } for _, ext := range extensions { switch ext.name { // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": copts, ok := acceptDeflate(ext, mode) if ok { return copts, true } } } return nil, false } func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue case "client_max_window_bits", "server_max_window_bits=15": continue } if strings.HasPrefix(p, "client_max_window_bits=") { // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } return nil, false } return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { if strings.EqualFold(t, token) { return true } } return false } type websocketExtension struct { name string params []string } func websocketExtensions(h http.Header) []websocketExtension { var exts []websocketExtension extStrs := headerTokens(h, "Sec-WebSocket-Extensions") for _, extStr := range extStrs { if extStr == "" { continue } vals := strings.Split(extStr, ";") for i := range vals { vals[i] = strings.TrimSpace(vals[i]) } e := websocketExtension{ name: vals[0], params: vals[1:], } exts = append(exts, e) } return exts } func headerTokens(h http.Header, key string) []string { key = textproto.CanonicalMIMEHeaderKey(key) var tokens []string for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { t = strings.TrimSpace(t) tokens = append(tokens, t) } } return tokens } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) }
//go:build !js // +build !js package websocket import ( "context" "encoding/binary" "errors" "fmt" "net" "time" "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 } // Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. // All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // // The maximum length of reason must be 125 bytes. Avoid // sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { defer c.wg.Wait() return c.closeHandshake(code, reason) } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { defer c.wg.Wait() defer errd.Wrap(&err, "failed to close WebSocket") if c.isClosed() { return net.ErrClosed } c.close(nil) return c.closeErr } func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") writeErr := c.writeClose(code, reason) closeHandshakeErr := c.waitCloseHandshake() if writeErr != nil { return writeErr } if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { return closeHandshakeErr } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if wroteClose { return net.ErrClosed } ce := CloseError{ Code: code, Reason: reason, } var p []byte var marshalErr error if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() } writeErr := c.writeControl(context.Background(), opClose, p) if CloseStatus(writeErr) != -1 { // Not a real error if it's due to a close frame being received. writeErr = nil } // We do this after in case there was an error writing the close frame. c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) if marshalErr != nil { return marshalErr } return writeErr } func (c *Conn) waitCloseHandshake() error { defer c.close(nil) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) if err != nil { return err } defer c.readMu.unlock() if c.readCloseFrameErr != nil { return c.readCloseFrameErr } for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } for { h, err := c.readLoop(ctx) if err != nil { return err } for i := int64(0); i < h.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } } } func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ Code: StatusNoStatusRcvd, }, nil } if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ Code: StatusCode(binary.BigEndian.Uint16(p)), Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { switch code { case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: return false } if code >= StatusNormalClosure && code <= StatusBadGateway { return true } if code >= 3000 && code <= 4999 { return true } return false } func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } return p, err } const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() c.setCloseErrLocked(err) c.closeMu.Unlock() } func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil && err != nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } }
//go:build !js // +build !js package websocket import ( "compress/flate" "io" "sync" ) // CompressionMode represents the modes available to the permessage-deflate extension. // See https://tools.ietf.org/html/rfc7692 // // Works in all modern browsers except Safari which does not implement the permessage-deflate extension. // // Compression is only used if the peer supports the mode selected. type CompressionMode int const ( // CompressionDisabled disables the negotiation of the permessage-deflate extension. // // This is the default. Do not enable compression without benchmarking for your particular use case first. CompressionDisabled CompressionMode = iota // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from // previous messages. i.e compression context across messages is preserved. // // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. // // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's // that are used when reading and then returned. // // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. // // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. CompressionContextTakeover // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from // a sync.Pool. // // This means less efficient compression as the sliding window from previous messages will not be used but the // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. // Especially if the connections are long lived and seldom written to. // // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. // // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. CompressionNoContextTakeover ) func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool } func (copts *compressionOptions) String() string { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" } if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } return s } // These bytes are required to get flate.Reader to return. // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps // trying to read more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { w io.Writer tail []byte } func (tw *trimLastFourBytesWriter) reset() { if tw != nil && tw.tail != nil { tw.tail = tw.tail[:0] } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if tw.tail == nil { tw.tail = make([]byte, 0, 4) } extra := len(tw.tail) + len(p) - 4 if extra <= 0 { tw.tail = append(tw.tail, p...) return len(p), nil } // Now we need to write as many extra bytes as we can from the previous tail. if extra > len(tw.tail) { extra = len(tw.tail) } if extra > 0 { _, err := tw.w.Write(tw.tail[:extra]) if err != nil { return 0, err } // Shift remaining bytes in tail over. n := copy(tw.tail, tw.tail[extra:]) tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, // all of it is is part of the tail. if len(p) <= 4 { tw.tail = append(tw.tail, p...) return len(p), nil } // Otherwise, only the last 4 bytes are. tw.tail = append(tw.tail, p[len(p)-4:]...) p = p[:len(p)-4] n, err := tw.w.Write(p) return n + 4, err } var flateReaderPool sync.Pool func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { return flate.NewReaderDict(r, dict) } fr.(flate.Resetter).Reset(r, dict) return fr } func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } var flateWriterPool sync.Pool func getFlateWriter(w io.Writer) *flate.Writer { fw, ok := flateWriterPool.Get().(*flate.Writer) if !ok { fw, _ = flate.NewWriter(w, flate.BestSpeed) return fw } fw.Reset(w) return fw } func putFlateWriter(w *flate.Writer) { flateWriterPool.Put(w) } type slidingWindow struct { buf []byte } var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} func slidingWindowPool(n int) *sync.Pool { swPoolMu.RLock() p, ok := swPool[n] swPoolMu.RUnlock() if ok { return p } p = &sync.Pool{} swPoolMu.Lock() swPool[n] = p swPoolMu.Unlock() return p } func (sw *slidingWindow) init(n int) { if sw.buf != nil { return } if n == 0 { n = 32768 } p := slidingWindowPool(n) sw2, ok := p.Get().(*slidingWindow) if ok { *sw = *sw2 } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { sw.buf = sw.buf[:0] swPoolMu.Lock() swPool[cap(sw.buf)].Put(sw) swPoolMu.Unlock() } func (sw *slidingWindow) write(p []byte) { if len(p) >= cap(sw.buf) { sw.buf = sw.buf[:cap(sw.buf)] p = p[len(p)-cap(sw.buf):] copy(sw.buf, p) return } left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left copy(sw.buf, sw.buf[spaceNeeded:]) sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } sw.buf = append(sw.buf, p...) }
//go:build !js // +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "net" "runtime" "strconv" "sync" "sync/atomic" ) // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int // MessageType constants. const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary ) // Conn represents a WebSocket connection. // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control // frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. // // On any error from any method, the connection is closed // with an appropriate reason. // // This applies to context expirations as well unfortunately. // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 type Conn struct { noCopy subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context // Read state. readMu *mu readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. msgWriter *msgWriter writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte writeHeader header wg sync.WaitGroup closed chan struct{} closeMu sync.Mutex closeErr error wroteClose bool pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } type connConfig struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer } func newConn(cfg connConfig) *Conn { c := &Conn{ subprotocol: cfg.subprotocol, rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } c.readMu = newMu(c) c.writeFrameMu = newMu(c) c.msgReader = newMsgReader(c) c.msgWriter = newMsgWriter(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 if !c.msgWriter.flateContextTakeover() { c.flateThreshold = 512 } } runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) c.wg.Add(1) go func() { defer c.wg.Done() c.timeoutLoop() }() return c } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } func (c *Conn) close(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { return } if err == nil { err = c.rwc.Close() } c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. c.rwc.Close() c.wg.Add(1) go func() { defer c.wg.Done() c.msgWriter.close() c.msgReader.close() }() } func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() for { select { case <-c.closed: return case writeCtx = <-c.writeTimeout: case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) c.wg.Add(1) go func() { defer c.wg.Done() c.writeError(StatusPolicyViolation, errors.New("read timed out")) }() case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } } func (c *Conn) flate() bool { return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does // not read from the connection but instead waits for a Reader call // to read the pong. // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { p := atomic.AddInt32(&c.pingCounter, 1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } return nil } func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}, 1) c.activePingsMu.Lock() c.activePings[p] = pong c.activePingsMu.Unlock() defer func() { c.activePingsMu.Lock() delete(c.activePings, p) c.activePingsMu.Unlock() }() err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { case <-c.closed: return net.ErrClosed case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: return nil } } type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) tryLock() bool { select { case m.ch <- struct{}{}: return true default: return false } } func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected // over the receive on closed. select { case <-m.c.closed: // Make sure to release. m.unlock() return net.ErrClosed default: } return nil } } func (m *mu) unlock() { select { case <-m.ch: default: } } type noCopy struct{} func (*noCopy) Lock() {}
//go:build !js // +build !js package websocket import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "fmt" "io" "net/http" "net/url" "strings" "sync" "time" "nhooyr.io/websocket/internal/errd" ) // DialOptions represents Dial's options. type DialOptions struct { // HTTPClient is used for the connection. // Its Transport must return writable bodies for WebSocket handshakes. // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header // Host optionally overrides the Host HTTP header to send. If empty, the value // of URL.Host will be used. Host string // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { var cancel context.CancelFunc var o DialOptions if opts != nil { o = *opts } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient } if o.HTTPClient.Timeout > 0 { ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) newClient := *o.HTTPClient newClient.Timeout = 0 o.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } newClient := *o.HTTPClient oldCheckRedirect := o.HTTPClient.CheckRedirect newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { switch req.URL.Scheme { case "ws": req.URL.Scheme = "http" case "wss": req.URL.Scheme = "https" } if oldCheckRedirect != nil { return oldCheckRedirect(req, via) } return nil } o.HTTPClient = &newClient return ctx, cancel, &o } // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. // However, you can only read the first 1024 bytes of the body. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") var cancel context.CancelFunc ctx, cancel, opts = opts.cloneWithDefaults(ctx) if cancel != nil { defer cancel() } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } var copts *compressionOptions if opts.CompressionMode != CompressionDisabled { copts = opts.CompressionMode.opts() } resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } respBody := resp.Body resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() }) defer timer.Stop() b, _ := io.ReadAll(r) respBody.Close() resp.Body = io.NopCloser(bytes.NewReader(b)) } }() copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } rwc, ok := respBody.(io.ReadWriteCloser) if !ok { return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, copts: copts, flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { return nil, fmt.Errorf("failed to create new http request: %w", err) } if len(opts.Host) > 0 { req.Host = opts.Host } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } func secWebSocketKey(rr io.Reader) (string, error) { if rr == nil { rr = rand.Reader } b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) } err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { proto := resp.Header.Get("Sec-WebSocket-Protocol") if proto == "" { return nil } for _, sp2 := range subprotos { if strings.EqualFold(sp2, proto) { return nil } } return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil } ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } _copts := *copts copts = &_copts for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } if strings.HasPrefix(p, "server_max_window_bits=") { // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } var bufioReaderPool sync.Pool func getBufioReader(r io.Reader) *bufio.Reader { br, ok := bufioReaderPool.Get().(*bufio.Reader) if !ok { return bufio.NewReader(r) } br.Reset(r) return br } func putBufioReader(br *bufio.Reader) { bufioReaderPool.Put(br) } var bufioWriterPool sync.Pool func getBufioWriter(w io.Writer) *bufio.Writer { bw, ok := bufioWriterPool.Get().(*bufio.Writer) if !ok { return bufio.NewWriter(w) } bw.Reset(w) return bw } func putBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) }
//go:build !js package websocket import ( "bufio" "encoding/binary" "fmt" "io" "math" "math/bits" "nhooyr.io/websocket/internal/errd" ) // opcode represents a WebSocket opcode. type opcode int // https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText opBinary // 3 - 7 are reserved for further non-control frames. _ _ _ _ _ opClose opPing opPong // 11-16 are reserved for further control frames. ) // header represents a WebSocket frame header. // See https://tools.ietf.org/html/rfc6455#section-5.2. type header struct { fin bool rsv1 bool rsv2 bool rsv3 bool opcode opcode payloadLength int64 masked bool maskKey uint32 } // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { return header{}, err } h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 h.rsv3 = b&(1<<4) != 0 h.opcode = opcode(b & 0xf) b, err = r.ReadByte() if err != nil { return header{}, err } h.masked = b&(1<<7) != 0 payloadLength := b &^ (1 << 7) switch { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: _, err = io.ReadFull(r, readBuf[:2]) h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: _, err = io.ReadFull(r, readBuf) h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { return header{}, err } if h.payloadLength < 0 { return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { _, err = io.ReadFull(r, readBuf[:4]) if err != nil { return header{}, err } h.maskKey = binary.LittleEndian.Uint32(readBuf) } return h, nil } // maxControlPayload is the maximum length of a control frame payload. // See https://tools.ietf.org/html/rfc6455#section-5.5. const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte if h.fin { b |= 1 << 7 } if h.rsv1 { b |= 1 << 6 } if h.rsv2 { b |= 1 << 5 } if h.rsv3 { b |= 1 << 4 } b |= byte(h.opcode) err = w.WriteByte(b) if err != nil { return err } lengthByte := byte(0) if h.masked { lengthByte |= 1 << 7 } switch { case h.payloadLength > math.MaxUint16: lengthByte |= 127 case h.payloadLength > 125: lengthByte |= 126 case h.payloadLength >= 0: lengthByte |= byte(h.payloadLength) } err = w.WriteByte(lengthByte) if err != nil { return err } switch { case h.payloadLength > math.MaxUint16: binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) _, err = w.Write(buf) case h.payloadLength > 125: binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { binary.LittleEndian.PutUint32(buf, h.maskKey) _, err = w.Write(buf[:4]) if err != nil { return err } } return nil } // mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // // The returned value is the correctly rotated key to // to continue to mask/unmask the message. // // It is optimized for LittleEndian and expects the key // to be in little endian. // // See https://github.com/golang/go/issues/31586 func mask(key uint32, b []byte) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) // At some point in the future we can clean these unrolled loops up. // See https://github.com/golang/go/issues/31586#issuecomment-487436401 // Then we xor until b is less than 128 bytes. for len(b) >= 128 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) v = binary.LittleEndian.Uint64(b[64:72]) binary.LittleEndian.PutUint64(b[64:72], v^key64) v = binary.LittleEndian.Uint64(b[72:80]) binary.LittleEndian.PutUint64(b[72:80], v^key64) v = binary.LittleEndian.Uint64(b[80:88]) binary.LittleEndian.PutUint64(b[80:88], v^key64) v = binary.LittleEndian.Uint64(b[88:96]) binary.LittleEndian.PutUint64(b[88:96], v^key64) v = binary.LittleEndian.Uint64(b[96:104]) binary.LittleEndian.PutUint64(b[96:104], v^key64) v = binary.LittleEndian.Uint64(b[104:112]) binary.LittleEndian.PutUint64(b[104:112], v^key64) v = binary.LittleEndian.Uint64(b[112:120]) binary.LittleEndian.PutUint64(b[112:120], v^key64) v = binary.LittleEndian.Uint64(b[120:128]) binary.LittleEndian.PutUint64(b[120:128], v^key64) b = b[128:] } // Then we xor until b is less than 64 bytes. for len(b) >= 64 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) b = b[64:] } // Then we xor until b is less than 32 bytes. for len(b) >= 32 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) b = b[32:] } // Then we xor until b is less than 16 bytes. for len(b) >= 16 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) b = b[16:] } // Then we xor until b is less than 8 bytes. for len(b) >= 8 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) b = b[8:] } } // Then we xor until b is less than 4 bytes. for len(b) >= 4 { v := binary.LittleEndian.Uint32(b) binary.LittleEndian.PutUint32(b, v^key) b = b[4:] } // xor remaining bytes. for i := range b { b[i] ^= byte(key) key = bits.RotateLeft32(key, -8) } return key }
package bpool import ( "bytes" "sync" ) var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() if b == nil { return &bytes.Buffer{} } return b.(*bytes.Buffer) } // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() bpool.Put(b) }
package errd import ( "fmt" ) // Wrap wraps err with fmt.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { *err = fmt.Errorf(f+": %w", append(v, *err)...) } }
package util // WriterFunc is used to implement one off io.Writers. type WriterFunc func(p []byte) (int, error) func (f WriterFunc) Write(p []byte) (int, error) { return f(p) } // ReaderFunc is used to implement one off io.Readers. type ReaderFunc func(p []byte) (int, error) func (f ReaderFunc) Read(p []byte) (int, error) { return f(p) }
package xsync import ( "fmt" "runtime/debug" ) // Go allows running a function in another goroutine // and waiting for its error. func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { r := recover() if r != nil { select { case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): default: } } }() errs <- fn() }() return errs }
package xsync import ( "sync/atomic" ) // Int64 represents an atomic int64. type Int64 struct { // We do not use atomic.Load/StoreInt64 since it does not // work on 32 bit computers but we need 64 bit integers. i atomic.Value } // Load loads the int64. func (v *Int64) Load() int64 { i, _ := v.i.Load().(int64) return i } // Store stores the int64. func (v *Int64) Store(i int64) { v.i.Store(i) }
package websocket import ( "context" "fmt" "io" "math" "net" "sync/atomic" "time" ) // NetConn converts a *websocket.Conn into a net.Conn. // // It's for tunneling arbitrary protocols over WebSockets. // Few users of the library will need this but it's tricky to implement // correctly and so provided in the library. // See https://github.com/nhooyr/websocket/issues/100. // // Every Write to the net.Conn will correspond to a message write of // the given type on *websocket.Conn. // // The passed ctx bounds the lifetime of the net.Conn. If cancelled, // all reads and writes on the net.Conn will be cancelled. // // If a message is read that is not of the correct type, the connection // will be closed with StatusUnsupportedData and an error will be returned. // // Close will close the *websocket.Conn with StatusNormalClosure. // // When a deadline is hit and there is an active read or write goroutine, the // connection will be closed. This is different from most net.Conn implementations // where only the reading/writing goroutines are interrupted but the connection // is kept alive. // // The Addr methods will return the real addresses for connections obtained // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the // full net.Conn to us. // // When running as WASM, the Addr methods will always return the mock address described above. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. // // Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { c.SetReadLimit(-1) nc := &netConn{ c: c, msgType: msgType, readMu: newMu(c), writeMu: newMu(c), } nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) nc.readCtx, nc.readCancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { if !nc.writeMu.tryLock() { // If the lock cannot be acquired, then there is an // active write goroutine and so we should cancel the context. nc.writeCancel() return } defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. atomic.StoreInt64(&nc.writeExpired, 1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } nc.readTimer = time.AfterFunc(math.MaxInt64, func() { if !nc.readMu.tryLock() { // If the lock cannot be acquired, then there is an // active read goroutine and so we should cancel the context. nc.readCancel() return } defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. atomic.StoreInt64(&nc.readExpired, 1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C } return nc } type netConn struct { c *Conn msgType MessageType writeTimer *time.Timer writeMu *mu writeExpired int64 writeCtx context.Context writeCancel context.CancelFunc readTimer *time.Timer readMu *mu readExpired int64 readCtx context.Context readCancel context.CancelFunc readEOFed bool reader io.Reader } var _ net.Conn = &netConn{} func (nc *netConn) Close() error { nc.writeTimer.Stop() nc.writeCancel() nc.readTimer.Stop() nc.readCancel() return nc.c.Close(StatusNormalClosure, "") } func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() if atomic.LoadInt64(&nc.writeExpired) == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } func (nc *netConn) Read(p []byte) (int, error) { nc.readMu.forceLock() defer nc.readMu.unlock() for { n, err := nc.read(p) if err != nil { return n, err } if n == 0 { continue } return n, nil } } func (nc *netConn) read(p []byte) (int, error) { if atomic.LoadInt64(&nc.readExpired) == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } if nc.readEOFed { return 0, io.EOF } if nc.reader == nil { typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: nc.readEOFed = true return 0, io.EOF } return 0, err } if typ != nc.msgType { err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } nc.reader = r } n, err := nc.reader.Read(p) if err == io.EOF { nc.reader = nil err = nil } return n, err } type websocketAddr struct { } func (a websocketAddr) Network() string { return "websocket" } func (a websocketAddr) String() string { return "websocket/unknown-addr" } func (nc *netConn) SetDeadline(t time.Time) error { nc.SetWriteDeadline(t) nc.SetReadDeadline(t) return nil } func (nc *netConn) SetWriteDeadline(t time.Time) error { atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { nc.writeTimer.Stop() } else { dur := time.Until(t) if dur <= 0 { dur = 1 } nc.writeTimer.Reset(dur) } return nil } func (nc *netConn) SetReadDeadline(t time.Time) error { atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { nc.readTimer.Stop() } else { dur := time.Until(t) if dur <= 0 { dur = 1 } nc.readTimer.Reset(dur) } return nil }
//go:build !js // +build !js package websocket import "net" func (nc *netConn) RemoteAddr() net.Addr { if unc, ok := nc.c.rwc.(net.Conn); ok { return unc.RemoteAddr() } return websocketAddr{} } func (nc *netConn) LocalAddr() net.Addr { if unc, ok := nc.c.rwc.(net.Conn); ok { return unc.LocalAddr() } return websocketAddr{} }
//go:build !js // +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "net" "strings" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/util" "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. // // If you need a separate timeout on the Reader call and the Read itself, // use time.AfterFunc to cancel the context passed in. // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 // Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err } b, err := io.ReadAll(r) return typ, b, err } // CloseRead starts a goroutine to read from the connection until it is closed // or a data message is received. // // Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // // If a data message is received, the connection will be closed with StatusPolicyViolation. // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) c.wg.Add(1) go func() { defer c.CloseNow() defer c.wg.Done() defer cancel() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") } }() return ctx } // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. // // Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { if n >= 0 { // We read one more byte than the limit in case // there is a fin frame that needs to be read. n++ } c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } mr.readFunc = mr.read mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { if mr.dict == nil { mr.dict = &slidingWindow{} } mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } if mr.flateContextTakeover() { mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) } else { mr.flateReader = getFlateReader(mr.flateBufio, nil) } mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil } } func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() if mr.dict != nil { mr.dict.close() mr.dict = nil } if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } if mr.c.client { putBufioReader(mr.c.br) mr.c.br = nil } } func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { return !mr.c.copts.serverNoContextTakeover } return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. if h.opcode != opText && h.opcode != opBinary { return true } return false } func (c *Conn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } } } func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed case c.readTimeout <- ctx: } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: return header{}, net.ErrClosed case <-ctx.Done(): return header{}, ctx.Err() default: c.close(err) return header{}, err } } select { case <-c.closed: return header{}, net.ErrClosed case c.readTimeout <- context.Background(): } return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, net.ErrClosed case c.readTimeout <- ctx: } n, err := io.ReadFull(c.br, p) if err != nil { select { case <-c.closed: return n, net.ErrClosed case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } } select { case <-c.closed: return n, net.ErrClosed case c.readTimeout <- context.Background(): } return n, err } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := c.readControlBuf[:h.payloadLength] _, err = c.readFramePayload(ctx, b) if err != nil { return err } if h.masked { mask(h.maskKey, b) } switch h.opcode { case opPing: return c.writeControl(ctx, opPong, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() if ok { select { case pong <- struct{}{}: default: } } return nil } defer func() { c.readCloseFrameErr = err }() ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) return err } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) if err != nil { return 0, nil, err } defer c.readMu.unlock() if !c.msgReader.fin { err = errors.New("previous message not read to completion") c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } c.msgReader.reset(ctx, h) return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { c *Conn ctx context.Context flate bool flateReader io.Reader flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict *slidingWindow fin bool payloadLength int64 maskKey uint32 // util.ReaderFunc(mr.Read) to avoid continuous allocations. readFunc util.ReaderFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() } mr.setFrame(h) } func (mr *msgReader) setFrame(h header) { mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey } func (mr *msgReader) Read(p []byte) (n int, err error) { err = mr.c.readMu.lock(mr.ctx) if err != nil { return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() n, err = mr.limitReader.Read(p) if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { mr.putFlateReader() return n, io.EOF } if err != nil { err = fmt.Errorf("failed to read: %w", err) mr.c.close(err) } return n, err } func (mr *msgReader) read(p []byte) (int, error) { for { if mr.payloadLength == 0 { if mr.fin { if mr.flate { return mr.flateTail.Read(p) } return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) continue } if int64(len(p)) > mr.payloadLength { p = p[:mr.payloadLength] } n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } return n, nil } } type limitReader struct { c *Conn r io.Reader limit xsync.Int64 n int64 } func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } lr.limit.Store(limit) lr.reset(r) return lr } func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { if lr.n < 0 { return lr.r.Read(p) } if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } if int64(len(p)) > lr.n { p = p[:lr.n] } n, err := lr.r.Read(p) lr.n -= int64(n) if lr.n < 0 { lr.n = 0 } return n, err }
//go:build !js // +build !js package websocket import ( "bufio" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "net" "time" "compress/flate" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // // You must close the writer once you have written the entire message. // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } return w, nil } // Write writes a message to the connection. // // See the Writer method if you want to stream a message. // // If compression is disabled or the compression threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } type msgWriter struct { c *Conn mu *mu writeMu *mu closed bool ctx context.Context opcode opcode flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer } func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ c: c, mu: newMu(c), writeMu: newMu(c), } return mw } func (mw *msgWriter) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: util.WriterFunc(mw.write), } } if mw.flateWriter == nil { mw.flateWriter = getFlateWriter(mw.trimWriter) } mw.flate = true } func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } return c.msgWriter, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { mw, err := c.writer(ctx, typ) if err != nil { return 0, err } if !c.flate() { defer c.msgWriter.mu.unlock() return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) if err != nil { return n, err } err = mw.Close() return n, err } func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err } mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false mw.closed = false mw.trimWriter.reset() return nil } func (mw *msgWriter) putFlateWriter() { if mw.flateWriter != nil { putFlateWriter(mw.flateWriter) mw.flateWriter = nil } } // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() if mw.closed { return 0, errors.New("cannot use closed writer") } defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) mw.c.close(err) } }() if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { mw.ensureFlate() } } if mw.flate { return mw.flateWriter.Write(p) } return mw.write(p) } func (mw *msgWriter) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil } // Close flushes the frame to the connection. func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) if err != nil { return err } defer mw.writeMu.unlock() if mw.closed { return errors.New("writer already closed") } mw.closed = true if mw.flate { err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate: %w", err) } } _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { mw.putFlateWriter() } mw.mu.unlock() return nil } func (mw *msgWriter) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() mw.putFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } // frame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } // If the state says a close has already been written, we wait until // the connection is closed and return that error. // // However, if the frame being written is a close, that means its the close from // the state being set so we let it go through. c.closeMu.Lock() wroteClose := c.wroteClose c.closeMu.Unlock() if wroteClose && opcode != opClose { c.writeFrameMu.unlock() select { case <-ctx.Done(): return 0, ctx.Err() case <-c.closed: return 0, net.ErrClosed } } defer c.writeFrameMu.unlock() select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } defer func() { if err != nil { select { case <-c.closed: err = net.ErrClosed case <-ctx.Done(): err = ctx.Err() default: } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) if c.client { c.writeHeader.masked = true _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false if flate && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err } n, err := c.writeFramePayload(p) if err != nil { return n, err } if c.writeHeader.fin { err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { case <-c.closed: if opcode == opClose { return n, nil } return n, net.ErrClosed case c.writeTimeout <- context.Background(): } return n, nil } func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. i := c.bw.Buffered() j := len(p) if j > c.bw.Available() { j = c.bw.Available() } _, err := c.bw.Write(p[:j]) if err != nil { return n, err } maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j } return n, nil } // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) bw.WriteByte(0) bw.Flush() bw.Reset(w) return writeBuf } func (c *Conn) writeError(code StatusCode, err error) { c.setCloseErr(err) c.writeClose(code, err.Error()) c.close(nil) }
// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" "fmt" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/util" ) // Read reads a JSON message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to read JSON message") _, r, err := c.Reader(ctx) if err != nil { return err } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil } // Write writes the JSON message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { err := c.Write(ctx, websocket.MessageText, p) if err != nil { return 0, err } return len(p), nil })).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } return nil }