websocket: disconnect if no data received in a minute

This commit is contained in:
Tulir Asokan 2024-10-12 00:52:23 +03:00
parent 6bb1d4477c
commit 757c1b444e
3 changed files with 43 additions and 3 deletions

View file

@ -47,7 +47,7 @@ export default abstract class RPCClient {
resolve: (data: unknown) => void, resolve: (data: unknown) => void,
reject: (err: Error) => void reject: (err: Error) => void
}> = new Map() }> = new Map()
protected nextRequestID: number = 1 #requestIDCounter: number = 1
protected abstract isConnected: boolean protected abstract isConnected: boolean
protected abstract send(data: string): void protected abstract send(data: string): void
@ -83,6 +83,10 @@ export default abstract class RPCClient {
) )
} }
protected get nextRequestID(): number {
return this.#requestIDCounter++
}
request<Req, Resp>(command: string, data: Req): CancellablePromise<Resp> { request<Req, Resp>(command: string, data: Req): CancellablePromise<Resp> {
if (!this.isConnected) { if (!this.isConnected) {
return new CancellablePromise((_resolve, reject) => { return new CancellablePromise((_resolve, reject) => {
@ -90,7 +94,7 @@ export default abstract class RPCClient {
}, () => { }, () => {
}) })
} }
const request_id = this.nextRequestID++ const request_id = this.nextRequestID
return new CancellablePromise((resolve, reject) => { return new CancellablePromise((resolve, reject) => {
if (!this.isConnected) { if (!this.isConnected) {
reject(new Error("Websocket not connected")) reject(new Error("Websocket not connected"))

View file

@ -16,8 +16,13 @@
import RPCClient from "./rpc.ts" import RPCClient from "./rpc.ts"
import type { RPCCommand } from "./types" import type { RPCCommand } from "./types"
const PING_INTERVAL = 15_000
const RECV_TIMEOUT = 4 * PING_INTERVAL
export default class WSClient extends RPCClient { export default class WSClient extends RPCClient {
#conn: WebSocket | null = null #conn: WebSocket | null = null
#lastMessage: number = 0
#pingInterval: number | null = null
constructor(readonly addr: string) { constructor(readonly addr: string) {
super() super()
@ -25,18 +30,32 @@ export default class WSClient extends RPCClient {
start() { start() {
try { try {
this.#lastMessage = Date.now()
console.info("Connecting to websocket", this.addr) console.info("Connecting to websocket", this.addr)
this.#conn = new WebSocket(this.addr) this.#conn = new WebSocket(this.addr)
this.#conn.onmessage = this.#onMessage this.#conn.onmessage = this.#onMessage
this.#conn.onopen = this.#onOpen this.#conn.onopen = this.#onOpen
this.#conn.onerror = this.#onError this.#conn.onerror = this.#onError
this.#conn.onclose = this.#onClose this.#conn.onclose = this.#onClose
this.#pingInterval = setInterval(this.#pingLoop, PING_INTERVAL)
} catch (err) { } catch (err) {
this.#dispatchConnectionStatus(false, err as Error) this.#dispatchConnectionStatus(false, err as Error)
} }
} }
#pingLoop = () => {
if (Date.now() - this.#lastMessage > RECV_TIMEOUT) {
console.warn("Websocket ping timeout, last message at", this.#lastMessage)
this.#conn?.close(4002, "Ping timeout")
return
}
this.send(JSON.stringify({ command: "ping", request_id: this.nextRequestID }))
}
stop() { stop() {
if (this.#pingInterval !== null) {
clearInterval(this.#pingInterval)
}
this.#conn?.close(1000, "Client closed") this.#conn?.close(1000, "Client closed")
} }
@ -52,6 +71,7 @@ export default class WSClient extends RPCClient {
} }
#onMessage = (ev: MessageEvent) => { #onMessage = (ev: MessageEvent) => {
this.#lastMessage = Date.now()
let parsed: RPCCommand<unknown> let parsed: RPCCommand<unknown>
try { try {
parsed = JSON.parse(ev.data) parsed = JSON.parse(ev.data)

View file

@ -23,6 +23,7 @@ import (
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/coder/websocket" "github.com/coder/websocket"
@ -45,7 +46,10 @@ func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand)
return writer.Close() return writer.Close()
} }
const StatusEventsStuck = 4001 const (
StatusEventsStuck = 4001
StatusPingTimeout = 4002
)
func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Sec-Fetch-Mode") != "websocket" { if r.Header.Get("Sec-Fetch-Mode") != "websocket" {
@ -115,9 +119,14 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
} }
}) })
lastDataReceived := &atomic.Int64{}
lastDataReceived.Store(time.Now().UnixMilli())
const RecvTimeout = 60 * time.Second
go func() { go func() {
defer recoverPanic("event loop") defer recoverPanic("event loop")
defer closeOnce.Do(forceClose) defer closeOnce.Do(forceClose)
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
ctxDone := ctx.Done() ctxDone := ctx.Done()
for { for {
select { select {
@ -129,6 +138,12 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
} else { } else {
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event") log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event")
} }
case <-ticker.C:
if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() {
log.Warn().Msg("No data received in a minute, closing connection")
_ = conn.Close(StatusPingTimeout, "Ping timeout")
return
}
case <-ctxDone: case <-ctxDone:
return return
} }
@ -191,6 +206,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusUnsupportedData, "Non-text message") _ = conn.Close(websocket.StatusUnsupportedData, "Non-text message")
return return
} }
lastDataReceived.Store(time.Now().UnixMilli())
var cmd hicli.JSONCommand var cmd hicli.JSONCommand
err = json.NewDecoder(reader).Decode(&cmd) err = json.NewDecoder(reader).Decode(&cmd)
if err != nil { if err != nil {