forked from Mirrors/gomuks
websocket: disconnect if no data received in a minute
This commit is contained in:
parent
6bb1d4477c
commit
757c1b444e
3 changed files with 43 additions and 3 deletions
|
@ -47,7 +47,7 @@ export default abstract class RPCClient {
|
|||
resolve: (data: unknown) => void,
|
||||
reject: (err: Error) => void
|
||||
}> = new Map()
|
||||
protected nextRequestID: number = 1
|
||||
#requestIDCounter: number = 1
|
||||
|
||||
protected abstract isConnected: boolean
|
||||
protected abstract send(data: string): void
|
||||
|
@ -83,6 +83,10 @@ export default abstract class RPCClient {
|
|||
)
|
||||
}
|
||||
|
||||
protected get nextRequestID(): number {
|
||||
return this.#requestIDCounter++
|
||||
}
|
||||
|
||||
request<Req, Resp>(command: string, data: Req): CancellablePromise<Resp> {
|
||||
if (!this.isConnected) {
|
||||
return new CancellablePromise((_resolve, reject) => {
|
||||
|
@ -90,7 +94,7 @@ export default abstract class RPCClient {
|
|||
}, () => {
|
||||
})
|
||||
}
|
||||
const request_id = this.nextRequestID++
|
||||
const request_id = this.nextRequestID
|
||||
return new CancellablePromise((resolve, reject) => {
|
||||
if (!this.isConnected) {
|
||||
reject(new Error("Websocket not connected"))
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
import RPCClient from "./rpc.ts"
|
||||
import type { RPCCommand } from "./types"
|
||||
|
||||
const PING_INTERVAL = 15_000
|
||||
const RECV_TIMEOUT = 4 * PING_INTERVAL
|
||||
|
||||
export default class WSClient extends RPCClient {
|
||||
#conn: WebSocket | null = null
|
||||
#lastMessage: number = 0
|
||||
#pingInterval: number | null = null
|
||||
|
||||
constructor(readonly addr: string) {
|
||||
super()
|
||||
|
@ -25,18 +30,32 @@ export default class WSClient extends RPCClient {
|
|||
|
||||
start() {
|
||||
try {
|
||||
this.#lastMessage = Date.now()
|
||||
console.info("Connecting to websocket", this.addr)
|
||||
this.#conn = new WebSocket(this.addr)
|
||||
this.#conn.onmessage = this.#onMessage
|
||||
this.#conn.onopen = this.#onOpen
|
||||
this.#conn.onerror = this.#onError
|
||||
this.#conn.onclose = this.#onClose
|
||||
this.#pingInterval = setInterval(this.#pingLoop, PING_INTERVAL)
|
||||
} catch (err) {
|
||||
this.#dispatchConnectionStatus(false, err as Error)
|
||||
}
|
||||
}
|
||||
|
||||
#pingLoop = () => {
|
||||
if (Date.now() - this.#lastMessage > RECV_TIMEOUT) {
|
||||
console.warn("Websocket ping timeout, last message at", this.#lastMessage)
|
||||
this.#conn?.close(4002, "Ping timeout")
|
||||
return
|
||||
}
|
||||
this.send(JSON.stringify({ command: "ping", request_id: this.nextRequestID }))
|
||||
}
|
||||
|
||||
stop() {
|
||||
if (this.#pingInterval !== null) {
|
||||
clearInterval(this.#pingInterval)
|
||||
}
|
||||
this.#conn?.close(1000, "Client closed")
|
||||
}
|
||||
|
||||
|
@ -52,6 +71,7 @@ export default class WSClient extends RPCClient {
|
|||
}
|
||||
|
||||
#onMessage = (ev: MessageEvent) => {
|
||||
this.#lastMessage = Date.now()
|
||||
let parsed: RPCCommand<unknown>
|
||||
try {
|
||||
parsed = JSON.parse(ev.data)
|
||||
|
|
18
websocket.go
18
websocket.go
|
@ -23,6 +23,7 @@ import (
|
|||
"net/http"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
@ -45,7 +46,10 @@ func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand)
|
|||
return writer.Close()
|
||||
}
|
||||
|
||||
const StatusEventsStuck = 4001
|
||||
const (
|
||||
StatusEventsStuck = 4001
|
||||
StatusPingTimeout = 4002
|
||||
)
|
||||
|
||||
func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Sec-Fetch-Mode") != "websocket" {
|
||||
|
@ -115,9 +119,14 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
})
|
||||
|
||||
lastDataReceived := &atomic.Int64{}
|
||||
lastDataReceived.Store(time.Now().UnixMilli())
|
||||
const RecvTimeout = 60 * time.Second
|
||||
go func() {
|
||||
defer recoverPanic("event loop")
|
||||
defer closeOnce.Do(forceClose)
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
defer ticker.Stop()
|
||||
ctxDone := ctx.Done()
|
||||
for {
|
||||
select {
|
||||
|
@ -129,6 +138,12 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
} else {
|
||||
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event")
|
||||
}
|
||||
case <-ticker.C:
|
||||
if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() {
|
||||
log.Warn().Msg("No data received in a minute, closing connection")
|
||||
_ = conn.Close(StatusPingTimeout, "Ping timeout")
|
||||
return
|
||||
}
|
||||
case <-ctxDone:
|
||||
return
|
||||
}
|
||||
|
@ -191,6 +206,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
_ = conn.Close(websocket.StatusUnsupportedData, "Non-text message")
|
||||
return
|
||||
}
|
||||
lastDataReceived.Store(time.Now().UnixMilli())
|
||||
var cmd hicli.JSONCommand
|
||||
err = json.NewDecoder(reader).Decode(&cmd)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue