hicli/sync: send account data to frontend

This commit is contained in:
Tulir Asokan 2024-10-25 19:15:22 +03:00
parent 5768b2202b
commit 72e1bd428e
6 changed files with 63 additions and 12 deletions

View file

@ -27,6 +27,9 @@ const (
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
`
getGlobalAccountDataQuery = `
SELECT user_id, '', type, content FROM account_data WHERE user_id = $1
`
)
type AccountDataQuery struct {
@ -41,12 +44,27 @@ func unsafeJSONString(content json.RawMessage) *string {
return &str
}
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) (*AccountData, error) {
ad := &AccountData{
UserID: userID,
Type: eventType.Type,
Content: content,
}
return ad, adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) (*AccountData, error) {
ad := &AccountData{
UserID: userID,
RoomID: roomID,
Type: eventType.Type,
Content: content,
}
return ad, adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) GetAllGlobal(ctx context.Context, userID id.UserID) ([]*AccountData, error) {
return adq.QueryMany(ctx, getGlobalAccountDataQuery, userID)
}
type AccountData struct {

View file

@ -17,6 +17,7 @@ type SyncRoom struct {
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
State map[event.Type]map[string]database.EventRowID `json:"state"`
AccountData map[event.Type]*database.AccountData `json:"account_data"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
Notifications []SyncNotification `json:"notifications"`
@ -29,6 +30,7 @@ type SyncNotification struct {
type SyncComplete struct {
Rooms map[id.RoomID]*SyncRoom `json:"rooms"`
AccountData map[event.Type]*database.AccountData `json:"account_data"`
LeftRooms []id.RoomID `json:"left_rooms"`
}

View file

@ -19,6 +19,7 @@ func (h *HiClient) getInitialSyncRoom(ctx context.Context, room *database.Room)
Timeline: make([]database.TimelineRowTuple, 0),
State: map[event.Type]map[string]database.EventRowID{},
Notifications: make([]SyncNotification, 0),
AccountData: make(map[event.Type]*database.AccountData),
}
if room.PreviewEventRowID != 0 {
previewEvent, err := h.DB.Event.GetByRowID(ctx, room.PreviewEventRowID)
@ -53,7 +54,7 @@ func (h *HiClient) getInitialSyncRoom(ctx context.Context, room *database.Room)
func (h *HiClient) GetInitialSync(ctx context.Context, batchSize int) iter.Seq[*SyncComplete] {
return func(yield func(*SyncComplete) bool) {
maxTS := time.Now().Add(1 * time.Hour)
for {
for i := 0; ; i++ {
rooms, err := h.DB.Room.GetBySortTS(ctx, maxTS, batchSize)
if err != nil {
if ctx.Err() == nil {
@ -65,6 +66,19 @@ func (h *HiClient) GetInitialSync(ctx context.Context, batchSize int) iter.Seq[*
Rooms: make(map[id.RoomID]*SyncRoom, len(rooms)-1),
LeftRooms: make([]id.RoomID, 0),
}
if i == 0 {
ad, err := h.DB.AccountData.GetAllGlobal(ctx, h.Account.UserID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get global account data")
return
}
payload.AccountData = make(map[event.Type]*database.AccountData, len(ad))
for _, data := range ad {
payload.AccountData[event.Type{Type: data.Type, Class: event.AccountDataEventType}] = data
}
} else {
payload.AccountData = make(map[event.Type]*database.AccountData)
}
for _, room := range rooms {
if room.SortingTimestamp == rooms[len(rooms)-1].SortingTimestamp {
break

View file

@ -104,9 +104,11 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
accountData := make(map[event.Type]*database.AccountData, len(resp.AccountData.Events))
var err error
for _, evt := range resp.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
accountData[evt.Type], err = h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
@ -120,6 +122,7 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy
}
}
}
ctx.Value(syncContextKey).(*syncContext).evt.AccountData = accountData
for roomID, room := range resp.Rooms.Join {
err := h.processSyncJoinedRoom(ctx, roomID, room)
if err != nil {
@ -133,7 +136,7 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy
}
}
h.Account.NextBatch = resp.NextBatch
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
err = h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
if err != nil {
return fmt.Errorf("failed to save next_batch: %w", err)
}
@ -179,10 +182,11 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
}
}
accountData := make(map[event.Type]*database.AccountData, len(room.AccountData.Events))
for _, evt := range room.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
evt.RoomID = roomID
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
accountData[evt.Type], err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
@ -216,6 +220,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
&room.Summary,
receiptsList,
newOwnReceipts,
accountData,
)
if err != nil {
return err
@ -513,6 +518,7 @@ func (h *HiClient) processStateAndTimeline(
summary *mautrix.LazyLoadSummary,
receipts []*database.Receipt,
newOwnReceipts []id.EventID,
accountData map[event.Type]*database.AccountData,
) error {
updatedRoom := &database.Room{
ID: room.ID,
@ -747,10 +753,11 @@ func (h *HiClient) processStateAndTimeline(
}
}
// TODO why is *old* unread count sometimes zero when processing the read receipt that is making it zero?
if roomChanged || len(newOwnReceipts) > 0 || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
if roomChanged || len(accountData) > 0 || len(newOwnReceipts) > 0 || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
Meta: room,
Timeline: timelineRowTuples,
AccountData: accountData,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,

View file

@ -14,7 +14,9 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import {
DBAccountData,
DBRoom,
DBRoomAccountData,
EventRowID,
RawDBEvent,
TimelineRowTuple,
@ -71,6 +73,7 @@ export interface SyncRoom {
state: Record<EventType, Record<string, EventRowID>>
reset: boolean
notifications: SyncNotification[]
account_data: Record<EventType, DBRoomAccountData>
}
export interface SyncNotification {
@ -81,6 +84,7 @@ export interface SyncNotification {
export interface SyncCompleteData {
rooms: Record<RoomID, SyncRoom>
left_rooms: RoomID[]
account_data: Record<EventType, DBAccountData>
}
export interface SyncCompleteEvent extends RPCCommand<SyncCompleteData> {

View file

@ -129,7 +129,13 @@ export interface MemDBEvent extends BaseDBEvent {
export interface DBAccountData {
user_id: UserID
room_id?: RoomID
type: EventType
content: unknown
}
export interface DBRoomAccountData {
user_id: UserID
room_id: RoomID
type: EventType
content: unknown
}