From 757c1b444e72853dc1fbafb9536ca5c01c9f1227 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 00:52:23 +0300 Subject: [PATCH] websocket: disconnect if no data received in a minute --- web/src/api/rpc.ts | 8 ++++++-- web/src/api/wsclient.ts | 20 ++++++++++++++++++++ websocket.go | 18 +++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/web/src/api/rpc.ts b/web/src/api/rpc.ts index 5e850f3..de90b7f 100644 --- a/web/src/api/rpc.ts +++ b/web/src/api/rpc.ts @@ -47,7 +47,7 @@ export default abstract class RPCClient { resolve: (data: unknown) => void, reject: (err: Error) => void }> = new Map() - protected nextRequestID: number = 1 + #requestIDCounter: number = 1 protected abstract isConnected: boolean protected abstract send(data: string): void @@ -83,6 +83,10 @@ export default abstract class RPCClient { ) } + protected get nextRequestID(): number { + return this.#requestIDCounter++ + } + request(command: string, data: Req): CancellablePromise { if (!this.isConnected) { return new CancellablePromise((_resolve, reject) => { @@ -90,7 +94,7 @@ export default abstract class RPCClient { }, () => { }) } - const request_id = this.nextRequestID++ + const request_id = this.nextRequestID return new CancellablePromise((resolve, reject) => { if (!this.isConnected) { reject(new Error("Websocket not connected")) diff --git a/web/src/api/wsclient.ts b/web/src/api/wsclient.ts index 1847090..91701e9 100644 --- a/web/src/api/wsclient.ts +++ b/web/src/api/wsclient.ts @@ -16,8 +16,13 @@ import RPCClient from "./rpc.ts" import type { RPCCommand } from "./types" +const PING_INTERVAL = 15_000 +const RECV_TIMEOUT = 4 * PING_INTERVAL + export default class WSClient extends RPCClient { #conn: WebSocket | null = null + #lastMessage: number = 0 + #pingInterval: number | null = null constructor(readonly addr: string) { super() @@ -25,18 +30,32 @@ export default class WSClient extends RPCClient { start() { try { + this.#lastMessage = Date.now() console.info("Connecting to websocket", this.addr) this.#conn = new WebSocket(this.addr) this.#conn.onmessage = this.#onMessage this.#conn.onopen = this.#onOpen this.#conn.onerror = this.#onError this.#conn.onclose = this.#onClose + this.#pingInterval = setInterval(this.#pingLoop, PING_INTERVAL) } catch (err) { this.#dispatchConnectionStatus(false, err as Error) } } + #pingLoop = () => { + if (Date.now() - this.#lastMessage > RECV_TIMEOUT) { + console.warn("Websocket ping timeout, last message at", this.#lastMessage) + this.#conn?.close(4002, "Ping timeout") + return + } + this.send(JSON.stringify({ command: "ping", request_id: this.nextRequestID })) + } + stop() { + if (this.#pingInterval !== null) { + clearInterval(this.#pingInterval) + } this.#conn?.close(1000, "Client closed") } @@ -52,6 +71,7 @@ export default class WSClient extends RPCClient { } #onMessage = (ev: MessageEvent) => { + this.#lastMessage = Date.now() let parsed: RPCCommand try { parsed = JSON.parse(ev.data) diff --git a/websocket.go b/websocket.go index f768f57..752e748 100644 --- a/websocket.go +++ b/websocket.go @@ -23,6 +23,7 @@ import ( "net/http" "runtime/debug" "sync" + "sync/atomic" "time" "github.com/coder/websocket" @@ -45,7 +46,10 @@ func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand) return writer.Close() } -const StatusEventsStuck = 4001 +const ( + StatusEventsStuck = 4001 + StatusPingTimeout = 4002 +) func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Sec-Fetch-Mode") != "websocket" { @@ -115,9 +119,14 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } }) + lastDataReceived := &atomic.Int64{} + lastDataReceived.Store(time.Now().UnixMilli()) + const RecvTimeout = 60 * time.Second go func() { defer recoverPanic("event loop") defer closeOnce.Do(forceClose) + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() ctxDone := ctx.Done() for { select { @@ -129,6 +138,12 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } else { log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event") } + case <-ticker.C: + if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() { + log.Warn().Msg("No data received in a minute, closing connection") + _ = conn.Close(StatusPingTimeout, "Ping timeout") + return + } case <-ctxDone: return } @@ -191,6 +206,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { _ = conn.Close(websocket.StatusUnsupportedData, "Non-text message") return } + lastDataReceived.Store(time.Now().UnixMilli()) var cmd hicli.JSONCommand err = json.NewDecoder(reader).Decode(&cmd) if err != nil {