websocket: add support for resuming sessions

This commit is contained in:
Tulir Asokan 2024-12-06 23:51:47 +02:00
parent 0930e94fb2
commit 7d6bbe77b9
13 changed files with 288 additions and 72 deletions

View file

@ -152,7 +152,7 @@ func main() {
URL: "/",
})
gmx.SubscribeEvents(nil, func(command *hicli.JSONCommand) {
gmx.EventBuffer.Subscribe(0, nil, func(command *hicli.JSONCommand) {
app.EmitEvent("hicli_event", command)
})

157
pkg/gomuks/buffer.go Normal file
View file

@ -0,0 +1,157 @@
// gomuks - A Matrix client written in Go.
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package gomuks
import (
"encoding/json"
"fmt"
"maps"
"slices"
"sync"
"github.com/coder/websocket"
"go.mau.fi/gomuks/pkg/hicli"
)
type WebsocketCloseFunc func(websocket.StatusCode, string)
type EventBuffer struct {
lock sync.RWMutex
buf []*hicli.JSONCommand
minID int64
maxID int64
MaxSize int
websocketClosers map[uint64]WebsocketCloseFunc
lastAckedID map[uint64]int64
eventListeners map[uint64]func(*hicli.JSONCommand)
nextListenerID uint64
}
func NewEventBuffer(maxSize int) *EventBuffer {
return &EventBuffer{
websocketClosers: make(map[uint64]WebsocketCloseFunc),
lastAckedID: make(map[uint64]int64),
eventListeners: make(map[uint64]func(*hicli.JSONCommand)),
buf: make([]*hicli.JSONCommand, 0, maxSize*2),
MaxSize: maxSize,
minID: -1,
}
}
func (eb *EventBuffer) HicliEventHandler(evt any) {
data, err := json.Marshal(evt)
if err != nil {
panic(fmt.Errorf("failed to marshal event %T: %w", evt, err))
}
allowCache := true
if syncComplete, ok := evt.(*hicli.SyncComplete); ok && syncComplete.Since != nil && *syncComplete.Since == "" {
// Don't cache initial sync responses
allowCache = false
} else if _, ok := evt.(*hicli.Typing); ok {
// Also don't cache typing events
allowCache = false
}
eb.lock.Lock()
defer eb.lock.Unlock()
jc := &hicli.JSONCommand{
Command: hicli.EventTypeName(evt),
Data: data,
}
if allowCache {
eb.addToBuffer(jc)
}
for _, listener := range eb.eventListeners {
listener(jc)
}
}
func (eb *EventBuffer) GetClosers() []WebsocketCloseFunc {
eb.lock.Lock()
defer eb.lock.Unlock()
return slices.Collect(maps.Values(eb.websocketClosers))
}
func (eb *EventBuffer) Unsubscribe(listenerID uint64) {
eb.lock.Lock()
defer eb.lock.Unlock()
delete(eb.eventListeners, listenerID)
delete(eb.websocketClosers, listenerID)
}
func (eb *EventBuffer) addToBuffer(evt *hicli.JSONCommand) {
eb.maxID--
evt.RequestID = eb.maxID
if len(eb.lastAckedID) > 0 {
eb.buf = append(eb.buf, evt)
} else {
eb.minID = eb.maxID - 1
}
if len(eb.buf) > eb.MaxSize {
eb.buf = eb.buf[len(eb.buf)-eb.MaxSize:]
eb.minID = eb.buf[0].RequestID
}
}
func (eb *EventBuffer) ClearListenerLastAckedID(listenerID uint64) {
eb.lock.Lock()
defer eb.lock.Unlock()
delete(eb.lastAckedID, listenerID)
eb.gc()
}
func (eb *EventBuffer) SetLastAckedID(listenerID uint64, ackedID int64) {
eb.lock.Lock()
defer eb.lock.Unlock()
eb.lastAckedID[listenerID] = ackedID
eb.gc()
}
func (eb *EventBuffer) gc() {
neededMinID := eb.maxID
for lid, evtID := range eb.lastAckedID {
if evtID > eb.minID {
delete(eb.lastAckedID, lid)
} else if evtID > neededMinID {
neededMinID = evtID
}
}
if neededMinID < eb.minID {
eb.buf = eb.buf[eb.minID-neededMinID:]
eb.minID = neededMinID
}
}
func (eb *EventBuffer) Subscribe(resumeFrom int64, closeForRestart WebsocketCloseFunc, cb func(*hicli.JSONCommand)) (uint64, []*hicli.JSONCommand) {
eb.lock.Lock()
defer eb.lock.Unlock()
eb.nextListenerID++
id := eb.nextListenerID
eb.eventListeners[id] = cb
if closeForRestart != nil {
eb.websocketClosers[id] = closeForRestart
}
var resumeData []*hicli.JSONCommand
if resumeFrom < eb.minID {
resumeData = eb.buf[eb.minID-resumeFrom+1:]
eb.lastAckedID[id] = resumeFrom
} else {
eb.lastAckedID[id] = eb.maxID
}
return id, resumeData
}

View file

@ -20,13 +20,11 @@ import (
"context"
"embed"
"fmt"
"maps"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"slices"
"sync"
"syscall"
"time"
@ -65,17 +63,13 @@ type Gomuks struct {
stopOnce sync.Once
stopChan chan struct{}
websocketClosers map[uint64]WebsocketCloseFunc
eventListeners map[uint64]func(*hicli.JSONCommand)
nextListenerID uint64
eventListenersLock sync.RWMutex
EventBuffer *EventBuffer
}
func NewGomuks() *Gomuks {
return &Gomuks{
stopChan: make(chan struct{}),
eventListeners: make(map[uint64]func(*hicli.JSONCommand)),
websocketClosers: make(map[uint64]WebsocketCloseFunc),
EventBuffer: NewEventBuffer(512),
}
}
@ -176,7 +170,7 @@ func (gmx *Gomuks) StartClient() {
nil,
gmx.Log.With().Str("component", "hicli").Logger(),
[]byte("meow"),
hicli.JSONEventHandler(gmx.OnEvent).HandleEvent,
gmx.EventBuffer.HicliEventHandler,
)
gmx.Client.LogoutFunc = gmx.Logout
httpClient := gmx.Client.Client.Client
@ -218,10 +212,7 @@ func (gmx *Gomuks) WaitForInterrupt() {
}
func (gmx *Gomuks) DirectStop() {
gmx.eventListenersLock.Lock()
closers := slices.Collect(maps.Values(gmx.websocketClosers))
gmx.eventListenersLock.Unlock()
for _, closer := range closers {
for _, closer := range gmx.EventBuffer.GetClosers() {
closer(websocket.StatusServiceRestart, "Server shutting down")
}
gmx.Client.Stop()
@ -231,33 +222,6 @@ func (gmx *Gomuks) DirectStop() {
}
}
func (gmx *Gomuks) OnEvent(evt *hicli.JSONCommand) {
gmx.eventListenersLock.RLock()
defer gmx.eventListenersLock.RUnlock()
for _, listener := range gmx.eventListeners {
listener(evt)
}
}
type WebsocketCloseFunc func(websocket.StatusCode, string)
func (gmx *Gomuks) SubscribeEvents(closeForRestart WebsocketCloseFunc, cb func(command *hicli.JSONCommand)) func() {
gmx.eventListenersLock.Lock()
defer gmx.eventListenersLock.Unlock()
gmx.nextListenerID++
id := gmx.nextListenerID
gmx.eventListeners[id] = cb
if closeForRestart != nil {
gmx.websocketClosers[id] = closeForRestart
}
return func() {
gmx.eventListenersLock.Lock()
defer gmx.eventListenersLock.Unlock()
delete(gmx.eventListeners, id)
delete(gmx.websocketClosers, id)
}
}
func (gmx *Gomuks) Run() {
gmx.InitDirectories()
err := gmx.LoadConfig()

View file

@ -22,6 +22,7 @@ import (
"errors"
"net/http"
"runtime/debug"
"strconv"
"sync"
"sync/atomic"
"time"
@ -52,6 +53,12 @@ const (
var emptyObject = json.RawMessage("{}")
type PingRequestData struct {
LastReceivedID int64 `json:"last_received_id"`
}
var runID = time.Now().UnixNano()
func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
var conn *websocket.Conn
log := zerolog.Ctx(r.Context())
@ -80,15 +87,23 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
log.Warn().Err(acceptErr).Msg("Failed to accept websocket connection")
return
}
log.Info().Msg("Accepted new websocket connection")
resumeFrom, _ := strconv.ParseInt(r.URL.Query().Get("last_received_event"), 10, 64)
resumeRunID, _ := strconv.ParseInt(r.URL.Query().Get("run_id"), 10, 64)
log.Info().
Int64("resume_from", resumeFrom).
Int64("resume_run_id", resumeRunID).
Int64("current_run_id", runID).
Msg("Accepted new websocket connection")
conn.SetReadLimit(128 * 1024)
ctx, cancel := context.WithCancel(context.Background())
ctx = log.WithContext(ctx)
unsubscribe := func() {}
var listenerID uint64
evts := make(chan *hicli.JSONCommand, 32)
forceClose := func() {
cancel()
unsubscribe()
if listenerID != 0 {
gmx.EventBuffer.Unsubscribe(listenerID)
}
_ = conn.CloseNow()
close(evts)
}
@ -99,7 +114,11 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
_ = conn.Close(statusCode, reason)
closeOnce.Do(forceClose)
}
unsubscribe = gmx.SubscribeEvents(closeManually, func(evt *hicli.JSONCommand) {
if resumeRunID != runID {
resumeFrom = 0
}
var resumeData []*hicli.JSONCommand
listenerID, resumeData = gmx.EventBuffer.Subscribe(resumeFrom, closeManually, func(evt *hicli.JSONCommand) {
if ctx.Err() != nil {
return
}
@ -115,6 +134,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
}()
}
})
didResume := resumeData != nil
lastDataReceived := &atomic.Int64{}
lastDataReceived.Store(time.Now().UnixMilli())
@ -133,6 +153,16 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
go func() {
defer recoverPanic("event loop")
defer closeOnce.Do(forceClose)
for _, cmd := range resumeData {
err := writeCmd(ctx, conn, cmd)
if err != nil {
log.Err(err).Int64("req_id", cmd.RequestID).Msg("Failed to write outgoing event from resume data")
return
} else {
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event from resume data")
}
}
resumeData = nil
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
ctxDone := ctx.Done()
@ -176,7 +206,22 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
Str("command", cmd.Command).
RawJSON("data", cmd.Data).
Msg("Received command")
resp := gmx.Client.SubmitJSONCommand(ctx, cmd)
var resp *hicli.JSONCommand
if cmd.Command == "ping" {
resp = &hicli.JSONCommand{
Command: "pong",
RequestID: cmd.RequestID,
}
var pingData PingRequestData
err := json.Unmarshal(cmd.Data, &pingData)
if err != nil {
log.Err(err).Msg("Failed to parse ping data")
} else if pingData.LastReceivedID != 0 {
gmx.EventBuffer.SetLastAckedID(listenerID, pingData.LastReceivedID)
}
} else {
resp = gmx.Client.SubmitJSONCommand(ctx, cmd)
}
if ctx.Err() != nil {
return
}
@ -188,7 +233,15 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent response to command")
}
}
initErr := writeCmd(ctx, conn, &hicli.JSONCommandCustom[*hicli.ClientState]{
initErr := writeCmd(ctx, conn, &hicli.JSONCommandCustom[string]{
Command: "run_id",
Data: strconv.FormatInt(runID, 10),
})
if initErr != nil {
log.Err(initErr).Msg("Failed to write init client state message")
return
}
initErr = writeCmd(ctx, conn, &hicli.JSONCommandCustom[*hicli.ClientState]{
Command: "client_state",
Data: gmx.Client.State(),
})
@ -205,10 +258,10 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
return
}
go sendImageAuthToken()
if gmx.Client.IsLoggedIn() {
if gmx.Client.IsLoggedIn() && !didResume {
go gmx.sendInitialData(ctx, conn)
}
log.Debug().Msg("Connection initialization complete")
log.Debug().Bool("did_resume", didResume).Msg("Connection initialization complete")
var closeErr websocket.CloseError
for {
msgType, reader, err := conn.Reader(ctx)
@ -218,6 +271,9 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
Stringer("status_code", closeErr.Code).
Str("reason", closeErr.Reason).
Msg("Connection closed")
if closeErr.Code == websocket.StatusGoingAway {
gmx.EventBuffer.ClearListenerLastAckedID(listenerID)
}
} else {
log.Err(err).Msg("Failed to read message")
}

View file

@ -30,6 +30,8 @@ type SyncNotification struct {
}
type SyncComplete struct {
Since *string `json:"since,omitempty"`
ClearState bool `json:"clear_state,omitempty"`
Rooms map[id.RoomID]*SyncRoom `json:"rooms"`
AccountData map[event.Type]*database.AccountData `json:"account_data"`
LeftRooms []id.RoomID `json:"left_rooms"`

View file

@ -76,6 +76,9 @@ func (h *HiClient) GetInitialSync(ctx context.Context, batchSize int) iter.Seq[*
LeftRooms: make([]id.RoomID, 0),
AccountData: make(map[event.Type]*database.AccountData),
}
if i == 0 {
payload.ClearState = true
}
for _, room := range rooms {
if room.SortingTimestamp == rooms[len(rooms)-1].SortingTimestamp {
break

View file

@ -76,12 +76,6 @@ func (h *HiClient) dispatchCurrentState() {
}
func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand {
if req.Command == "ping" {
return &JSONCommand{
Command: "pong",
RequestID: req.RequestID,
}
}
log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger()
ctx, cancel := context.WithCancelCause(ctx)
defer func() {

View file

@ -29,6 +29,7 @@ func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync,
c := (*HiClient)(h)
c.lastSync = time.Now()
ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{
Since: &since,
Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join)),
LeftRooms: make([]id.RoomID, 0, len(resp.Rooms.Leave)),
}})

View file

@ -324,12 +324,12 @@ export default class Client {
this.initComplete.emit(false)
this.syncStatus.emit({ type: "waiting", error_count: 0 })
this.state.clearCache()
localStorage.clear()
this.store.clear()
}
async logout() {
await this.rpc.logout()
this.clearState()
localStorage.clear()
}
}

View file

@ -75,7 +75,7 @@ export default abstract class RPCClient {
public abstract start(): void
public abstract stop(): void
protected onCommand(data: RPCCommand<unknown>) {
protected onCommand(data: RPCCommand) {
if (data.command === "response" || data.command === "error") {
const target = this.pendingRequests.get(data.request_id)
if (!target) {

View file

@ -154,6 +154,10 @@ export class StateStore {
}
applySync(sync: SyncCompleteData) {
if (sync.clear_state && this.rooms.size > 0) {
console.info("Clearing state store as sync told to reset and there are rooms in the store")
this.clear()
}
const resyncRoomList = this.roomList.current.length === 0
const changedRoomListEntries = new Map<RoomID, RoomListEntry | null>()
for (const [roomID, data] of Object.entries(sync.rooms)) {

View file

@ -28,7 +28,7 @@ import {
UserID,
} from "./mxtypes.ts"
export interface RPCCommand<T> {
interface BaseRPCCommand<T> {
command: string
request_id: number
data: T
@ -39,7 +39,7 @@ export interface TypingEventData {
user_ids: UserID[]
}
export interface TypingEvent extends RPCCommand<TypingEventData> {
export interface TypingEvent extends BaseRPCCommand<TypingEventData> {
command: "typing"
}
@ -48,7 +48,7 @@ export interface SendCompleteData {
error: string | null
}
export interface SendCompleteEvent extends RPCCommand<SendCompleteData> {
export interface SendCompleteEvent extends BaseRPCCommand<SendCompleteData> {
command: "send_complete"
}
@ -58,11 +58,11 @@ export interface EventsDecryptedData {
events: RawDBEvent[]
}
export interface EventsDecryptedEvent extends RPCCommand<EventsDecryptedData> {
export interface EventsDecryptedEvent extends BaseRPCCommand<EventsDecryptedData> {
command: "events_decrypted"
}
export interface ImageAuthTokenEvent extends RPCCommand<string> {
export interface ImageAuthTokenEvent extends BaseRPCCommand<string> {
command: "image_auth_token"
}
@ -85,9 +85,11 @@ export interface SyncCompleteData {
rooms: Record<RoomID, SyncRoom>
left_rooms: RoomID[]
account_data: Record<EventType, DBAccountData>
since?: string
clear_state?: boolean
}
export interface SyncCompleteEvent extends RPCCommand<SyncCompleteData> {
export interface SyncCompleteEvent extends BaseRPCCommand<SyncCompleteData> {
command: "sync_complete"
}
@ -103,7 +105,7 @@ export type ClientState = {
homeserver_url: string
}
export interface ClientStateEvent extends RPCCommand<ClientState> {
export interface ClientStateEvent extends BaseRPCCommand<ClientState> {
command: "client_state"
}
@ -114,14 +116,26 @@ export interface SyncStatus {
last_sync?: number
}
export interface SyncStatusEvent extends RPCCommand<SyncStatus> {
export interface SyncStatusEvent extends BaseRPCCommand<SyncStatus> {
command: "sync_status"
}
export interface InitCompleteEvent extends RPCCommand<void> {
export interface InitCompleteEvent extends BaseRPCCommand<void> {
command: "init_complete"
}
export interface RunIDEvent extends BaseRPCCommand<string> {
command: "run_id"
}
export interface ResponseCommand extends BaseRPCCommand<unknown> {
command: "response"
}
export interface ErrorCommand extends BaseRPCCommand<unknown> {
command: "error"
}
export type RPCEvent =
ClientStateEvent |
SyncStatusEvent |
@ -130,4 +144,7 @@ export type RPCEvent =
EventsDecryptedEvent |
SyncCompleteEvent |
ImageAuthTokenEvent |
InitCompleteEvent
InitCompleteEvent |
RunIDEvent
export type RPCCommand = RPCEvent | ResponseCommand | ErrorCommand

View file

@ -23,6 +23,8 @@ export default class WSClient extends RPCClient {
#conn: WebSocket | null = null
#lastMessage: number = 0
#pingInterval: number | null = null
#lastReceivedEvt: number = 0
#resumeRunID: string = ""
constructor(readonly addr: string) {
super()
@ -31,8 +33,13 @@ export default class WSClient extends RPCClient {
start() {
try {
this.#lastMessage = Date.now()
console.info("Connecting to websocket", this.addr)
this.#conn = new WebSocket(this.addr)
const params = new URLSearchParams({
run_id: this.#resumeRunID,
last_received_event: this.#lastReceivedEvt.toString(),
}).toString()
const addr = this.#lastReceivedEvt && this.#resumeRunID ? `${this.addr}?${params}` : this.addr
console.info("Connecting to websocket", addr)
this.#conn = new WebSocket(addr)
this.#conn.onmessage = this.#onMessage
this.#conn.onopen = this.#onOpen
this.#conn.onerror = this.#onError
@ -49,7 +56,13 @@ export default class WSClient extends RPCClient {
this.#conn?.close(4002, "Ping timeout")
return
}
this.send(JSON.stringify({ command: "ping", request_id: this.nextRequestID }))
this.send(JSON.stringify({
command: "ping",
data: {
last_received_id: this.#lastReceivedEvt,
},
request_id: this.nextRequestID,
}))
}
stop() {
@ -72,7 +85,7 @@ export default class WSClient extends RPCClient {
#onMessage = (ev: MessageEvent) => {
this.#lastMessage = Date.now()
let parsed: RPCCommand<unknown>
let parsed: RPCCommand
try {
parsed = JSON.parse(ev.data)
if (!parsed.command) {
@ -84,6 +97,11 @@ export default class WSClient extends RPCClient {
this.#conn?.close(1003, "Malformed JSON")
return
}
if (parsed.request_id < 0) {
this.#lastReceivedEvt = parsed.request_id
} else if (parsed.command === "run_id") {
this.#resumeRunID = parsed.data
}
this.onCommand(parsed)
}