gomuks/pkg/hicli/database/accountdata.go

92 lines
2.8 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"encoding/json"
"unsafe"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
upsertAccountDataQuery = `
INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
`
upsertRoomAccountDataQuery = `
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
`
getGlobalAccountDataQuery = `
SELECT user_id, '', type, content FROM account_data WHERE user_id = $1
`
getRoomAccountDataQuery = `
SELECT user_id, room_id, type, content FROM room_account_data WHERE user_id = $1 AND room_id = $2
`
)
type AccountDataQuery struct {
*dbutil.QueryHelper[*AccountData]
}
func unsafeJSONString(content json.RawMessage) *string {
if content == nil {
return nil
}
str := unsafe.String(unsafe.SliceData(content), len(content))
return &str
}
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) (*AccountData, error) {
ad := &AccountData{
UserID: userID,
Type: eventType.Type,
Content: content,
}
return ad, adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) (*AccountData, error) {
ad := &AccountData{
UserID: userID,
RoomID: roomID,
Type: eventType.Type,
Content: content,
}
return ad, adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) GetAllGlobal(ctx context.Context, userID id.UserID) ([]*AccountData, error) {
return adq.QueryMany(ctx, getGlobalAccountDataQuery, userID)
}
func (adq *AccountDataQuery) GetAllRoom(ctx context.Context, userID id.UserID, roomID id.RoomID) ([]*AccountData, error) {
return adq.QueryMany(ctx, getRoomAccountDataQuery, userID, roomID)
}
type AccountData struct {
UserID id.UserID `json:"user_id"`
RoomID id.RoomID `json:"room_id,omitempty"`
Type string `json:"type"`
Content json.RawMessage `json:"content"`
}
func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
var roomID sql.NullString
err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
if err != nil {
return nil, err
}
a.RoomID = id.RoomID(roomID.String)
return a, nil
}