diff --git a/pkg/hicli/database/accountdata.go b/pkg/hicli/database/accountdata.go index 51233a4..01d9aca 100644 --- a/pkg/hicli/database/accountdata.go +++ b/pkg/hicli/database/accountdata.go @@ -30,6 +30,9 @@ const ( getGlobalAccountDataQuery = ` SELECT user_id, '', type, content FROM account_data WHERE user_id = $1 ` + getRoomAccountDataQuery = ` + SELECT user_id, room_id, type, content FROM room_account_data WHERE user_id = $1 AND room_id = $2 + ` ) type AccountDataQuery struct { @@ -67,6 +70,10 @@ func (adq *AccountDataQuery) GetAllGlobal(ctx context.Context, userID id.UserID) return adq.QueryMany(ctx, getGlobalAccountDataQuery, userID) } +func (adq *AccountDataQuery) GetAllRoom(ctx context.Context, userID id.UserID, roomID id.RoomID) ([]*AccountData, error) { + return adq.QueryMany(ctx, getRoomAccountDataQuery, userID, roomID) +} + type AccountData struct { UserID id.UserID `json:"user_id"` RoomID id.RoomID `json:"room_id,omitempty"` diff --git a/pkg/hicli/init.go b/pkg/hicli/init.go index c1354c8..ae124a1 100644 --- a/pkg/hicli/init.go +++ b/pkg/hicli/init.go @@ -19,17 +19,26 @@ func (h *HiClient) getInitialSyncRoom(ctx context.Context, room *database.Room) Timeline: make([]database.TimelineRowTuple, 0), State: map[event.Type]map[string]database.EventRowID{}, Notifications: make([]SyncNotification, 0), - AccountData: make(map[event.Type]*database.AccountData), + } + ad, err := h.DB.AccountData.GetAllRoom(ctx, h.Account.UserID, room.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("room_id", room.ID).Msg("Failed to get room account data") + syncRoom.AccountData = make(map[event.Type]*database.AccountData) + } else { + syncRoom.AccountData = make(map[event.Type]*database.AccountData, len(ad)) + for _, data := range ad { + syncRoom.AccountData[event.Type{Type: data.Type, Class: event.AccountDataEventType}] = data + } } if room.PreviewEventRowID != 0 { previewEvent, err := h.DB.Event.GetByRowID(ctx, room.PreviewEventRowID) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get preview event for room") + zerolog.Ctx(ctx).Err(err).Stringer("room_id", room.ID).Msg("Failed to get preview event for room") } else if previewEvent != nil { h.ReprocessExistingEvent(ctx, previewEvent) previewMember, err := h.DB.CurrentState.Get(ctx, room.ID, event.StateMember, previewEvent.Sender.String()) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get preview member event for room") + zerolog.Ctx(ctx).Err(err).Stringer("room_id", room.ID).Msg("Failed to get preview member event for room") } else if previewMember != nil { syncRoom.Events = append(syncRoom.Events, previewMember) syncRoom.State[event.StateMember] = map[string]database.EventRowID{ @@ -39,7 +48,7 @@ func (h *HiClient) getInitialSyncRoom(ctx context.Context, room *database.Room) if previewEvent.LastEditRowID != nil { lastEdit, err := h.DB.Event.GetByRowID(ctx, *previewEvent.LastEditRowID) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get last edit for preview event") + zerolog.Ctx(ctx).Err(err).Stringer("room_id", room.ID).Msg("Failed to get last edit for preview event") } else if lastEdit != nil { h.ReprocessExistingEvent(ctx, lastEdit) syncRoom.Events = append(syncRoom.Events, lastEdit)