diff --git a/desktop/main.go b/desktop/main.go index fb78b82..d13d358 100644 --- a/desktop/main.go +++ b/desktop/main.go @@ -152,7 +152,7 @@ func main() { URL: "/", }) - gmx.SubscribeEvents(nil, func(command *hicli.JSONCommand) { + gmx.EventBuffer.Subscribe(0, nil, func(command *hicli.JSONCommand) { app.EmitEvent("hicli_event", command) }) diff --git a/pkg/gomuks/buffer.go b/pkg/gomuks/buffer.go new file mode 100644 index 0000000..021102e --- /dev/null +++ b/pkg/gomuks/buffer.go @@ -0,0 +1,157 @@ +// gomuks - A Matrix client written in Go. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gomuks + +import ( + "encoding/json" + "fmt" + "maps" + "slices" + "sync" + + "github.com/coder/websocket" + + "go.mau.fi/gomuks/pkg/hicli" +) + +type WebsocketCloseFunc func(websocket.StatusCode, string) + +type EventBuffer struct { + lock sync.RWMutex + buf []*hicli.JSONCommand + minID int64 + maxID int64 + MaxSize int + + websocketClosers map[uint64]WebsocketCloseFunc + lastAckedID map[uint64]int64 + eventListeners map[uint64]func(*hicli.JSONCommand) + nextListenerID uint64 +} + +func NewEventBuffer(maxSize int) *EventBuffer { + return &EventBuffer{ + websocketClosers: make(map[uint64]WebsocketCloseFunc), + lastAckedID: make(map[uint64]int64), + eventListeners: make(map[uint64]func(*hicli.JSONCommand)), + buf: make([]*hicli.JSONCommand, 0, maxSize*2), + MaxSize: maxSize, + minID: -1, + } +} + +func (eb *EventBuffer) HicliEventHandler(evt any) { + data, err := json.Marshal(evt) + if err != nil { + panic(fmt.Errorf("failed to marshal event %T: %w", evt, err)) + } + allowCache := true + if syncComplete, ok := evt.(*hicli.SyncComplete); ok && syncComplete.Since != nil && *syncComplete.Since == "" { + // Don't cache initial sync responses + allowCache = false + } else if _, ok := evt.(*hicli.Typing); ok { + // Also don't cache typing events + allowCache = false + } + eb.lock.Lock() + defer eb.lock.Unlock() + jc := &hicli.JSONCommand{ + Command: hicli.EventTypeName(evt), + Data: data, + } + if allowCache { + eb.addToBuffer(jc) + } + for _, listener := range eb.eventListeners { + listener(jc) + } +} + +func (eb *EventBuffer) GetClosers() []WebsocketCloseFunc { + eb.lock.Lock() + defer eb.lock.Unlock() + return slices.Collect(maps.Values(eb.websocketClosers)) +} + +func (eb *EventBuffer) Unsubscribe(listenerID uint64) { + eb.lock.Lock() + defer eb.lock.Unlock() + delete(eb.eventListeners, listenerID) + delete(eb.websocketClosers, listenerID) +} + +func (eb *EventBuffer) addToBuffer(evt *hicli.JSONCommand) { + eb.maxID-- + evt.RequestID = eb.maxID + if len(eb.lastAckedID) > 0 { + eb.buf = append(eb.buf, evt) + } else { + eb.minID = eb.maxID - 1 + } + if len(eb.buf) > eb.MaxSize { + eb.buf = eb.buf[len(eb.buf)-eb.MaxSize:] + eb.minID = eb.buf[0].RequestID + } +} + +func (eb *EventBuffer) ClearListenerLastAckedID(listenerID uint64) { + eb.lock.Lock() + defer eb.lock.Unlock() + delete(eb.lastAckedID, listenerID) + eb.gc() +} + +func (eb *EventBuffer) SetLastAckedID(listenerID uint64, ackedID int64) { + eb.lock.Lock() + defer eb.lock.Unlock() + eb.lastAckedID[listenerID] = ackedID + eb.gc() +} + +func (eb *EventBuffer) gc() { + neededMinID := eb.maxID + for lid, evtID := range eb.lastAckedID { + if evtID > eb.minID { + delete(eb.lastAckedID, lid) + } else if evtID > neededMinID { + neededMinID = evtID + } + } + if neededMinID < eb.minID { + eb.buf = eb.buf[eb.minID-neededMinID:] + eb.minID = neededMinID + } +} + +func (eb *EventBuffer) Subscribe(resumeFrom int64, closeForRestart WebsocketCloseFunc, cb func(*hicli.JSONCommand)) (uint64, []*hicli.JSONCommand) { + eb.lock.Lock() + defer eb.lock.Unlock() + eb.nextListenerID++ + id := eb.nextListenerID + eb.eventListeners[id] = cb + if closeForRestart != nil { + eb.websocketClosers[id] = closeForRestart + } + var resumeData []*hicli.JSONCommand + if resumeFrom < eb.minID { + resumeData = eb.buf[eb.minID-resumeFrom+1:] + eb.lastAckedID[id] = resumeFrom + } else { + eb.lastAckedID[id] = eb.maxID + } + return id, resumeData +} diff --git a/pkg/gomuks/gomuks.go b/pkg/gomuks/gomuks.go index d294943..3764c50 100644 --- a/pkg/gomuks/gomuks.go +++ b/pkg/gomuks/gomuks.go @@ -20,13 +20,11 @@ import ( "context" "embed" "fmt" - "maps" "net/http" "os" "os/signal" "path/filepath" "runtime" - "slices" "sync" "syscall" "time" @@ -65,17 +63,13 @@ type Gomuks struct { stopOnce sync.Once stopChan chan struct{} - websocketClosers map[uint64]WebsocketCloseFunc - eventListeners map[uint64]func(*hicli.JSONCommand) - nextListenerID uint64 - eventListenersLock sync.RWMutex + EventBuffer *EventBuffer } func NewGomuks() *Gomuks { return &Gomuks{ - stopChan: make(chan struct{}), - eventListeners: make(map[uint64]func(*hicli.JSONCommand)), - websocketClosers: make(map[uint64]WebsocketCloseFunc), + stopChan: make(chan struct{}), + EventBuffer: NewEventBuffer(512), } } @@ -176,7 +170,7 @@ func (gmx *Gomuks) StartClient() { nil, gmx.Log.With().Str("component", "hicli").Logger(), []byte("meow"), - hicli.JSONEventHandler(gmx.OnEvent).HandleEvent, + gmx.EventBuffer.HicliEventHandler, ) gmx.Client.LogoutFunc = gmx.Logout httpClient := gmx.Client.Client.Client @@ -218,10 +212,7 @@ func (gmx *Gomuks) WaitForInterrupt() { } func (gmx *Gomuks) DirectStop() { - gmx.eventListenersLock.Lock() - closers := slices.Collect(maps.Values(gmx.websocketClosers)) - gmx.eventListenersLock.Unlock() - for _, closer := range closers { + for _, closer := range gmx.EventBuffer.GetClosers() { closer(websocket.StatusServiceRestart, "Server shutting down") } gmx.Client.Stop() @@ -231,33 +222,6 @@ func (gmx *Gomuks) DirectStop() { } } -func (gmx *Gomuks) OnEvent(evt *hicli.JSONCommand) { - gmx.eventListenersLock.RLock() - defer gmx.eventListenersLock.RUnlock() - for _, listener := range gmx.eventListeners { - listener(evt) - } -} - -type WebsocketCloseFunc func(websocket.StatusCode, string) - -func (gmx *Gomuks) SubscribeEvents(closeForRestart WebsocketCloseFunc, cb func(command *hicli.JSONCommand)) func() { - gmx.eventListenersLock.Lock() - defer gmx.eventListenersLock.Unlock() - gmx.nextListenerID++ - id := gmx.nextListenerID - gmx.eventListeners[id] = cb - if closeForRestart != nil { - gmx.websocketClosers[id] = closeForRestart - } - return func() { - gmx.eventListenersLock.Lock() - defer gmx.eventListenersLock.Unlock() - delete(gmx.eventListeners, id) - delete(gmx.websocketClosers, id) - } -} - func (gmx *Gomuks) Run() { gmx.InitDirectories() err := gmx.LoadConfig() diff --git a/pkg/gomuks/websocket.go b/pkg/gomuks/websocket.go index c3d35d9..0a78446 100644 --- a/pkg/gomuks/websocket.go +++ b/pkg/gomuks/websocket.go @@ -22,6 +22,7 @@ import ( "errors" "net/http" "runtime/debug" + "strconv" "sync" "sync/atomic" "time" @@ -52,6 +53,12 @@ const ( var emptyObject = json.RawMessage("{}") +type PingRequestData struct { + LastReceivedID int64 `json:"last_received_id"` +} + +var runID = time.Now().UnixNano() + func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { var conn *websocket.Conn log := zerolog.Ctx(r.Context()) @@ -80,15 +87,23 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { log.Warn().Err(acceptErr).Msg("Failed to accept websocket connection") return } - log.Info().Msg("Accepted new websocket connection") + resumeFrom, _ := strconv.ParseInt(r.URL.Query().Get("last_received_event"), 10, 64) + resumeRunID, _ := strconv.ParseInt(r.URL.Query().Get("run_id"), 10, 64) + log.Info(). + Int64("resume_from", resumeFrom). + Int64("resume_run_id", resumeRunID). + Int64("current_run_id", runID). + Msg("Accepted new websocket connection") conn.SetReadLimit(128 * 1024) ctx, cancel := context.WithCancel(context.Background()) ctx = log.WithContext(ctx) - unsubscribe := func() {} + var listenerID uint64 evts := make(chan *hicli.JSONCommand, 32) forceClose := func() { cancel() - unsubscribe() + if listenerID != 0 { + gmx.EventBuffer.Unsubscribe(listenerID) + } _ = conn.CloseNow() close(evts) } @@ -99,7 +114,11 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { _ = conn.Close(statusCode, reason) closeOnce.Do(forceClose) } - unsubscribe = gmx.SubscribeEvents(closeManually, func(evt *hicli.JSONCommand) { + if resumeRunID != runID { + resumeFrom = 0 + } + var resumeData []*hicli.JSONCommand + listenerID, resumeData = gmx.EventBuffer.Subscribe(resumeFrom, closeManually, func(evt *hicli.JSONCommand) { if ctx.Err() != nil { return } @@ -115,6 +134,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { }() } }) + didResume := resumeData != nil lastDataReceived := &atomic.Int64{} lastDataReceived.Store(time.Now().UnixMilli()) @@ -133,6 +153,16 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { go func() { defer recoverPanic("event loop") defer closeOnce.Do(forceClose) + for _, cmd := range resumeData { + err := writeCmd(ctx, conn, cmd) + if err != nil { + log.Err(err).Int64("req_id", cmd.RequestID).Msg("Failed to write outgoing event from resume data") + return + } else { + log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event from resume data") + } + } + resumeData = nil ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() ctxDone := ctx.Done() @@ -176,7 +206,22 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { Str("command", cmd.Command). RawJSON("data", cmd.Data). Msg("Received command") - resp := gmx.Client.SubmitJSONCommand(ctx, cmd) + var resp *hicli.JSONCommand + if cmd.Command == "ping" { + resp = &hicli.JSONCommand{ + Command: "pong", + RequestID: cmd.RequestID, + } + var pingData PingRequestData + err := json.Unmarshal(cmd.Data, &pingData) + if err != nil { + log.Err(err).Msg("Failed to parse ping data") + } else if pingData.LastReceivedID != 0 { + gmx.EventBuffer.SetLastAckedID(listenerID, pingData.LastReceivedID) + } + } else { + resp = gmx.Client.SubmitJSONCommand(ctx, cmd) + } if ctx.Err() != nil { return } @@ -188,7 +233,15 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent response to command") } } - initErr := writeCmd(ctx, conn, &hicli.JSONCommandCustom[*hicli.ClientState]{ + initErr := writeCmd(ctx, conn, &hicli.JSONCommandCustom[string]{ + Command: "run_id", + Data: strconv.FormatInt(runID, 10), + }) + if initErr != nil { + log.Err(initErr).Msg("Failed to write init client state message") + return + } + initErr = writeCmd(ctx, conn, &hicli.JSONCommandCustom[*hicli.ClientState]{ Command: "client_state", Data: gmx.Client.State(), }) @@ -205,10 +258,10 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { return } go sendImageAuthToken() - if gmx.Client.IsLoggedIn() { + if gmx.Client.IsLoggedIn() && !didResume { go gmx.sendInitialData(ctx, conn) } - log.Debug().Msg("Connection initialization complete") + log.Debug().Bool("did_resume", didResume).Msg("Connection initialization complete") var closeErr websocket.CloseError for { msgType, reader, err := conn.Reader(ctx) @@ -218,6 +271,9 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { Stringer("status_code", closeErr.Code). Str("reason", closeErr.Reason). Msg("Connection closed") + if closeErr.Code == websocket.StatusGoingAway { + gmx.EventBuffer.ClearListenerLastAckedID(listenerID) + } } else { log.Err(err).Msg("Failed to read message") } diff --git a/pkg/hicli/events.go b/pkg/hicli/events.go index af153cc..c3ee935 100644 --- a/pkg/hicli/events.go +++ b/pkg/hicli/events.go @@ -30,6 +30,8 @@ type SyncNotification struct { } type SyncComplete struct { + Since *string `json:"since,omitempty"` + ClearState bool `json:"clear_state,omitempty"` Rooms map[id.RoomID]*SyncRoom `json:"rooms"` AccountData map[event.Type]*database.AccountData `json:"account_data"` LeftRooms []id.RoomID `json:"left_rooms"` diff --git a/pkg/hicli/init.go b/pkg/hicli/init.go index ae124a1..6b6f7a0 100644 --- a/pkg/hicli/init.go +++ b/pkg/hicli/init.go @@ -76,6 +76,9 @@ func (h *HiClient) GetInitialSync(ctx context.Context, batchSize int) iter.Seq[* LeftRooms: make([]id.RoomID, 0), AccountData: make(map[event.Type]*database.AccountData), } + if i == 0 { + payload.ClearState = true + } for _, room := range rooms { if room.SortingTimestamp == rooms[len(rooms)-1].SortingTimestamp { break diff --git a/pkg/hicli/json.go b/pkg/hicli/json.go index 8f7346c..5c1274a 100644 --- a/pkg/hicli/json.go +++ b/pkg/hicli/json.go @@ -76,12 +76,6 @@ func (h *HiClient) dispatchCurrentState() { } func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand { - if req.Command == "ping" { - return &JSONCommand{ - Command: "pong", - RequestID: req.RequestID, - } - } log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger() ctx, cancel := context.WithCancelCause(ctx) defer func() { diff --git a/pkg/hicli/syncwrap.go b/pkg/hicli/syncwrap.go index 5e41f73..5c14e88 100644 --- a/pkg/hicli/syncwrap.go +++ b/pkg/hicli/syncwrap.go @@ -29,6 +29,7 @@ func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, c := (*HiClient)(h) c.lastSync = time.Now() ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{ + Since: &since, Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join)), LeftRooms: make([]id.RoomID, 0, len(resp.Rooms.Leave)), }}) diff --git a/web/src/api/client.ts b/web/src/api/client.ts index 1db38b4..a3dab90 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -324,12 +324,12 @@ export default class Client { this.initComplete.emit(false) this.syncStatus.emit({ type: "waiting", error_count: 0 }) this.state.clearCache() - localStorage.clear() this.store.clear() } async logout() { await this.rpc.logout() this.clearState() + localStorage.clear() } } diff --git a/web/src/api/rpc.ts b/web/src/api/rpc.ts index 0eb06de..dd7df6f 100644 --- a/web/src/api/rpc.ts +++ b/web/src/api/rpc.ts @@ -75,7 +75,7 @@ export default abstract class RPCClient { public abstract start(): void public abstract stop(): void - protected onCommand(data: RPCCommand) { + protected onCommand(data: RPCCommand) { if (data.command === "response" || data.command === "error") { const target = this.pendingRequests.get(data.request_id) if (!target) { diff --git a/web/src/api/statestore/main.ts b/web/src/api/statestore/main.ts index 91884cb..08b63a9 100644 --- a/web/src/api/statestore/main.ts +++ b/web/src/api/statestore/main.ts @@ -154,6 +154,10 @@ export class StateStore { } applySync(sync: SyncCompleteData) { + if (sync.clear_state && this.rooms.size > 0) { + console.info("Clearing state store as sync told to reset and there are rooms in the store") + this.clear() + } const resyncRoomList = this.roomList.current.length === 0 const changedRoomListEntries = new Map() for (const [roomID, data] of Object.entries(sync.rooms)) { diff --git a/web/src/api/types/hievents.ts b/web/src/api/types/hievents.ts index 96603b4..acaf973 100644 --- a/web/src/api/types/hievents.ts +++ b/web/src/api/types/hievents.ts @@ -28,7 +28,7 @@ import { UserID, } from "./mxtypes.ts" -export interface RPCCommand { +interface BaseRPCCommand { command: string request_id: number data: T @@ -39,7 +39,7 @@ export interface TypingEventData { user_ids: UserID[] } -export interface TypingEvent extends RPCCommand { +export interface TypingEvent extends BaseRPCCommand { command: "typing" } @@ -48,7 +48,7 @@ export interface SendCompleteData { error: string | null } -export interface SendCompleteEvent extends RPCCommand { +export interface SendCompleteEvent extends BaseRPCCommand { command: "send_complete" } @@ -58,11 +58,11 @@ export interface EventsDecryptedData { events: RawDBEvent[] } -export interface EventsDecryptedEvent extends RPCCommand { +export interface EventsDecryptedEvent extends BaseRPCCommand { command: "events_decrypted" } -export interface ImageAuthTokenEvent extends RPCCommand { +export interface ImageAuthTokenEvent extends BaseRPCCommand { command: "image_auth_token" } @@ -85,9 +85,11 @@ export interface SyncCompleteData { rooms: Record left_rooms: RoomID[] account_data: Record + since?: string + clear_state?: boolean } -export interface SyncCompleteEvent extends RPCCommand { +export interface SyncCompleteEvent extends BaseRPCCommand { command: "sync_complete" } @@ -103,7 +105,7 @@ export type ClientState = { homeserver_url: string } -export interface ClientStateEvent extends RPCCommand { +export interface ClientStateEvent extends BaseRPCCommand { command: "client_state" } @@ -114,14 +116,26 @@ export interface SyncStatus { last_sync?: number } -export interface SyncStatusEvent extends RPCCommand { +export interface SyncStatusEvent extends BaseRPCCommand { command: "sync_status" } -export interface InitCompleteEvent extends RPCCommand { +export interface InitCompleteEvent extends BaseRPCCommand { command: "init_complete" } +export interface RunIDEvent extends BaseRPCCommand { + command: "run_id" +} + +export interface ResponseCommand extends BaseRPCCommand { + command: "response" +} + +export interface ErrorCommand extends BaseRPCCommand { + command: "error" +} + export type RPCEvent = ClientStateEvent | SyncStatusEvent | @@ -130,4 +144,7 @@ export type RPCEvent = EventsDecryptedEvent | SyncCompleteEvent | ImageAuthTokenEvent | - InitCompleteEvent + InitCompleteEvent | + RunIDEvent + +export type RPCCommand = RPCEvent | ResponseCommand | ErrorCommand diff --git a/web/src/api/wsclient.ts b/web/src/api/wsclient.ts index 91701e9..5d81a74 100644 --- a/web/src/api/wsclient.ts +++ b/web/src/api/wsclient.ts @@ -23,6 +23,8 @@ export default class WSClient extends RPCClient { #conn: WebSocket | null = null #lastMessage: number = 0 #pingInterval: number | null = null + #lastReceivedEvt: number = 0 + #resumeRunID: string = "" constructor(readonly addr: string) { super() @@ -31,8 +33,13 @@ export default class WSClient extends RPCClient { start() { try { this.#lastMessage = Date.now() - console.info("Connecting to websocket", this.addr) - this.#conn = new WebSocket(this.addr) + const params = new URLSearchParams({ + run_id: this.#resumeRunID, + last_received_event: this.#lastReceivedEvt.toString(), + }).toString() + const addr = this.#lastReceivedEvt && this.#resumeRunID ? `${this.addr}?${params}` : this.addr + console.info("Connecting to websocket", addr) + this.#conn = new WebSocket(addr) this.#conn.onmessage = this.#onMessage this.#conn.onopen = this.#onOpen this.#conn.onerror = this.#onError @@ -49,7 +56,13 @@ export default class WSClient extends RPCClient { this.#conn?.close(4002, "Ping timeout") return } - this.send(JSON.stringify({ command: "ping", request_id: this.nextRequestID })) + this.send(JSON.stringify({ + command: "ping", + data: { + last_received_id: this.#lastReceivedEvt, + }, + request_id: this.nextRequestID, + })) } stop() { @@ -72,7 +85,7 @@ export default class WSClient extends RPCClient { #onMessage = (ev: MessageEvent) => { this.#lastMessage = Date.now() - let parsed: RPCCommand + let parsed: RPCCommand try { parsed = JSON.parse(ev.data) if (!parsed.command) { @@ -84,6 +97,11 @@ export default class WSClient extends RPCClient { this.#conn?.close(1003, "Malformed JSON") return } + if (parsed.request_id < 0) { + this.#lastReceivedEvt = parsed.request_id + } else if (parsed.command === "run_id") { + this.#resumeRunID = parsed.data + } this.onCommand(parsed) }