gomuks/pkg/hicli/send.go

317 lines
10 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/yuin/goldmark"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
"go.mau.fi/gomuks/pkg/rainbow"
)
var (
rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension))
)
func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, base *event.MessageEventContent, text string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) {
var content event.MessageEventContent
if strings.HasPrefix(text, "/rainbow ") {
text = strings.TrimPrefix(text, "/rainbow ")
content = format.RenderMarkdownCustom(text, rainbowWithHTML)
content.FormattedBody = rainbow.ApplyColor(content.FormattedBody)
} else if strings.HasPrefix(text, "/plain ") {
text = strings.TrimPrefix(text, "/plain ")
content = format.RenderMarkdown(text, false, false)
} else if strings.HasPrefix(text, "/html ") {
text = strings.TrimPrefix(text, "/html ")
content = format.RenderMarkdown(text, false, true)
} else if text != "" {
content = format.RenderMarkdown(text, true, false)
}
if base != nil {
if text != "" {
base.Body = content.Body
base.Format = content.Format
base.FormattedBody = content.FormattedBody
}
content = *base
}
if content.Mentions == nil {
content.Mentions = &event.Mentions{}
}
if mentions != nil {
content.Mentions.Room = mentions.Room
for _, userID := range mentions.UserIDs {
if userID != h.Account.UserID {
content.Mentions.Add(userID)
}
}
}
if replyTo != "" {
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo)
}
return h.Send(ctx, roomID, event.EventMessage, &content)
}
func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error {
content := &mautrix.ReqSetReadMarkers{
FullyRead: eventID,
}
if receiptType == event.ReceiptTypeRead {
content.Read = eventID
} else if receiptType == event.ReceiptTypeReadPrivate {
content.ReadPrivate = eventID
} else {
return fmt.Errorf("invalid receipt type: %v", receiptType)
}
err := h.Client.SetReadMarkers(ctx, roomID, content)
if err != nil {
return fmt.Errorf("failed to mark event as read: %w", err)
}
return nil
}
func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error {
_, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout)
return err
}
func (h *HiClient) Send(
ctx context.Context,
roomID id.RoomID,
evtType event.Type,
content any,
) (*database.Event, error) {
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room metadata: %w", err)
} else if room == nil {
return nil, fmt.Errorf("unknown room")
}
txnID := "hicli-" + h.Client.TxnID()
dbEvt := &database.Event{
RoomID: room.ID,
ID: id.EventID(fmt.Sprintf("~%s", txnID)),
Sender: h.Account.UserID,
Timestamp: jsontime.UnixMilliNow(),
Unsigned: []byte("{}"),
TransactionID: txnID,
DecryptionError: "",
SendError: "not sent",
Reactions: map[string]int{},
LastEditRowID: ptr.Ptr(database.EventRowID(0)),
}
if room.EncryptionEvent != nil && evtType != event.EventReaction {
dbEvt.Type = event.EventEncrypted.Type
dbEvt.DecryptedType = evtType.Type
dbEvt.Decrypted, err = json.Marshal(content)
if err != nil {
return nil, fmt.Errorf("failed to marshal event content: %w", err)
}
dbEvt.Content = json.RawMessage("{}")
dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Decrypted)
} else {
dbEvt.Type = evtType.Type
dbEvt.Content, err = json.Marshal(content)
if err != nil {
return nil, fmt.Errorf("failed to marshal event content: %w", err)
}
dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content)
}
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix())
_, err = h.DB.Event.Insert(ctx, dbEvt)
if err != nil {
return nil, fmt.Errorf("failed to insert event into database: %w", err)
}
ctx = context.WithoutCancel(ctx)
go func() {
err := h.SetTyping(ctx, room.ID, 0)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message")
}
}()
go func() {
var err error
defer func() {
if dbEvt.SendError != "" {
err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError)
if err2 != nil {
zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err).
Msg("Failed to update send error in database after sending failed")
}
}
h.EventHandler(&SendComplete{
Event: dbEvt,
Error: err,
})
}()
if dbEvt.Decrypted != nil {
var encryptedContent *event.EncryptedEventContent
encryptedContent, err = h.Encrypt(ctx, room, evtType, dbEvt.Decrypted)
if err != nil {
dbEvt.SendError = fmt.Sprintf("failed to encrypt: %v", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to encrypt event")
return
}
evtType = event.EventEncrypted
dbEvt.MegolmSessionID = encryptedContent.SessionID
dbEvt.Content, err = json.Marshal(encryptedContent)
if err != nil {
dbEvt.SendError = fmt.Sprintf("failed to marshal encrypted content: %v", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to marshal encrypted content")
return
}
err = h.DB.Event.UpdateEncryptedContent(ctx, dbEvt)
if err != nil {
dbEvt.SendError = fmt.Sprintf("failed to save event after encryption: %v", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to save event after encryption")
return
}
}
var resp *mautrix.RespSendEvent
resp, err = h.Client.SendMessageEvent(ctx, room.ID, evtType, dbEvt.Content, mautrix.ReqSendEvent{
Timestamp: dbEvt.Timestamp.UnixMilli(),
TransactionID: txnID,
DontEncrypt: true,
})
if err != nil {
dbEvt.SendError = err.Error()
err = fmt.Errorf("failed to send event: %w", err)
return
}
dbEvt.ID = resp.EventID
err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID)
if err != nil {
err = fmt.Errorf("failed to update event ID in database: %w", err)
}
}()
return dbEvt, nil
}
func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
h.encryptLock.Lock()
defer h.encryptLock.Unlock()
encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content)
if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) {
if err = h.shareGroupSession(ctx, room); err != nil {
err = fmt.Errorf("failed to share group session: %w", err)
} else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil {
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return
}
func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error {
h.encryptLock.Lock()
defer h.encryptLock.Unlock()
if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return nil
} else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil {
return fmt.Errorf("failed to get room metadata: %w", err)
} else if roomMeta == nil {
return fmt.Errorf("unknown room")
} else {
return h.shareGroupSession(ctx, roomMeta)
}
}
func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
if room.HasMemberList {
return nil
}
resp, err := h.Client.Members(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to get room member list: %w", err)
}
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, nil, nil, true)
if err != nil {
return err
}
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
Membership: event.Membership(evt.Content.Raw["membership"].(string)),
}
}
err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
if err != nil {
return err
}
return h.DB.Room.Upsert(ctx, &database.Room{
ID: room.ID,
HasMemberList: true,
})
})
if err != nil {
return fmt.Errorf("failed to process room member list: %w", err)
}
return nil
}
func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error {
err := h.loadMembers(ctx, room)
if err != nil {
return err
}
shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID)
var users []id.UserID
if shareToInvited {
users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID)
} else {
users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID)
}
if err != nil {
return fmt.Errorf("failed to get room member list: %w", err)
} else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
return nil
}
func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool {
historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "")
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event")
return false
}
mautrixEvt := historyVisibility.AsRawMautrix()
err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event")
return false
}
hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent)
if !ok {
zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event")
return false
}
return hv.HistoryVisibility == event.HistoryVisibilityInvited ||
hv.HistoryVisibility == event.HistoryVisibilityShared ||
hv.HistoryVisibility == event.HistoryVisibilityWorldReadable
}