From 1db1d2db5c76336352c6224f55a09a4fe415526f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 20:22:53 +0300 Subject: [PATCH] all: move hicli from mautrix-go and add more features --- go.mod | 16 +- go.sum | 13 +- gomuks.go | 4 +- main.go | 3 +- media.go | 4 +- pkg/hicli/LICENSE | 373 ++++++++ pkg/hicli/cryptohelper.go | 64 ++ pkg/hicli/database/account.go | 73 ++ pkg/hicli/database/accountdata.go | 67 ++ pkg/hicli/database/cachedmedia.go | 149 +++ pkg/hicli/database/database.go | 73 ++ pkg/hicli/database/event.go | 509 +++++++++++ pkg/hicli/database/receipt.go | 81 ++ pkg/hicli/database/room.go | 278 ++++++ pkg/hicli/database/sessionrequest.go | 68 ++ pkg/hicli/database/state.go | 94 ++ pkg/hicli/database/statestore.go | 187 ++++ pkg/hicli/database/timeline.go | 135 +++ .../database/upgrades/00-latest-revision.sql | 255 ++++++ .../upgrades/02-explicit-avatar-flag.sql | 2 + .../upgrades/03-more-event-fields.sql | 6 + pkg/hicli/database/upgrades/upgrades.go | 22 + pkg/hicli/decryptionqueue.go | 209 +++++ pkg/hicli/events.go | 60 ++ pkg/hicli/hicli.go | 251 ++++++ pkg/hicli/hitest/hitest.go | 110 +++ pkg/hicli/html.go | 493 ++++++++++ pkg/hicli/json-commands.go | 179 ++++ pkg/hicli/json.go | 119 +++ pkg/hicli/login.go | 88 ++ pkg/hicli/paginate.go | 245 +++++ pkg/hicli/pushrules.go | 80 ++ pkg/hicli/send.go | 287 ++++++ pkg/hicli/sync.go | 850 ++++++++++++++++++ pkg/hicli/syncwrap.go | 96 ++ pkg/hicli/verify.go | 161 ++++ pkg/rainbow/goldmark.go | 120 +++ pkg/rainbow/gradient.go | 56 ++ server.go | 47 +- web/src/App.tsx | 2 + web/src/api/client.ts | 2 + web/src/api/statestore/main.ts | 52 +- web/src/api/statestore/room.ts | 2 + web/src/api/types/hievents.ts | 13 +- web/src/api/types/hitypes.ts | 17 + web/src/ui/MainScreen.tsx | 1 + web/src/ui/roomlist/Entry.tsx | 5 + web/src/ui/roomlist/RoomList.css | 20 + web/src/ui/timeline/TimelineEvent.css | 15 +- web/src/ui/timeline/TimelineEvent.tsx | 6 +- web/src/ui/timeline/content/MessageBody.tsx | 16 +- web/src/util/focus.ts | 2 +- websocket.go | 33 +- 53 files changed, 6068 insertions(+), 45 deletions(-) create mode 100644 pkg/hicli/LICENSE create mode 100644 pkg/hicli/cryptohelper.go create mode 100644 pkg/hicli/database/account.go create mode 100644 pkg/hicli/database/accountdata.go create mode 100644 pkg/hicli/database/cachedmedia.go create mode 100644 pkg/hicli/database/database.go create mode 100644 pkg/hicli/database/event.go create mode 100644 pkg/hicli/database/receipt.go create mode 100644 pkg/hicli/database/room.go create mode 100644 pkg/hicli/database/sessionrequest.go create mode 100644 pkg/hicli/database/state.go create mode 100644 pkg/hicli/database/statestore.go create mode 100644 pkg/hicli/database/timeline.go create mode 100644 pkg/hicli/database/upgrades/00-latest-revision.sql create mode 100644 pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql create mode 100644 pkg/hicli/database/upgrades/03-more-event-fields.sql create mode 100644 pkg/hicli/database/upgrades/upgrades.go create mode 100644 pkg/hicli/decryptionqueue.go create mode 100644 pkg/hicli/events.go create mode 100644 pkg/hicli/hicli.go create mode 100644 pkg/hicli/hitest/hitest.go create mode 100644 pkg/hicli/html.go create mode 100644 pkg/hicli/json-commands.go create mode 100644 pkg/hicli/json.go create mode 100644 pkg/hicli/login.go create mode 100644 pkg/hicli/paginate.go create mode 100644 pkg/hicli/pushrules.go create mode 100644 pkg/hicli/send.go create mode 100644 pkg/hicli/sync.go create mode 100644 pkg/hicli/syncwrap.go create mode 100644 pkg/hicli/verify.go create mode 100644 pkg/rainbow/goldmark.go create mode 100644 pkg/rainbow/gradient.go diff --git a/go.mod b/go.mod index 425dbaf..485d8b3 100644 --- a/go.mod +++ b/go.mod @@ -5,33 +5,35 @@ go 1.23.0 toolchain go1.23.2 require ( + github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.12 + github.com/lucasb-eyer/go-colorful v1.2.0 github.com/mattn/go-sqlite3 v1.14.24 + github.com/rivo/uniseg v0.4.7 github.com/rs/zerolog v1.33.0 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 + github.com/yuin/goldmark v1.7.7 go.mau.fi/util v0.8.1 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.28.0 + golang.org/x/net v0.30.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 - maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455 + maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297 + mvdan.cc/xurls/v2 v2.5.0 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect - github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/xid v1.6.0 // indirect - github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.7.7 // indirect golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect - golang.org/x/net v0.30.0 // indirect golang.org/x/sys v0.26.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 2f19861..b2b3704 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,12 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= @@ -53,6 +59,7 @@ golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBn golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -66,5 +73,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455 h1:6UznUe9nDojckcvBNq0h1vZFM/KmA7Xs24YowG8iE+4= -maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455/go.mod h1:7F/S6XAdyc/6DW+Q7xyFXRSPb6IjfqMb1OMepQ8C8OE= +maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297 h1:8CybV+x9HPh4p41nIqJMKHI0bUF0g0bEozq49ytNVlc= +maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297/go.mod h1:sjCZR1R/3NET/WjkcXPL6WpAHlWKku9HjRsdOkbM8Qw= +mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= +mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= diff --git a/gomuks.go b/gomuks.go index 4d58cf7..7f35968 100644 --- a/gomuks.go +++ b/gomuks.go @@ -34,7 +34,8 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exerrors" "go.mau.fi/util/exzerolog" - "maunium.net/go/mautrix/hicli" + + "go.mau.fi/gomuks/pkg/hicli" ) type Gomuks struct { @@ -144,6 +145,7 @@ func (gmx *Gomuks) SetupLog() { } func (gmx *Gomuks) StartClient() { + hicli.HTMLSanitizerImgSrcTemplate = "_gomuks/media/%s/%s" rawDB, err := dbutil.NewFromConfig("gomuks", dbutil.Config{ PoolConfig: dbutil.PoolConfig{ Type: "sqlite3-fk-wal", diff --git a/main.go b/main.go index 9706301..5c2550c 100644 --- a/main.go +++ b/main.go @@ -27,7 +27,8 @@ import ( _ "go.mau.fi/util/dbutil/litestream" flag "maunium.net/go/mauflag" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/hicli" + + "go.mau.fi/gomuks/pkg/hicli" ) var ( diff --git a/media.go b/media.go index e5bc21b..f36989b 100644 --- a/media.go +++ b/media.go @@ -16,10 +16,10 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/jsontime" "go.mau.fi/util/ptr" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" ) var ErrBadGateway = mautrix.RespError{ diff --git a/pkg/hicli/LICENSE b/pkg/hicli/LICENSE new file mode 100644 index 0000000..a612ad9 --- /dev/null +++ b/pkg/hicli/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/pkg/hicli/cryptohelper.go b/pkg/hicli/cryptohelper.go new file mode 100644 index 0000000..bc5a00a --- /dev/null +++ b/pkg/hicli/cryptohelper.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type hiCryptoHelper HiClient + +var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) + +func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") + } + return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content) +} + +func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return h.Crypto.DecryptMegolmEvent(ctx, evt) +} + +func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + userID: {deviceID}, + h.Account.UserID: {"*"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Failed to send room key request") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Sent room key request") + } +} + +func (h *hiCryptoHelper) Init(ctx context.Context) error { + return nil +} diff --git a/pkg/hicli/database/account.go b/pkg/hicli/database/account.go new file mode 100644 index 0000000..8bb4bfd --- /dev/null +++ b/pkg/hicli/database/account.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +const ( + getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1` + putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2` + upsertAccountQuery = ` + INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id) + DO UPDATE SET device_id = excluded.device_id, + access_token = excluded.access_token, + homeserver_url = excluded.homeserver_url, + next_batch = excluded.next_batch + ` +) + +type AccountQuery struct { + *dbutil.QueryHelper[*Account] +} + +func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { + var exists bool + if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists { + return + } + err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) { + return aq.QueryOne(ctx, getAccountQuery, userID) +} + +func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error { + return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID) +} + +func (aq *AccountQuery) Put(ctx context.Context, account *Account) error { + return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...) +} + +type Account struct { + UserID id.UserID + DeviceID id.DeviceID + AccessToken string + HomeserverURL string + NextBatch string +} + +func (a *Account) Scan(row dbutil.Scannable) (*Account, error) { + return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch)) +} + +func (a *Account) sqlVariables() []any { + return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch} +} diff --git a/pkg/hicli/database/accountdata.go b/pkg/hicli/database/accountdata.go new file mode 100644 index 0000000..55c7826 --- /dev/null +++ b/pkg/hicli/database/accountdata.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/json" + "unsafe" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertAccountDataQuery = ` + INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3) + ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content + ` + upsertRoomAccountDataQuery = ` + INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content + ` +) + +type AccountDataQuery struct { + *dbutil.QueryHelper[*AccountData] +} + +func unsafeJSONString(content json.RawMessage) *string { + if content == nil { + return nil + } + str := unsafe.String(unsafe.SliceData(content), len(content)) + return &str +} + +func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content)) +} + +func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content)) +} + +type AccountData struct { + UserID id.UserID `json:"user_id"` + RoomID id.RoomID `json:"room_id,omitempty"` + Type string `json:"type"` + Content json.RawMessage `json:"content"` +} + +func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { + var roomID sql.NullString + err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content)) + if err != nil { + return nil, err + } + a.RoomID = id.RoomID(roomID.String) + return a, nil +} diff --git a/pkg/hicli/database/cachedmedia.go b/pkg/hicli/database/cachedmedia.go new file mode 100644 index 0000000..35cb9a9 --- /dev/null +++ b/pkg/hicli/database/cachedmedia.go @@ -0,0 +1,149 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "net/http" + "slices" + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +const ( + insertCachedMediaQuery = ` + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (mxc) DO NOTHING + ` + upsertCachedMediaQuery = ` + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (mxc) DO UPDATE + SET enc_file = excluded.enc_file, + file_name = excluded.file_name, + mime_type = excluded.mime_type, + size = excluded.size, + hash = excluded.hash, + error = excluded.error + WHERE excluded.error IS NULL OR cached_media.hash IS NULL + ` + getCachedMediaQuery = ` + SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error + FROM cached_media + WHERE mxc = $1 + ` +) + +type CachedMediaQuery struct { + *dbutil.QueryHelper[*CachedMedia] +} + +func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error { + return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...) +} + +func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error { + return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...) +} + +func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) { + return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc) +} + +type MediaError struct { + Matrix *mautrix.RespError `json:"data"` + StatusCode int `json:"status_code"` + ReceivedAt jsontime.UnixMilli `json:"received_at"` + Attempts int `json:"attempts"` +} + +const MaxMediaBackoff = 7 * 24 * time.Hour + +func (me *MediaError) backoff() time.Duration { + return min(time.Duration(2< 0 { + err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts}) + if err != nil { + return err + } + } + } + return nil + }) +} + +type EventRowID int64 + +func (m EventRowID) GetMassInsertValues() [1]any { + return [1]any{m} +} + +type LocalContent struct { + SanitizedHTML string `json:"sanitized_html,omitempty"` + HTMLVersion int `json:"html_version,omitempty"` +} + +type UnreadType int + +func (ut UnreadType) Is(flag UnreadType) bool { + return ut&flag != 0 +} + +const ( + UnreadTypeNone UnreadType = 0b0000 + UnreadTypeNormal UnreadType = 0b0001 + UnreadTypeNotify UnreadType = 0b0010 + UnreadTypeHighlight UnreadType = 0b0100 + UnreadTypeSound UnreadType = 0b1000 +) + +type Event struct { + RowID EventRowID `json:"rowid"` + TimelineRowID TimelineRowID `json:"timeline_rowid"` + + RoomID id.RoomID `json:"room_id"` + ID id.EventID `json:"event_id"` + Sender id.UserID `json:"sender"` + Type string `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Timestamp jsontime.UnixMilli `json:"timestamp"` + + Content json.RawMessage `json:"content"` + Decrypted json.RawMessage `json:"decrypted,omitempty"` + DecryptedType string `json:"decrypted_type,omitempty"` + Unsigned json.RawMessage `json:"unsigned,omitempty"` + LocalContent *LocalContent `json:"local_content,omitempty"` + + TransactionID string `json:"transaction_id,omitempty"` + + RedactedBy id.EventID `json:"redacted_by,omitempty"` + RelatesTo id.EventID `json:"relates_to,omitempty"` + RelationType event.RelationType `json:"relation_type,omitempty"` + + MegolmSessionID id.SessionID `json:"-,omitempty"` + DecryptionError string `json:"decryption_error,omitempty"` + SendError string `json:"send_error,omitempty"` + + Reactions map[string]int `json:"reactions,omitempty"` + LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` + UnreadType UnreadType `json:"unread_type,omitempty"` +} + +func MautrixToEvent(evt *event.Event) *Event { + dbEvt := &Event{ + RoomID: evt.RoomID, + ID: evt.ID, + Sender: evt.Sender, + Type: evt.Type.Type, + StateKey: evt.StateKey, + Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)), + Content: evt.Content.VeryRaw, + MegolmSessionID: getMegolmSessionID(evt), + TransactionID: evt.Unsigned.TransactionID, + } + if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") { + dbEvt.TransactionID = "" + } + dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt) + dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) + if evt.Unsigned.RedactedBecause != nil { + dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID + } + return dbEvt +} + +func (e *Event) AsRawMautrix() *event.Event { + if e == nil { + return nil + } + evt := &event.Event{ + RoomID: e.RoomID, + ID: e.ID, + Sender: e.Sender, + Type: event.Type{Type: e.Type, Class: event.MessageEventType}, + StateKey: e.StateKey, + Timestamp: e.Timestamp.UnixMilli(), + Content: event.Content{VeryRaw: e.Content}, + } + if e.Decrypted != nil { + evt.Content.VeryRaw = e.Decrypted + evt.Type.Type = e.DecryptedType + evt.Mautrix.WasEncrypted = true + } + if e.StateKey != nil { + evt.Type.Class = event.StateEventType + } + _ = json.Unmarshal(e.Unsigned, &evt.Unsigned) + return evt +} + +func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { + var timestamp int64 + var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString + err := row.Scan( + &e.RowID, + &e.TimelineRowID, + &e.RoomID, + &e.ID, + &e.Sender, + &e.Type, + &e.StateKey, + ×tamp, + (*[]byte)(&e.Content), + (*[]byte)(&e.Decrypted), + &decryptedType, + (*[]byte)(&e.Unsigned), + dbutil.JSON{Data: &e.LocalContent}, + &transactionID, + &redactedBy, + &relatesTo, + &relationType, + &megolmSessionID, + &decryptionError, + &sendError, + dbutil.JSON{Data: &e.Reactions}, + &e.LastEditRowID, + &e.UnreadType, + ) + if err != nil { + return nil, err + } + e.Timestamp = jsontime.UM(time.UnixMilli(timestamp)) + e.TransactionID = transactionID.String + e.RedactedBy = id.EventID(redactedBy.String) + e.RelatesTo = id.EventID(relatesTo.String) + e.RelationType = event.RelationType(relationType.String) + e.MegolmSessionID = id.SessionID(megolmSessionID.String) + e.DecryptedType = decryptedType.String + e.DecryptionError = decryptionError.String + e.SendError = sendError.String + return e, nil +} + +var relatesToPath = exgjson.Path("m.relates_to", "event_id") +var relationTypePath = exgjson.Path("m.relates_to", "rel_type") + +func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) { + if evt.StateKey != nil { + return "", "" + } + return GetRelatesToFromBytes(evt.Content.VeryRaw) +} + +func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) { + results := gjson.GetManyBytes(content, relatesToPath, relationTypePath) + if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String { + return id.EventID(results[0].Str), event.RelationType(results[1].Str) + } + return "", "" +} + +func getMegolmSessionID(evt *event.Event) id.SessionID { + if evt.Type != event.EventEncrypted { + return "" + } + res := gjson.GetBytes(evt.Content.VeryRaw, "session_id") + if res.Exists() && res.Type == gjson.String { + return id.SessionID(res.Str) + } + return "" +} + +func (e *Event) sqlVariables() []any { + var reactions any + if e.Reactions != nil { + reactions = e.Reactions + } + return []any{ + e.RoomID, + e.ID, + e.Sender, + e.Type, + e.StateKey, + e.Timestamp.UnixMilli(), + unsafeJSONString(e.Content), + unsafeJSONString(e.Decrypted), + dbutil.StrPtr(e.DecryptedType), + unsafeJSONString(e.Unsigned), + dbutil.JSONPtr(e.LocalContent), + dbutil.StrPtr(e.TransactionID), + dbutil.StrPtr(e.RedactedBy), + dbutil.StrPtr(e.RelatesTo), + dbutil.StrPtr(e.RelationType), + dbutil.StrPtr(e.MegolmSessionID), + dbutil.StrPtr(e.DecryptionError), + dbutil.StrPtr(e.SendError), + dbutil.JSON{Data: reactions}, + e.LastEditRowID, + e.UnreadType, + } +} + +func (e *Event) GetNonPushUnreadType() UnreadType { + if e.RelationType == event.RelReplace { + return UnreadTypeNone + } + switch e.Type { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + case event.EventEncrypted.Type: + switch e.DecryptedType { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + } + } + return UnreadTypeNone +} + +func (e *Event) CanUseForPreview() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || + (e.Type == event.EventEncrypted.Type && + (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) && + e.RelationType != event.RelReplace && e.RedactedBy == "" +} + +func (e *Event) BumpsSortingTimestamp() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && + e.RelationType != event.RelReplace +} diff --git a/pkg/hicli/database/receipt.go b/pkg/hicli/database/receipt.go new file mode 100644 index 0000000..8b20816 --- /dev/null +++ b/pkg/hicli/database/receipt.go @@ -0,0 +1,81 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "slices" + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertReceiptQuery = ` + INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE + SET event_id = excluded.event_id, + timestamp = excluded.timestamp + ` +) + +var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + +type ReceiptQuery struct { + *dbutil.QueryHelper[*Receipt] +} + +func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error { + return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...) +} + +func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error { + if len(receipts) > 1000 { + return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + for receiptChunk := range slices.Chunk(receipts, 200) { + err := rq.PutMany(ctx, roomID, receiptChunk...) + if err != nil { + return err + } + } + return nil + }) + } + query, params := receiptMassInserter.Build([1]any{roomID}, receipts) + return rq.Exec(ctx, query, params...) +} + +type Receipt struct { + RoomID id.RoomID `json:"room_id"` + UserID id.UserID `json:"user_id"` + ReceiptType event.ReceiptType `json:"receipt_type"` + ThreadID event.ThreadID `json:"thread_id"` + EventID id.EventID `json:"event_id"` + Timestamp jsontime.UnixMilli `json:"timestamp"` +} + +func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { + var ts int64 + err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts) + if err != nil { + return nil, err + } + r.Timestamp = jsontime.UM(time.UnixMilli(ts)) + return r, nil +} + +func (r *Receipt) sqlVariables() []any { + return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} + +func (r *Receipt) GetMassInsertValues() [5]any { + return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} diff --git a/pkg/hicli/database/room.go b/pkg/hicli/database/room.go new file mode 100644 index 0000000..fe38f2f --- /dev/null +++ b/pkg/hicli/database/room.go @@ -0,0 +1,278 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getRoomBaseQuery = ` + SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias, + lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, + unread_highlights, unread_notifications, unread_messages, prev_batch + FROM room + ` + getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2` + getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` + ensureRoomExistsQuery = ` + INSERT INTO room (room_id) VALUES ($1) + ON CONFLICT (room_id) DO NOTHING + ` + upsertRoomFromSyncQuery = ` + UPDATE room + SET creation_content = COALESCE(room.creation_content, $2), + name = COALESCE($3, room.name), + name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END, + avatar = COALESCE($5, room.avatar), + explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END, + topic = COALESCE($7, room.topic), + canonical_alias = COALESCE($8, room.canonical_alias), + lazy_load_summary = COALESCE($9, room.lazy_load_summary), + encryption_event = COALESCE($10, room.encryption_event), + has_member_list = room.has_member_list OR $11, + preview_event_rowid = COALESCE($12, room.preview_event_rowid), + sorting_timestamp = COALESCE($13, room.sorting_timestamp), + unread_highlights = COALESCE($14, room.unread_highlights), + unread_notifications = COALESCE($15, room.unread_notifications), + unread_messages = COALESCE($16, room.unread_messages), + prev_batch = COALESCE($17, room.prev_batch) + WHERE room_id = $1 + ` + setRoomPrevBatchQuery = ` + UPDATE room SET prev_batch = $2 WHERE room_id = $1 + ` + updateRoomPreviewIfLaterOnTimelineQuery = ` + UPDATE room + SET preview_event_rowid = $2 + WHERE room_id = $1 + AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1) + > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) + RETURNING preview_event_rowid + ` + recalculateRoomPreviewEventQuery = ` + SELECT rowid + FROM event + WHERE + room_id = $1 + AND (type IN ('m.room.message', 'm.sticker') + OR (type = 'm.room.encrypted' + AND decrypted_type IN ('m.room.message', 'm.sticker'))) + AND relation_type <> 'm.replace' + AND redacted_by IS NULL + ORDER BY timestamp DESC + LIMIT 1 + ` +) + +type RoomQuery struct { + *dbutil.QueryHelper[*Room] +} + +func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { + return rq.QueryOne(ctx, getRoomByIDQuery, roomID) +} + +func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) { + return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit) +} + +func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { + return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) +} + +func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error { + return rq.Exec(ctx, ensureRoomExistsQuery, roomID) +} + +func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error { + return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) +} + +func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) { + var newPreviewRowID EventRowID + err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } else if err == nil { + previewChanged = newPreviewRowID == rowID + } + return +} + +func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) { + err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID) + return +} + +type NameQuality int + +const ( + NameQualityNil NameQuality = iota + NameQualityParticipants + NameQualityCanonicalAlias + NameQualityExplicit +) + +const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete" + +type Room struct { + ID id.RoomID `json:"room_id"` + CreationContent *event.CreateEventContent `json:"creation_content,omitempty"` + + Name *string `json:"name,omitempty"` + NameQuality NameQuality `json:"name_quality"` + Avatar *id.ContentURI `json:"avatar,omitempty"` + ExplicitAvatar bool `json:"explicit_avatar"` + Topic *string `json:"topic,omitempty"` + CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"` + + LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"` + + EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` + HasMemberList bool `json:"has_member_list"` + + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + UnreadHighlights int `json:"unread_highlights"` + UnreadNotifications int `json:"unread_notifications"` + UnreadMessages int `json:"unread_messages"` + + PrevBatch string `json:"prev_batch"` +} + +func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { + if r.Name != nil && r.NameQuality >= other.NameQuality { + other.Name = r.Name + other.NameQuality = r.NameQuality + hasChanges = true + } + if r.Avatar != nil { + other.Avatar = r.Avatar + other.ExplicitAvatar = r.ExplicitAvatar + hasChanges = true + } + if r.Topic != nil { + other.Topic = r.Topic + hasChanges = true + } + if r.CanonicalAlias != nil { + other.CanonicalAlias = r.CanonicalAlias + hasChanges = true + } + if r.LazyLoadSummary != nil { + other.LazyLoadSummary = r.LazyLoadSummary + hasChanges = true + } + if r.EncryptionEvent != nil && other.EncryptionEvent == nil { + other.EncryptionEvent = r.EncryptionEvent + hasChanges = true + } + if r.HasMemberList && !other.HasMemberList { + hasChanges = true + other.HasMemberList = true + } + if r.PreviewEventRowID > other.PreviewEventRowID { + other.PreviewEventRowID = r.PreviewEventRowID + hasChanges = true + } + if r.SortingTimestamp.After(other.SortingTimestamp.Time) { + other.SortingTimestamp = r.SortingTimestamp + hasChanges = true + } + if r.UnreadHighlights != other.UnreadHighlights { + other.UnreadHighlights = r.UnreadHighlights + hasChanges = true + } + if r.UnreadNotifications != other.UnreadNotifications { + other.UnreadNotifications = r.UnreadNotifications + hasChanges = true + } + if r.UnreadMessages != other.UnreadMessages { + other.UnreadMessages = r.UnreadMessages + hasChanges = true + } + if r.PrevBatch != "" && other.PrevBatch == "" { + other.PrevBatch = r.PrevBatch + hasChanges = true + } + return +} + +func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { + var prevBatch sql.NullString + var previewEventRowID, sortingTimestamp sql.NullInt64 + err := row.Scan( + &r.ID, + dbutil.JSON{Data: &r.CreationContent}, + &r.Name, + &r.NameQuality, + &r.Avatar, + &r.ExplicitAvatar, + &r.Topic, + &r.CanonicalAlias, + dbutil.JSON{Data: &r.LazyLoadSummary}, + dbutil.JSON{Data: &r.EncryptionEvent}, + &r.HasMemberList, + &previewEventRowID, + &sortingTimestamp, + &r.UnreadHighlights, + &r.UnreadNotifications, + &r.UnreadMessages, + &prevBatch, + ) + if err != nil { + return nil, err + } + r.PrevBatch = prevBatch.String + r.PreviewEventRowID = EventRowID(previewEventRowID.Int64) + r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64)) + return r, nil +} + +func (r *Room) sqlVariables() []any { + return []any{ + r.ID, + dbutil.JSONPtr(r.CreationContent), + r.Name, + r.NameQuality, + r.Avatar, + r.ExplicitAvatar, + r.Topic, + r.CanonicalAlias, + dbutil.JSONPtr(r.LazyLoadSummary), + dbutil.JSONPtr(r.EncryptionEvent), + r.HasMemberList, + dbutil.NumPtr(r.PreviewEventRowID), + dbutil.UnixMilliPtr(r.SortingTimestamp.Time), + r.UnreadHighlights, + r.UnreadNotifications, + r.UnreadMessages, + dbutil.StrPtr(r.PrevBatch), + } +} + +func (r *Room) BumpSortingTimestamp(evt *Event) bool { + if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) { + return false + } + r.SortingTimestamp = evt.Timestamp + now := time.Now() + if r.SortingTimestamp.After(now) { + r.SortingTimestamp = jsontime.UM(now) + } + return true +} diff --git a/pkg/hicli/database/sessionrequest.go b/pkg/hicli/database/sessionrequest.go new file mode 100644 index 0000000..fabd7c3 --- /dev/null +++ b/pkg/hicli/database/sessionrequest.go @@ -0,0 +1,68 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +const ( + putSessionRequestQueueEntry = ` + INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (session_id) DO UPDATE + SET min_index = MIN(excluded.min_index, session_request.min_index), + backup_checked = excluded.backup_checked OR session_request.backup_checked, + request_sent = excluded.request_sent OR session_request.request_sent + ` + removeSessionRequestQuery = ` + DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2 + ` + getNextSessionsToRequestQuery = ` + SELECT room_id, session_id, sender, min_index, backup_checked, request_sent + FROM session_request + WHERE request_sent = false OR backup_checked = false + ORDER BY backup_checked, rowid + LIMIT $1 + ` +) + +type SessionRequestQuery struct { + *dbutil.QueryHelper[*SessionRequest] +} + +func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) { + return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count) +} + +func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error { + return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex) +} + +func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error { + return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...) +} + +type SessionRequest struct { + RoomID id.RoomID + SessionID id.SessionID + Sender id.UserID + MinIndex uint32 + BackupChecked bool + RequestSent bool +} + +func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) { + return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent)) +} + +func (s *SessionRequest) sqlVariables() []any { + return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent} +} diff --git a/pkg/hicli/database/state.go b/pkg/hicli/database/state.go new file mode 100644 index 0000000..62e6b7b --- /dev/null +++ b/pkg/hicli/database/state.go @@ -0,0 +1,94 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "fmt" + "slices" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + setCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership + ` + addCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + ` + deleteCurrentStateQuery = ` + DELETE FROM current_state WHERE room_id = $1 + ` + getCurrentRoomStateQuery = ` + SELECT event.rowid, -1, + event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type + FROM current_state cs + JOIN event ON cs.event_rowid = event.rowid + WHERE cs.room_id = $1 + ` + getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` +) + +var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)") + +const currentStateMassInsertBatchSize = 1000 + +type CurrentStateEntry struct { + EventType event.Type + StateKey string + EventRowID EventRowID + Membership event.Membership +} + +func (cse *CurrentStateEntry) GetMassInsertValues() [4]any { + return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)} +} + +type CurrentStateQuery struct { + *dbutil.QueryHelper[*Event] +} + +func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { + return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + +func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error { + var err error + if deleteOld { + err = csq.Exec(ctx, deleteCurrentStateQuery, roomID) + if err != nil { + return fmt.Errorf("failed to delete old state: %w", err) + } + } + for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) { + query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk) + err = csq.Exec(ctx, query, params...) + if err != nil { + return err + } + } + return nil +} + +func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { + return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + +func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { + return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) +} + +func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) { + return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID) +} diff --git a/pkg/hicli/database/statestore.go b/pkg/hicli/database/statestore.go new file mode 100644 index 0000000..2f39ff0 --- /dev/null +++ b/pkg/hicli/database/statestore.go @@ -0,0 +1,187 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "slices" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getMembershipQuery = ` + SELECT membership FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2 + ` + getStateEventContentQuery = ` + SELECT event.content FROM current_state cs + LEFT JOIN event ON event.rowid = cs.event_rowid + WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 + ` + getRoomJoinedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join' + ` + getRoomJoinedOrInvitedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') + ` + getHasFetchedMembersQuery = ` + SELECT has_member_list FROM room WHERE room_id = $1 + ` + isRoomEncryptedQuery = ` + SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 + ` + getRoomEncryptionEventQuery = ` + SELECT room.encryption_event FROM room WHERE room_id = $1 + ` + findSharedRoomsQuery = ` + SELECT room_id FROM current_state + WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join' + ` +) + +type ClientStateStore struct { + *Database +} + +var _ mautrix.StateStore = (*ClientStateStore)(nil) +var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil) +var _ crypto.StateStore = (*ClientStateStore)(nil) + +func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipJoin) +} + +func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin) +} + +func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + var membership event.Membership + err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership) + if errors.Is(err, sql.ErrNoRows) { + err = nil + membership = event.MembershipLeave + } + return slices.Contains(allowedMemberships, membership) +} + +func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + content, err := c.TryGetMember(ctx, roomID, userID) + if content == nil { + content = &event.MemberEventContent{Membership: event.MembershipLeave} + } + return content, err +} + +func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + //TODO implement me + panic("implement me") +} + +func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) { + //err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched) + //if errors.Is(err, sql.ErrNoRows) { + // err = nil + //} + //return + return false, fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + return fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + return nil, fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { + err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) { + err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID). + Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { + // TODO for multiuser support, this might need to filter by the local user's membership + rows, err := c.Query(ctx, findSharedRoomsQuery, userID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() +} + +// Update methods are all intentionally no-ops as the state store wants to have the full event + +func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + return nil +} + +func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + return nil +} + +func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + return nil +} + +func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} + +func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + return nil +} diff --git a/pkg/hicli/database/timeline.go b/pkg/hicli/database/timeline.go new file mode 100644 index 0000000..da9c194 --- /dev/null +++ b/pkg/hicli/database/timeline.go @@ -0,0 +1,135 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + "sync" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +const ( + clearTimelineQuery = ` + DELETE FROM timeline WHERE room_id = $1 + ` + appendTimelineQuery = ` + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) + ON CONFLICT DO NOTHING + RETURNING rowid, event_rowid + ` + prependTimelineQuery = ` + INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) + ` + checkTimelineContainsQuery = ` + SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2) + ` + findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` + getTimelineQuery = ` + SELECT event.rowid, timeline.rowid, + event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type + FROM timeline + JOIN event ON event.rowid = timeline.event_rowid + WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2) + ORDER BY timeline.rowid DESC + LIMIT $3 + ` +) + +type TimelineRowID int64 + +type TimelineRowTuple struct { + Timeline TimelineRowID `json:"timeline_rowid"` + Event EventRowID `json:"event_rowid"` +} + +var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) { + err = row.Scan(&trt.Timeline, &trt.Event) + return +}) + +func (trt TimelineRowTuple) GetMassInsertValues() [2]any { + return [2]any{trt.Timeline, trt.Event} +} + +var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") +var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)") + +type TimelineQuery struct { + *dbutil.QueryHelper[*Event] + + minRowID TimelineRowID + minRowIDFound bool + prependLock sync.Mutex +} + +// Clear clears the timeline of a given room. +func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { + return tq.Exec(ctx, clearTimelineQuery, roomID) +} + +func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) { + tq.prependLock.Lock() + defer tq.prependLock.Unlock() + if !tq.minRowIDFound { + err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return + } + if tq.minRowID >= 0 { + // No negative row IDs exist, start at -2 + tq.minRowID = -2 + } else { + // We fetched the lowest row ID, but we want the next available one, so decrement one + tq.minRowID-- + } + tq.minRowIDFound = true + } + startFrom = tq.minRowID + tq.minRowID -= TimelineRowID(count) + return +} + +// Prepend adds the given event row IDs to the beginning of the timeline. +// The events must be sorted in reverse chronological order (newest event first). +func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) { + var startFrom TimelineRowID + startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs)) + if err != nil { + return + } + prependEntries = make([]TimelineRowTuple, len(rowIDs)) + for i, rowID := range rowIDs { + prependEntries[i] = TimelineRowTuple{ + Timeline: startFrom - TimelineRowID(i), + Event: rowID, + } + } + query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) + err = tq.Exec(ctx, query, params...) + return +} + +// Append adds the given event row IDs to the end of the timeline. +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) { + query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) + return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList() +} + +func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { + return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) +} + +func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) { + err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists) + return +} diff --git a/pkg/hicli/database/upgrades/00-latest-revision.sql b/pkg/hicli/database/upgrades/00-latest-revision.sql new file mode 100644 index 0000000..0808a6e --- /dev/null +++ b/pkg/hicli/database/upgrades/00-latest-revision.sql @@ -0,0 +1,255 @@ +-- v0 -> v3 (compatible with v1+): Latest revision +CREATE TABLE account ( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + access_token TEXT NOT NULL, + homeserver_url TEXT NOT NULL, + + next_batch TEXT NOT NULL +) STRICT; + +CREATE TABLE room ( + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, + + name TEXT, + name_quality INTEGER NOT NULL DEFAULT 0, + avatar TEXT, + explicit_avatar INTEGER NOT NULL DEFAULT 0, + topic TEXT, + canonical_alias TEXT, + lazy_load_summary TEXT, + + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, + + preview_event_rowid INTEGER, + sorting_timestamp INTEGER, + unread_highlights INTEGER NOT NULL DEFAULT 0, + unread_notifications INTEGER NOT NULL DEFAULT 0, + unread_messages INTEGER NOT NULL DEFAULT 0, + + prev_batch TEXT, + + CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL +) STRICT; +CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); +CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0); + +CREATE TABLE account_data ( + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, type) +) STRICT; + +CREATE TABLE room_account_data ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, room_id, type), + CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id); + +CREATE TABLE event ( + rowid INTEGER PRIMARY KEY, + + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + sender TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT, + timestamp INTEGER NOT NULL, + + content TEXT NOT NULL, + decrypted TEXT, + decrypted_type TEXT, + unsigned TEXT NOT NULL, + local_content TEXT, + + transaction_id TEXT, + + redacted_by TEXT, + relates_to TEXT, + relation_type TEXT, + + megolm_session_id TEXT, + decryption_error TEXT, + send_error TEXT, + + reactions TEXT, + last_edit_rowid INTEGER, + unread_type INTEGER NOT NULL DEFAULT 0, + + CONSTRAINT event_id_unique_key UNIQUE (event_id), + CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), + CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX event_room_id_idx ON event (room_id); +CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); +CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); +CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); + +CREATE TRIGGER event_update_redacted_by + AFTER INSERT + ON event + WHEN NEW.type = 'm.room.redaction' +BEGIN + UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts'; +END; + +CREATE TRIGGER event_update_last_edit_when_redacted + AFTER UPDATE + ON event + WHEN OLD.redacted_by IS NULL + AND NEW.redacted_by IS NOT NULL + AND NEW.relation_type = 'm.replace' + AND NEW.state_key IS NULL +BEGIN + UPDATE event + SET last_edit_rowid = COALESCE( + (SELECT rowid + FROM event edit + WHERE edit.room_id = event.room_id + AND edit.relates_to = event.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = event.type + AND edit.sender = event.sender + AND edit.redacted_by IS NULL + AND edit.state_key IS NULL + ORDER BY edit.timestamp DESC + LIMIT 1), + 0) + WHERE event_id = NEW.relates_to + AND last_edit_rowid = NEW.rowid + AND state_key IS NULL + AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')); +END; + +CREATE TRIGGER event_insert_update_last_edit + AFTER INSERT + ON event + WHEN NEW.relation_type = 'm.replace' + AND NEW.redacted_by IS NULL + AND NEW.state_key IS NULL +BEGIN + UPDATE event + SET last_edit_rowid = NEW.rowid + WHERE event_id = NEW.relates_to + AND type = NEW.type + AND sender = NEW.sender + AND state_key IS NULL + AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')) + AND NEW.timestamp > + COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); +END; + +CREATE TRIGGER event_insert_fill_reactions + AFTER INSERT + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) + 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + +CREATE TRIGGER event_redact_fill_reactions + AFTER UPDATE + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NOT NULL + AND OLD.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) - 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + +CREATE TABLE cached_media ( + mxc TEXT NOT NULL PRIMARY KEY, + event_rowid INTEGER, + enc_file TEXT, + file_name TEXT, + mime_type TEXT, + size INTEGER, + hash BLOB, + error TEXT, + + CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL +) STRICT; + +CREATE TABLE session_request ( + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + sender TEXT NOT NULL, + min_index INTEGER NOT NULL, + backup_checked INTEGER NOT NULL DEFAULT false, + request_sent INTEGER NOT NULL DEFAULT false, + + PRIMARY KEY (session_id), + CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX session_request_room_idx ON session_request (room_id); + +CREATE TABLE timeline ( + rowid INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE, + CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid) +) STRICT; +CREATE INDEX timeline_room_id_idx ON timeline (room_id); + +CREATE TABLE current_state ( + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + membership TEXT, + + PRIMARY KEY (room_id, event_type, state_key), + CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) +) STRICT, WITHOUT ROWID; + +CREATE TABLE receipt ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + thread_id TEXT NOT NULL, + event_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + + PRIMARY KEY (room_id, user_id, receipt_type, thread_id), + CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE + -- note: there's no foreign key on event ID because receipts could point at events that are too far in history. +) STRICT; diff --git a/pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql b/pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql new file mode 100644 index 0000000..c11e880 --- /dev/null +++ b/pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql @@ -0,0 +1,2 @@ +-- v2 (compatible with v1+): Add explicit avatar flag to rooms +ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0; diff --git a/pkg/hicli/database/upgrades/03-more-event-fields.sql b/pkg/hicli/database/upgrades/03-more-event-fields.sql new file mode 100644 index 0000000..3e07ad7 --- /dev/null +++ b/pkg/hicli/database/upgrades/03-more-event-fields.sql @@ -0,0 +1,6 @@ +-- v3 (compatible with v1+): Add more fields to events +ALTER TABLE event ADD COLUMN local_content TEXT; +ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0; diff --git a/pkg/hicli/database/upgrades/upgrades.go b/pkg/hicli/database/upgrades/upgrades.go new file mode 100644 index 0000000..9d0bd1a --- /dev/null +++ b/pkg/hicli/database/upgrades/upgrades.go @@ -0,0 +1,22 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package upgrades + +import ( + "embed" + + "go.mau.fi/util/dbutil" +) + +var Table dbutil.UpgradeTable + +//go:embed *.sql +var upgrades embed.FS + +func init() { + Table.RegisterFS(upgrades) +} diff --git a/pkg/hicli/decryptionqueue.go b/pkg/hicli/decryptionqueue.go new file mode 100644 index 0000000..00088b8 --- /dev/null +++ b/pkg/hicli/decryptionqueue.go @@ -0,0 +1,209 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "sync" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { + data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID) + if err != nil { + return nil, err + } else if data == nil { + return nil, nil + } + decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey) + if err != nil { + return nil, err + } + return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted) +} + +func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) { + log := zerolog.Ctx(ctx) + err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex) + if err != nil { + log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session") + } + events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID) + if err != nil { + log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption") + return + } else if len(events) == 0 { + log.Trace().Msg("No events to retry decryption for") + return + } + decrypted := events[:0] + for _, evt := range events { + if evt.Decrypted != nil { + continue + } + + var mautrixEvt *event.Event + mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + if err != nil { + log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") + } else { + decrypted = append(decrypted, evt) + h.postDecryptProcess(ctx, nil, evt, mautrixEvt) + } + } + if len(decrypted) > 0 { + var newPreview database.EventRowID + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range decrypted { + err = h.DB.Event.UpdateDecrypted(ctx, evt) + if err != nil { + return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) + } + if evt.CanUseForPreview() { + var previewChanged bool + previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) + if err != nil { + return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) + } else if previewChanged { + newPreview = evt.RowID + } + } + } + return nil + }) + if err != nil { + log.Err(err).Msg("Failed to save decrypted events") + } else { + h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID}) + } + } +} + +func (h *HiClient) WakeupRequestQueue() { + select { + case h.requestQueueWakeup <- struct{}{}: + default: + } +} + +func (h *HiClient) RunRequestQueue(ctx context.Context) { + log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Starting key request queue") + defer func() { + log.Info().Msg("Stopping key request queue") + }() + for { + err := h.FetchKeysForOutdatedUsers(ctx) + if err != nil { + log.Err(err).Msg("Failed to fetch outdated device lists for tracked users") + } + madeRequests, err := h.RequestQueuedSessions(ctx) + if err != nil { + log.Err(err).Msg("Failed to handle session request queue") + } else if madeRequests { + continue + } + select { + case <-ctx.Done(): + return + case <-h.requestQueueWakeup: + } + } +} + +func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) { + defer doneFunc() + log := zerolog.Ctx(ctx) + if !req.BackupChecked { + sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID) + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to fetch session from key backup") + + // TODO should this have retries instead of just storing it's checked? + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup") + } + } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex { + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup") + } + } else { + log.Debug().Stringer("session_id", req.SessionID). + Msg("Found session with sufficiently low first known index, removing from queue") + err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex()) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue") + } + } + } else { + err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{ + h.Account.UserID: {"*"}, + req.Sender: {"*"}, + }) + //var err error + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to send key request") + } else { + log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request") + req.RequestSent = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request") + } + } + } +} + +const MaxParallelRequests = 5 + +func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) { + sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests) + if err != nil { + return false, fmt.Errorf("failed to get next events to decrypt: %w", err) + } else if len(sessions) == 0 { + return false, nil + } + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, req := range sessions { + go h.requestQueuedSession(ctx, req, wg.Done) + } + wg.Wait() + + return true, err +} + +func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error { + outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx) + if err != nil { + return err + } else if len(outdatedUsers) == 0 { + return nil + } + _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false) + if err != nil { + return err + } + // TODO backoff for users that fail to be fetched? + return nil +} diff --git a/pkg/hicli/events.go b/pkg/hicli/events.go new file mode 100644 index 0000000..df46afd --- /dev/null +++ b/pkg/hicli/events.go @@ -0,0 +1,60 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +type SyncRoom struct { + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + State map[event.Type]map[string]database.EventRowID `json:"state"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` + Notifications []SyncNotification `json:"notifications"` +} + +type SyncNotification struct { + RowID database.EventRowID `json:"event_rowid"` + Sound bool `json:"sound"` +} + +type SyncComplete struct { + Rooms map[id.RoomID]*SyncRoom `json:"rooms"` +} + +func (c *SyncComplete) IsEmpty() bool { + return len(c.Rooms) == 0 +} + +type EventsDecrypted struct { + RoomID id.RoomID `json:"room_id"` + PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"` + Events []*database.Event `json:"events"` +} + +type Typing struct { + RoomID id.RoomID `json:"room_id"` + event.TypingEventContent +} + +type SendComplete struct { + Event *database.Event `json:"event"` + Error error `json:"error"` +} + +type ClientState struct { + IsLoggedIn bool `json:"is_logged_in"` + IsVerified bool `json:"is_verified"` + UserID id.UserID `json:"user_id,omitempty"` + DeviceID id.DeviceID `json:"device_id,omitempty"` + HomeserverURL string `json:"homeserver_url,omitempty"` +} diff --git a/pkg/hicli/hicli.go b/pkg/hicli/hicli.go new file mode 100644 index 0000000..b5e2468 --- /dev/null +++ b/pkg/hicli/hicli.go @@ -0,0 +1,251 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix. +package hicli + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +type HiClient struct { + DB *database.Database + Account *database.Account + Client *mautrix.Client + Crypto *crypto.OlmMachine + CryptoStore *crypto.SQLCryptoStore + ClientStore *database.ClientStateStore + Log zerolog.Logger + + Verified bool + + KeyBackupVersion id.KeyBackupVersion + KeyBackupKey *backup.MegolmBackupKey + + PushRules atomic.Pointer[pushrules.PushRuleset] + + EventHandler func(evt any) + + firstSyncReceived bool + syncingID int + syncLock sync.Mutex + stopSync atomic.Pointer[context.CancelFunc] + encryptLock sync.Mutex + + requestQueueWakeup chan struct{} + + jsonRequestsLock sync.Mutex + jsonRequests map[int64]context.CancelCauseFunc + + paginationInterrupterLock sync.Mutex + paginationInterrupter map[id.RoomID]context.CancelCauseFunc +} + +var ErrTimelineReset = errors.New("got limited timeline sync response") + +func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { + if cryptoDB == nil { + cryptoDB = rawDB + } + if rawDB.Owner == "" { + rawDB.Owner = "hicli" + rawDB.IgnoreForeignTables = true + } + if rawDB.Log == nil { + rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) + } + db := database.New(rawDB) + c := &HiClient{ + DB: db, + Log: log, + + requestQueueWakeup: make(chan struct{}, 1), + jsonRequests: make(map[int64]context.CancelCauseFunc), + paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc), + + EventHandler: evtHandler, + } + c.ClientStore = &database.ClientStateStore{Database: db} + c.Client = &mautrix.Client{ + UserAgent: mautrix.DefaultUserAgent, + Client: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + // This needs to be relatively high to allow initial syncs + ResponseHeaderTimeout: 180 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 180 * time.Second, + }, + Syncer: (*hiSyncer)(c), + Store: (*hiStore)(c), + StateStore: c.ClientStore, + Log: log.With().Str("component", "mautrix client").Logger(), + } + c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) + cryptoLog := log.With().Str("component", "crypto").Logger() + c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) + c.Crypto.SessionReceived = c.handleReceivedMegolmSession + c.Crypto.DisableRatchetTracking = true + c.Crypto.DisableDecryptKeyFetching = true + c.Client.Crypto = (*hiCryptoHelper)(c) + return c +} + +func (h *HiClient) IsLoggedIn() bool { + return h.Account != nil +} + +func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error { + if expectedAccount != nil && userID != expectedAccount.UserID { + panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID")) + } + err := h.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade hicli db: %w", err) + } + err = h.CryptoStore.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade crypto db: %w", err) + } + account, err := h.DB.Account.Get(ctx, userID) + if err != nil { + return err + } else if account == nil && expectedAccount != nil { + err = h.DB.Account.Put(ctx, expectedAccount) + if err != nil { + return err + } + account = expectedAccount + } else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID { + return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID) + } + if account != nil { + zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") + h.Account = account + h.CryptoStore.AccountID = account.UserID.String() + h.CryptoStore.DeviceID = account.DeviceID + h.Client.UserID = account.UserID + h.Client.DeviceID = account.DeviceID + h.Client.AccessToken = account.AccessToken + h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL) + if err != nil { + return err + } + err = h.CheckServerVersions(ctx) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + + h.Verified, err = h.checkIsCurrentDeviceVerified(ctx) + if err != nil { + return err + } + zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status") + if h.Verified { + err = h.loadPrivateKeys(ctx) + if err != nil { + return err + } + go h.Sync() + } + } + return nil +} + +var ErrFailedToCheckServerVersions = errors.New("failed to check server versions") +var ErrOutdatedServer = errors.New("homeserver is outdated") +var MinimumSpecVersion = mautrix.SpecV11 + +func (h *HiClient) CheckServerVersions(ctx context.Context) error { + versions, err := h.Client.Versions(ctx) + if err != nil { + return exerrors.NewDualError(ErrFailedToCheckServerVersions, err) + } else if !versions.Contains(MinimumSpecVersion) { + return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest()) + } + return nil +} + +func (h *HiClient) IsSyncing() bool { + return h.stopSync.Load() != nil +} + +func (h *HiClient) Sync() { + h.Client.StopSync() + if fn := h.stopSync.Load(); fn != nil { + (*fn)() + } + h.syncLock.Lock() + defer h.syncLock.Unlock() + h.syncingID++ + syncingID := h.syncingID + log := h.Log.With(). + Str("action", "sync"). + Int("sync_id", syncingID). + Logger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.stopSync.Store(&cancel) + go h.RunRequestQueue(h.Log.WithContext(ctx)) + go h.LoadPushRules(h.Log.WithContext(ctx)) + ctx = log.WithContext(ctx) + log.Info().Msg("Starting syncing") + err := h.Client.SyncWithContext(ctx) + if err != nil && ctx.Err() == nil { + log.Err(err).Msg("Fatal error in syncer") + } else { + log.Info().Msg("Syncing stopped") + } +} + +func (h *HiClient) LoadPushRules(ctx context.Context) { + rules, err := h.Client.GetPushRules(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules") + return + } + h.PushRules.Store(rules) + zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch") +} + +func (h *HiClient) Stop() { + h.Client.StopSync() + if fn := h.stopSync.Swap(nil); fn != nil { + (*fn)() + } + h.syncLock.Lock() + //lint:ignore SA2001 just acquire the lock to make sure Sync is done + h.syncLock.Unlock() + err := h.DB.Close() + if err != nil { + h.Log.Err(err).Msg("Failed to close database cleanly") + } +} diff --git a/pkg/hicli/hitest/hitest.go b/pkg/hicli/hitest/hitest.go new file mode 100644 index 0000000..820c44e --- /dev/null +++ b/pkg/hicli/hitest/hitest.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package main + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/chzyer/readline" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "go.mau.fi/zeroconfig" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli" +) + +var writerTypeReadline zeroconfig.WriterType = "hitest_readline" + +func main() { + hicli.InitialDeviceDisplayName = "mautrix hitest" + rl := exerrors.Must(readline.New("> ")) + defer func() { + _ = rl.Close() + }() + zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { + return rl.Stdout(), nil + }) + debug := zerolog.DebugLevel + log := exerrors.Must((&zeroconfig.Config{ + MinLevel: &debug, + Writers: []zeroconfig.WriterConfig{{ + Type: writerTypeReadline, + Format: zeroconfig.LogFormatPrettyColored, + }}, + }).Compile()) + exzerolog.SetupDefaults(log) + + rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) + ctx := log.WithContext(context.Background()) + cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) { + _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a) + switch evt := a.(type) { + case *hicli.SyncComplete: + for _, room := range evt.Rooms { + name := "name unset" + if room.Meta.Name != nil { + name = *room.Meta.Name + } + _, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID) + _, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp) + _, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset) + } + case *hicli.EventsDecrypted: + for _, decrypted := range evt.Events { + _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted) + } + if evt.PreviewEventRowID != 0 { + _, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID) + } + case *hicli.Typing: + _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs) + } + }) + userID, _ := cli.DB.Account.GetFirstUserID(ctx) + exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil)) + if !cli.IsLoggedIn() { + rl.SetPrompt("User ID: ") + userID := id.UserID(exerrors.Must(rl.Readline())) + _, serverName := exerrors.Must2(userID.Parse()) + discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName)) + password := exerrors.Must(rl.ReadPassword("Password: ")) + recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) + exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) + } + rl.SetPrompt("> ") + + for { + line, err := rl.Readline() + if err != nil { + break + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + switch strings.ToLower(fields[0]) { + case "send": + resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{ + Body: strings.Join(fields[2:], " "), + MsgType: event.MsgText, + }) + _, _ = fmt.Fprintln(rl, err) + _, _ = fmt.Fprintf(rl, "%+v\n", resp) + } + } + cli.Stop() +} diff --git a/pkg/hicli/html.go b/pkg/hicli/html.go new file mode 100644 index 0000000..b0b7df1 --- /dev/null +++ b/pkg/hicli/html.go @@ -0,0 +1,493 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + + "golang.org/x/net/html" + "golang.org/x/net/html/atom" + "maunium.net/go/mautrix/id" + "mvdan.cc/xurls/v2" +) + +func tagIsAllowed(tag atom.Atom) bool { + switch tag { + case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P, + atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong, + atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody, + atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img, + atom.Details, atom.Summary: + return true + default: + return false + } +} + +func isSelfClosing(tag atom.Atom) bool { + switch tag { + case atom.Img, atom.Br, atom.Hr: + return true + default: + return false + } +} + +var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`) +var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`) + +// This is approximately a mirror of web/src/util/mediasize.ts in gomuks +func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) { + if widthInt <= 0 || heightInt <= 0 { + return + } + width = float64(widthInt) + height = float64(heightInt) + const imageContainerWidth float64 = 320 + const imageContainerHeight float64 = 240 + const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight + if width > imageContainerWidth || height > imageContainerHeight { + aspectRatio := width / height + if aspectRatio > imageContainerAspectRatio { + width = imageContainerWidth + height = imageContainerWidth / aspectRatio + } else if aspectRatio < imageContainerAspectRatio { + width = imageContainerHeight * aspectRatio + height = imageContainerHeight + } else { + width = imageContainerWidth + height = imageContainerHeight + } + } + ok = true + return +} + +func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) { + for _, attr := range attrs { + switch attr.Key { + case "src": + src = attr.Val + case "alt": + alt = attr.Val + case "title": + title = attr.Val + case "data-mx-emoticon": + isCustomEmoji = true + case "width": + width, _ = strconv.Atoi(attr.Val) + case "height": + height, _ = strconv.Atoi(attr.Val) + } + } + return +} + +func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) { + for _, attr := range attrs { + switch attr.Key { + case "data-mx-bg-color": + if allowedColorRegex.MatchString(attr.Val) { + bgColor = attr.Val + } + case "data-mx-color", "color": + if allowedColorRegex.MatchString(attr.Val) { + textColor = attr.Val + } + case "data-mx-spoiler": + spoiler = attr.Val + isSpoiler = true + case "data-mx-maths": + maths = attr.Val + } + } + return +} + +func parseAAttributes(attrs []html.Attribute) (href string) { + for _, attr := range attrs { + switch attr.Key { + case "href": + href = strings.TrimSpace(attr.Val) + } + } + return +} + +func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool { + switch tag { + case atom.Ol: + switch attr.Key { + case "start": + _, err := strconv.Atoi(attr.Val) + return err == nil + } + case atom.Code: + switch attr.Key { + case "class": + return languageRegex.MatchString(attr.Val) + } + case atom.Div: + switch attr.Key { + case "data-mx-maths": + return true + } + } + return false +} + +// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them. +var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`) + +func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) { + for i, item := range items { + if item[0] >= minIndex { + return i, item[0], item[1], true + } + } + return -1, -1, -1, false +} + +func writeMention(w *strings.Builder, mention []byte) { + uri := &id.MatrixURI{ + Sigil1: rune(mention[0]), + MXID1: string(mention[1:]), + } + w.WriteString(`') + writeEscapedBytes(w, mention) + w.WriteString("") +} + +func writeURL(w *strings.Builder, addr []byte) { + parsedURL, err := url.Parse(string(addr)) + if err != nil { + writeEscapedBytes(w, addr) + return + } + if parsedURL.Scheme == "" { + parsedURL.Scheme = "https" + } + w.WriteString(`') + writeEscapedBytes(w, addr) + w.WriteString("") +} + +func linkifyAndWriteBytes(w *strings.Builder, s []byte) { + mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1) + urls := xurls.Relaxed().FindAllIndex(s, -1) + minIndex := 0 + for { + mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex) + urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex) + if hasMention && (!hasURL || nextMentionStart <= nextURLStart) { + writeEscapedBytes(w, s[minIndex:nextMentionStart]) + writeMention(w, s[nextMentionStart:nextMentionEnd]) + minIndex = nextMentionEnd + mentions = mentions[mentionIdx:] + } else if hasURL && (!hasMention || nextURLStart < nextMentionStart) { + writeEscapedBytes(w, s[minIndex:nextURLStart]) + writeURL(w, s[nextURLStart:nextURLEnd]) + minIndex = nextURLEnd + urls = urls[urlIdx:] + } else { + break + } + } + writeEscapedBytes(w, s[minIndex:]) +} + +const escapedChars = "&'<>\"\r" + +func writeEscapedBytes(w *strings.Builder, s []byte) { + i := bytes.IndexAny(s, escapedChars) + for i != -1 { + w.Write(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = bytes.IndexAny(s, escapedChars) + } + w.Write(s) +} + +func writeEscapedString(w *strings.Builder, s string) { + i := strings.IndexAny(s, escapedChars) + for i != -1 { + w.WriteString(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = strings.IndexAny(s, escapedChars) + } + w.WriteString(s) +} + +func writeAttribute(w *strings.Builder, key, value string) { + w.WriteByte(' ') + w.WriteString(key) + w.WriteString(`="`) + writeEscapedString(w, value) + w.WriteByte('"') +} + +func matrixURIClassName(uri *id.MatrixURI) string { + switch uri.Sigil1 { + case '@': + return "hicli-matrix-uri hicli-matrix-uri-user" + case '#': + return "hicli-matrix-uri hicli-matrix-uri-room-alias" + case '!': + if uri.Sigil2 == '$' { + return "hicli-matrix-uri hicli-matrix-uri-event-id" + } + return "hicli-matrix-uri hicli-matrix-uri-room-id" + default: + return "hicli-matrix-uri hicli-matrix-uri-unknown" + } +} + +func writeA(w *strings.Builder, attr []html.Attribute) { + w.WriteString("`) + w.WriteString(spoiler) + w.WriteString(" ") + } + w.WriteByte('<') + w.WriteString("span") + if isSpoiler { + writeAttribute(w, "class", "hicli-spoiler") + } + var style string + if bgColor != "" { + style += fmt.Sprintf("background-color: %s;", bgColor) + } + if textColor != "" { + style += fmt.Sprintf("color: %s;", textColor) + } + if style != "" { + writeAttribute(w, "style", style) + } +} + +type tagStack []atom.Atom + +func (ts *tagStack) contains(tags ...atom.Atom) bool { + for i := len(*ts) - 1; i >= 0; i-- { + for _, tag := range tags { + if (*ts)[i] == tag { + return true + } + } + } + return false +} + +func (ts *tagStack) push(tag atom.Atom) { + *ts = append(*ts, tag) +} + +func (ts *tagStack) pop(tag atom.Atom) bool { + if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag { + *ts = (*ts)[:len(*ts)-1] + return true + } + return false +} + +func sanitizeAndLinkifyHTML(body string) (string, error) { + tz := html.NewTokenizer(strings.NewReader(body)) + var built strings.Builder + ts := make(tagStack, 2) +Loop: + for { + switch tz.Next() { + case html.ErrorToken: + err := tz.Err() + if errors.Is(err, io.EOF) { + break Loop + } + return "", err + case html.StartTagToken, html.SelfClosingTagToken: + token := tz.Token() + if !tagIsAllowed(token.DataAtom) { + continue + } + tagIsSelfClosing := isSelfClosing(token.DataAtom) + if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing { + continue + } + switch token.DataAtom { + case atom.A: + writeA(&built, token.Attr) + case atom.Img: + writeImg(&built, token.Attr) + case atom.Span, atom.Font: + writeSpan(&built, token.Attr) + default: + built.WriteByte('<') + built.WriteString(token.Data) + for _, attr := range token.Attr { + if attributeIsAllowed(token.DataAtom, attr) { + writeAttribute(&built, attr.Key, attr.Val) + } + } + } + built.WriteByte('>') + if !tagIsSelfClosing { + ts.push(token.DataAtom) + } + case html.EndTagToken: + tagName, _ := tz.TagName() + tag := atom.Lookup(tagName) + if tagIsAllowed(tag) && ts.pop(tag) { + built.WriteString("') + } + case html.TextToken: + if ts.contains(atom.Pre, atom.Code, atom.A) { + writeEscapedBytes(&built, tz.Text()) + } else { + linkifyAndWriteBytes(&built, tz.Text()) + } + case html.DoctypeToken, html.CommentToken: + // ignore + } + } + slices.Reverse(ts) + for _, t := range ts { + built.WriteString("') + } + return built.String(), nil +} diff --git a/pkg/hicli/json-commands.go b/pkg/hicli/json-commands.go new file mode 100644 index 0000000..378b7b8 --- /dev/null +++ b/pkg/hicli/json-commands.go @@ -0,0 +1,179 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) { + switch req.Command { + case "get_state": + return h.State(), nil + case "cancel": + return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) { + h.jsonRequestsLock.Lock() + cancelTarget, ok := h.jsonRequests[params.RequestID] + h.jsonRequestsLock.Unlock() + if ok { + return false, nil + } + if params.Reason == "" { + cancelTarget(nil) + } else { + cancelTarget(errors.New(params.Reason)) + } + return true, nil + }) + case "send_message": + return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) { + return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions) + }) + case "send_event": + return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) { + return h.Send(ctx, params.RoomID, params.EventType, params.Content) + }) + case "mark_read": + return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) { + return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType) + }) + case "set_typing": + return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) { + return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond) + }) + case "get_event": + return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) { + return h.GetEvent(ctx, params.RoomID, params.EventID) + }) + case "get_events_by_rowids": + return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) { + return h.GetEventsByRowIDs(ctx, params.RowIDs) + }) + case "get_room_state": + return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) { + return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch) + }) + case "paginate": + return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { + return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) + }) + case "paginate_server": + return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { + return h.PaginateServer(ctx, params.RoomID, params.Limit) + }) + case "ensure_group_session_shared": + return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) { + return true, h.EnsureGroupSessionShared(ctx, params.RoomID) + }) + case "login": + return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) { + return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password) + }) + case "verify": + return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) { + return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey) + }) + case "discover_homeserver": + return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) { + _, homeserver, err := params.UserID.Parse() + if err != nil { + return nil, err + } + return mautrix.DiscoverClientAPI(ctx, homeserver) + }) + default: + return nil, fmt.Errorf("unknown command %q", req.Command) + } +} + +func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) { + var input T + err = json.Unmarshal(data, &input) + if err != nil { + return + } + return fn(&input) +} + +type cancelRequestParams struct { + RequestID int64 `json:"request_id"` + Reason string `json:"reason"` +} + +type sendMessageParams struct { + RoomID id.RoomID `json:"room_id"` + Text string `json:"text"` + MediaPath string `json:"media_path"` + ReplyTo id.EventID `json:"reply_to"` + Mentions *event.Mentions `json:"mentions"` +} + +type sendEventParams struct { + RoomID id.RoomID `json:"room_id"` + EventType event.Type `json:"type"` + Content json.RawMessage `json:"content"` +} + +type markReadParams struct { + RoomID id.RoomID `json:"room_id"` + EventID id.EventID `json:"event_id"` + ReceiptType event.ReceiptType `json:"receipt_type"` +} + +type setTypingParams struct { + RoomID id.RoomID `json:"room_id"` + Timeout int `json:"timeout"` +} + +type getEventParams struct { + RoomID id.RoomID `json:"room_id"` + EventID id.EventID `json:"event_id"` +} + +type getEventsByRowIDsParams struct { + RowIDs []database.EventRowID `json:"row_ids"` +} + +type getRoomStateParams struct { + RoomID id.RoomID `json:"room_id"` + Refetch bool `json:"refetch"` + FetchMembers bool `json:"fetch_members"` +} + +type ensureGroupSessionSharedParams struct { + RoomID id.RoomID `json:"room_id"` +} + +type loginParams struct { + HomeserverURL string `json:"homeserver_url"` + Username string `json:"username"` + Password string `json:"password"` +} + +type verifyParams struct { + RecoveryKey string `json:"recovery_key"` +} + +type discoverHomeserverParams struct { + UserID id.UserID `json:"user_id"` +} + +type paginateParams struct { + RoomID id.RoomID `json:"room_id"` + MaxTimelineID database.TimelineRowID `json:"max_timeline_id"` + Limit int `json:"limit"` +} diff --git a/pkg/hicli/json.go b/pkg/hicli/json.go new file mode 100644 index 0000000..a27fd00 --- /dev/null +++ b/pkg/hicli/json.go @@ -0,0 +1,119 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync/atomic" + + "go.mau.fi/util/exerrors" +) + +type JSONCommand struct { + Command string `json:"command"` + RequestID int64 `json:"request_id"` + Data json.RawMessage `json:"data"` +} + +type JSONEventHandler func(*JSONCommand) + +var outgoingEventCounter atomic.Int64 + +func (jeh JSONEventHandler) HandleEvent(evt any) { + var command string + switch evt.(type) { + case *SyncComplete: + command = "sync_complete" + case *EventsDecrypted: + command = "events_decrypted" + case *Typing: + command = "typing" + case *SendComplete: + command = "send_complete" + case *ClientState: + command = "client_state" + default: + panic(fmt.Errorf("unknown event type %T", evt)) + } + data, err := json.Marshal(evt) + if err != nil { + panic(fmt.Errorf("failed to marshal event %T: %w", evt, err)) + } + jeh(&JSONCommand{ + Command: command, + RequestID: -outgoingEventCounter.Add(1), + Data: data, + }) +} + +func (h *HiClient) State() *ClientState { + state := &ClientState{} + if acc := h.Account; acc != nil { + state.IsLoggedIn = true + state.UserID = acc.UserID + state.DeviceID = acc.DeviceID + state.HomeserverURL = acc.HomeserverURL + state.IsVerified = h.Verified + } + return state +} + +func (h *HiClient) dispatchCurrentState() { + h.EventHandler(h.State()) +} + +func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand { + if req.Command == "ping" { + return &JSONCommand{ + Command: "pong", + RequestID: req.RequestID, + } + } + log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger() + ctx, cancel := context.WithCancelCause(ctx) + defer func() { + cancel(nil) + h.jsonRequestsLock.Lock() + delete(h.jsonRequests, req.RequestID) + h.jsonRequestsLock.Unlock() + }() + ctx = log.WithContext(ctx) + h.jsonRequestsLock.Lock() + h.jsonRequests[req.RequestID] = cancel + h.jsonRequestsLock.Unlock() + resp, err := h.handleJSONCommand(ctx, req) + if err != nil { + if errors.Is(err, context.Canceled) { + causeErr := context.Cause(ctx) + if causeErr != ctx.Err() { + err = fmt.Errorf("%w: %w", err, causeErr) + } + } + return &JSONCommand{ + Command: "error", + RequestID: req.RequestID, + Data: exerrors.Must(json.Marshal(err.Error())), + } + } + var respData json.RawMessage + respData, err = json.Marshal(resp) + if err != nil { + return &JSONCommand{ + Command: "error", + RequestID: req.RequestID, + Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))), + } + } + return &JSONCommand{ + Command: "response", + RequestID: req.RequestID, + Data: respData, + } +} diff --git a/pkg/hicli/login.go b/pkg/hicli/login.go new file mode 100644 index 0000000..06ad693 --- /dev/null +++ b/pkg/hicli/login.go @@ -0,0 +1,88 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "net/url" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +var InitialDeviceDisplayName = "mautrix hiclient" + +func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { + var err error + h.Client.HomeserverURL, err = url.Parse(homeserverURL) + if err != nil { + return err + } + return h.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: username, + }, + Password: password, + InitialDeviceDisplayName: InitialDeviceDisplayName, + }) +} + +func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { + err := h.CheckServerVersions(ctx) + if err != nil { + return err + } + req.StoreCredentials = true + req.StoreHomeserverURL = true + resp, err := h.Client.Login(ctx, req) + if err != nil { + return err + } + defer h.dispatchCurrentState() + h.Account = &database.Account{ + UserID: resp.UserID, + DeviceID: resp.DeviceID, + AccessToken: resp.AccessToken, + HomeserverURL: h.Client.HomeserverURL.String(), + } + h.CryptoStore.AccountID = resp.UserID.String() + h.CryptoStore.DeviceID = resp.DeviceID + err = h.DB.Account.Put(ctx, h.Account) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + err = h.Crypto.ShareKeys(ctx, 0) + if err != nil { + return err + } + _, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true) + if err != nil { + return fmt.Errorf("failed to fetch own devices: %w", err) + } + return nil +} + +func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error { + err := h.LoginPassword(ctx, homeserverURL, username, password) + if err != nil { + return err + } + err = h.VerifyWithRecoveryKey(ctx, recoveryKey) + if err != nil { + return err + } + return nil +} diff --git a/pkg/hicli/paginate.go b/pkg/hicli/paginate.go new file mode 100644 index 0000000..8d7bf05 --- /dev/null +++ b/pkg/hicli/paginate.go @@ -0,0 +1,245 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") + +func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) { + events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...) + if err != nil { + return nil, err + } else if len(events) == 0 { + return events, nil + } + firstRoomID := events[0].RoomID + allInSameRoom := true + for _, evt := range events { + h.ReprocessExistingEvent(ctx, evt) + if evt.RoomID != firstRoomID { + allInSameRoom = false + break + } + } + if allInSameRoom { + err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill last edit row IDs: %w", err) + } + err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill reaction counts: %w", err) + } + } else { + // TODO slow path where events are collected and filling is done one room at a time? + } + return events, nil +} + +func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) { + if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from database: %w", err) + } else if evt != nil { + h.ReprocessExistingEvent(ctx, evt) + return evt, nil + } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from server: %w", err) + } else { + return h.processEvent(ctx, serverEvt, nil, nil, false) + } +} + +func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) { + var evts []*event.Event + if refetch { + resp, err := h.Client.StateAsArray(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to refetch state: %w", err) + } + evts = resp + } else if fetchMembers { + resp, err := h.Client.Members(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to fetch members: %w", err) + } + evts = resp.Chunk + } + if evts != nil { + err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room from database: %w", err) + } + updatedRoom := &database.Room{ + ID: room.ID, + HasMemberList: true, + } + entries := make([]*database.CurrentStateEntry, len(evts)) + for i, evt := range evts { + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false) + if err != nil { + return fmt.Errorf("failed to process event %s: %w", evt.ID, err) + } + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + } + if evt.Type == event.StateMember { + entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) + } else { + processImportantEvent(ctx, evt, room, updatedRoom) + } + } + err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries) + if err != nil { + return err + } + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { + err = h.DB.Room.Upsert(ctx, updatedRoom) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + return nil + }) + if err != nil { + return nil, err + } + } + return h.DB.CurrentState.GetAll(ctx, roomID) +} + +type PaginationResponse struct { + Events []*database.Event `json:"events"` + HasMore bool `json:"has_more"` +} + +func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) { + evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) + if err != nil { + return nil, err + } else if len(evts) > 0 { + for _, evt := range evts { + h.ReprocessExistingEvent(ctx, evt) + } + return &PaginationResponse{Events: evts, HasMore: true}, nil + } else { + return h.PaginateServer(ctx, roomID, limit) + } +} + +func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) { + ctx, cancel := context.WithCancelCause(ctx) + h.paginationInterrupterLock.Lock() + if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating { + h.paginationInterrupterLock.Unlock() + return nil, ErrPaginationAlreadyInProgress + } + h.paginationInterrupter[roomID] = cancel + h.paginationInterrupterLock.Unlock() + defer func() { + h.paginationInterrupterLock.Lock() + delete(h.paginationInterrupter, roomID) + h.paginationInterrupterLock.Unlock() + }() + + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room from database: %w", err) + } else if room.PrevBatch == database.PrevBatchPaginationComplete { + return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil + } + resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) + if err != nil { + return nil, fmt.Errorf("failed to get messages from server: %w", err) + } + events := make([]*database.Event, len(resp.Chunk)) + if resp.End == "" { + resp.End = database.PrevBatchPaginationComplete + } + if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 { + err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) + if err != nil { + return nil, fmt.Errorf("failed to set prev_batch: %w", err) + } + return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil + } + wakeupSessionRequests := false + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + if err = ctx.Err(); err != nil { + return err + } + eventRowIDs := make([]database.EventRowID, len(resp.Chunk)) + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + iOffset := 0 + for i, evt := range resp.Chunk { + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true) + if err != nil { + return err + } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil { + return fmt.Errorf("failed to check if event exists in timeline: %w", err) + } else if exists { + zerolog.Ctx(ctx).Warn(). + Int64("row_id", int64(dbEvt.RowID)). + Str("event_id", dbEvt.ID.String()). + Msg("Event already exists in timeline, skipping") + iOffset++ + continue + } + events[i-iOffset] = dbEvt + eventRowIDs[i-iOffset] = events[i-iOffset].RowID + } + if iOffset >= len(events) { + events = events[:0] + return nil + } + events = events[:len(events)-iOffset] + eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset] + wakeupSessionRequests = len(decryptionQueue) > 0 + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events) + if err != nil { + return fmt.Errorf("failed to fill last edit row IDs: %w", err) + } + err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) + if err != nil { + return fmt.Errorf("failed to set prev_batch: %w", err) + } + var tuples []database.TimelineRowTuple + tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) + if err != nil { + return fmt.Errorf("failed to prepend events to timeline: %w", err) + } + for i, evt := range events { + evt.TimelineRowID = tuples[i].Timeline + } + return nil + }) + if err == nil && wakeupSessionRequests { + h.WakeupRequestQueue() + } + return &PaginationResponse{Events: events, HasMore: true}, err +} diff --git a/pkg/hicli/pushrules.go b/pkg/hicli/pushrules.go new file mode 100644 index 0000000..ddf7a79 --- /dev/null +++ b/pkg/hicli/pushrules.go @@ -0,0 +1,80 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +type pushRoom struct { + ctx context.Context + roomID id.RoomID + h *HiClient + ll *mautrix.LazyLoadSummary +} + +func (p *pushRoom) GetOwnDisplayname() string { + // TODO implement + return "" +} + +func (p *pushRoom) GetMemberCount() int { + if p.ll == nil { + room, err := p.h.DB.Room.Get(p.ctx, p.roomID) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("room_id", p.roomID). + Msg("Failed to get room by ID in push rule evaluator") + } else if room != nil { + p.ll = room.LazyLoadSummary + } + } + if p.ll != nil && p.ll.JoinedMemberCount != nil { + return *p.ll.JoinedMemberCount + } + // TODO query db? + return 0 +} + +func (p *pushRoom) GetEvent(id id.EventID) *event.Event { + evt, err := p.h.DB.Event.GetByID(p.ctx, id) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("event_id", id). + Msg("Failed to get event by ID in push rule evaluator") + } + return evt.AsRawMautrix() +} + +var _ pushrules.EventfulRoom = (*pushRoom)(nil) + +func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType { + should := h.PushRules.Load().GetMatchingRule(&pushRoom{ + ctx: ctx, + roomID: evt.RoomID, + h: h, + ll: llSummary, + }, evt).GetActions().Should() + if should.Notify { + baseType |= database.UnreadTypeNotify + } + if should.Highlight { + baseType |= database.UnreadTypeHighlight + } + if should.PlaySound { + baseType |= database.UnreadTypeSound + } + return baseType +} diff --git a/pkg/hicli/send.go b/pkg/hicli/send.go new file mode 100644 index 0000000..2e269c7 --- /dev/null +++ b/pkg/hicli/send.go @@ -0,0 +1,287 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/yuin/goldmark" + "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli/database" + "go.mau.fi/gomuks/pkg/rainbow" +) + +var ( + rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension)) +) + +func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) { + var content event.MessageEventContent + if strings.HasPrefix(text, "/rainbow ") { + text = strings.TrimPrefix(text, "/rainbow ") + content = format.RenderMarkdownCustom(text, rainbowWithHTML) + content.FormattedBody = rainbow.ApplyColor(content.FormattedBody) + } else if strings.HasPrefix(text, "/plain ") { + text = strings.TrimPrefix(text, "/plain ") + content = format.RenderMarkdown(text, false, false) + } else if strings.HasPrefix(text, "/html ") { + text = strings.TrimPrefix(text, "/html ") + content = format.RenderMarkdown(text, false, true) + } else { + content = format.RenderMarkdown(text, true, false) + } + if mentions != nil { + content.Mentions.Room = mentions.Room + for _, userID := range mentions.UserIDs { + if userID != h.Account.UserID { + content.Mentions.Add(userID) + } + } + } + if replyTo != "" { + content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo) + } + return h.Send(ctx, roomID, event.EventMessage, &content) +} + +func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error { + content := &mautrix.ReqSetReadMarkers{ + FullyRead: eventID, + } + if receiptType == event.ReceiptTypeRead { + content.Read = eventID + } else if receiptType == event.ReceiptTypeReadPrivate { + content.ReadPrivate = eventID + } else { + return fmt.Errorf("invalid receipt type: %v", receiptType) + } + err := h.Client.SetReadMarkers(ctx, roomID, content) + if err != nil { + return fmt.Errorf("failed to mark event as read: %w", err) + } + return nil +} + +func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error { + _, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout) + return err +} + +func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") + } + var decryptedType event.Type + var decryptedContent json.RawMessage + var megolmSessionID id.SessionID + if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { + decryptedType = evtType + decryptedContent, err = json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content) + if err != nil { + return nil, fmt.Errorf("failed to encrypt event: %w", err) + } + megolmSessionID = encryptedContent.SessionID + content = encryptedContent + evtType = event.EventEncrypted + } + mainContent, err := json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + txnID := "hicli-" + h.Client.TxnID() + relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) + dbEvt := &database.Event{ + RoomID: roomID, + ID: id.EventID(fmt.Sprintf("~%s", txnID)), + Sender: h.Account.UserID, + Type: evtType.Type, + Timestamp: jsontime.UnixMilliNow(), + Content: mainContent, + Decrypted: decryptedContent, + DecryptedType: decryptedType.Type, + Unsigned: []byte("{}"), + TransactionID: txnID, + RelatesTo: relatesTo, + RelationType: relationType, + MegolmSessionID: megolmSessionID, + DecryptionError: "", + SendError: "not sent", + Reactions: map[string]int{}, + LastEditRowID: ptr.Ptr(database.EventRowID(0)), + } + _, err = h.DB.Event.Insert(ctx, dbEvt) + if err != nil { + return nil, fmt.Errorf("failed to insert event into database: %w", err) + } + ctx = context.WithoutCancel(ctx) + go func() { + err := h.SetTyping(ctx, roomID, 0) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message") + } + }() + go func() { + var err error + defer func() { + h.EventHandler(&SendComplete{ + Event: dbEvt, + Error: err, + }) + }() + var resp *mautrix.RespSendEvent + resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{ + Timestamp: dbEvt.Timestamp.UnixMilli(), + TransactionID: txnID, + DontEncrypt: true, + }) + if err != nil { + dbEvt.SendError = err.Error() + err = fmt.Errorf("failed to send event: %w", err) + err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError) + if err2 != nil { + zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err). + Msg("Failed to update send error in database after sending failed") + } + return + } + dbEvt.ID = resp.EventID + err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID) + if err != nil { + err = fmt.Errorf("failed to update event ID in database: %w", err) + } + }() + return dbEvt, nil +} + +func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content) + if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) { + if err = h.shareGroupSession(ctx, room); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + return +} + +func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil { + return fmt.Errorf("failed to get previous outbound group session: %w", err) + } else if session != nil && session.Shared && !session.Expired() { + return nil + } else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil { + return fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return fmt.Errorf("unknown room") + } else { + return h.shareGroupSession(ctx, roomMeta) + } +} + +func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { + if room.HasMemberList { + return nil + } + resp, err := h.Client.Members(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) + for i, evt := range resp.Chunk { + dbEvt, err := h.processEvent(ctx, evt, nil, nil, true) + if err != nil { + return err + } + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + Membership: event.Membership(evt.Content.Raw["membership"].(string)), + } + } + err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries) + if err != nil { + return err + } + return h.DB.Room.Upsert(ctx, &database.Room{ + ID: room.ID, + HasMemberList: true, + }) + }) + if err != nil { + return fmt.Errorf("failed to process room member list: %w", err) + } + return nil +} + +func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { + err := h.loadMembers(ctx, room) + if err != nil { + return err + } + shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID) + var users []id.UserID + if shareToInvited { + users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID) + } else { + users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID) + } + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil { + return fmt.Errorf("failed to share group session: %w", err) + } + return nil +} + +func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool { + historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event") + return false + } + mautrixEvt := historyVisibility.AsRawMautrix() + err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event") + return false + } + hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent) + if !ok { + zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event") + return false + } + return hv.HistoryVisibility == event.HistoryVisibilityInvited || + hv.HistoryVisibility == event.HistoryVisibilityShared || + hv.HistoryVisibility == event.HistoryVisibilityWorldReadable +} diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go new file mode 100644 index 0000000..d2916cf --- /dev/null +++ b/pkg/hicli/sync.go @@ -0,0 +1,850 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exzerolog" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" + + "go.mau.fi/gomuks/pkg/hicli/database" +) + +type syncContext struct { + shouldWakeupRequestQueue bool + + evt *SyncComplete +} + +func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + log := zerolog.Ctx(ctx) + postponedToDevices := resp.ToDevice.Events[:0] + for _, evt := range resp.ToDevice.Events { + evt.Type.Class = event.ToDeviceEventType + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + log.Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("sender", evt.Sender). + Msg("Failed to parse to-device event, skipping") + continue + } + + switch content := evt.Content.Parsed.(type) { + case *event.EncryptedEventContent: + h.Crypto.HandleEncryptedEvent(ctx, evt) + case *event.RoomKeyWithheldEventContent: + h.Crypto.HandleRoomKeyWithheld(ctx, content) + default: + postponedToDevices = append(postponedToDevices, evt) + } + } + resp.ToDevice.Events = postponedToDevices + + return nil +} + +func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + go h.asyncPostProcessSyncResponse(ctx, resp, since) + syncCtx := ctx.Value(syncContextKey).(*syncContext) + if syncCtx.shouldWakeupRequestQueue { + h.WakeupRequestQueue() + } + h.firstSyncReceived = true + if !syncCtx.evt.IsEmpty() { + h.EventHandler(syncCtx.evt) + } +} + +func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + for _, evt := range resp.ToDevice.Events { + switch content := evt.Content.Parsed.(type) { + case *event.SecretRequestEventContent: + h.Crypto.HandleSecretRequest(ctx, evt.Sender, content) + case *event.RoomKeyRequestEventContent: + h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content) + } + } +} + +func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + if len(resp.DeviceLists.Changed) > 0 { + zerolog.Ctx(ctx).Debug(). + Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)). + Msg("Marking changed device lists for tracked users as outdated") + err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed) + if err != nil { + return fmt.Errorf("failed to mark changed device lists as outdated: %w", err) + } + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + + for _, evt := range resp.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + if evt.Type == event.AccountDataPushRules { + err = evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync") + } else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok { + h.PushRules.Store(pushRules.Ruleset) + zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync") + } + } + } + for roomID, room := range resp.Rooms.Join { + err := h.processSyncJoinedRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process joined room %s: %w", roomID, err) + } + } + for roomID, room := range resp.Rooms.Leave { + err := h.processSyncLeftRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process left room %s: %w", roomID, err) + } + } + h.Account.NextBatch = resp.NextBatch + err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch) + if err != nil { + return fmt.Errorf("failed to save next_batch: %w", err) + } + return nil +} + +func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) { + receiptList := make([]*database.Receipt, 0) + var newOwnReceipts []id.EventID + for eventID, receipts := range *content { + for receiptType, users := range receipts { + for userID, receiptInfo := range users { + if userID == h.Account.UserID { + newOwnReceipts = append(newOwnReceipts, eventID) + } + receiptList = append(receiptList, &database.Receipt{ + UserID: userID, + ReceiptType: receiptType, + ThreadID: receiptInfo.ThreadID, + EventID: eventID, + Timestamp: jsontime.UM(receiptInfo.Timestamp), + }) + } + } + } + return receiptList, newOwnReceipts +} + +type receiptsToSave struct { + roomID id.RoomID + receipts []*database.Receipt +} + +func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + err = h.DB.Room.CreateRow(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to ensure room row exists: %w", err) + } + existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()} + } + + for _, evt := range room.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + evt.RoomID = roomID + err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + var receipts []receiptsToSave + var newOwnReceipts []id.EventID + for _, evt := range room.Ephemeral.Events { + evt.Type.Class = event.EphemeralEventType + err = evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content") + continue + } + switch evt.Type { + case event.EphemeralEventReceipt: + var receiptsList []*database.Receipt + receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt()) + receipts = append(receipts, receiptsToSave{roomID, receiptsList}) + case event.EphemeralEventTyping: + go h.EventHandler(&Typing{ + RoomID: roomID, + TypingEventContent: *evt.Content.AsTyping(), + }) + } + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications) + if err != nil { + return err + } + for _, rs := range receipts { + err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) + } + } + return nil +} + +func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + return nil + } + // TODO delete room instead of processing? + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) +} + +func isDecryptionErrorRetryable(err error) bool { + return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld) +} + +func removeReplyFallback(evt *event.Event) []byte { + if evt.Type != event.EventMessage && evt.Type != event.EventSticker { + return nil + } + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if ok && content.RelatesTo.GetReplyTo() != "" { + prevFormattedBody := content.FormattedBody + content.RemoveReplyFallback() + if content.FormattedBody != prevFormattedBody { + bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) + bytes, err2 := sjson.SetBytes(bytes, "body", content.Body) + if err == nil && err2 == nil { + return bytes + } + } + } + return nil +} + +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + return nil, nil, "", err + } + decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) + if err != nil { + return nil, nil, "", err + } + withoutFallback := removeReplyFallback(decrypted) + if withoutFallback != nil { + return decrypted, withoutFallback, decrypted.Type.Type, nil + } + return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) addMediaCache( + ctx context.Context, + eventRowID database.EventRowID, + uri id.ContentURIString, + file *event.EncryptedFileInfo, + info *event.FileInfo, + fileName string, +) { + parsedMXC := uri.ParseOrIgnore() + if !parsedMXC.IsValid() { + return + } + cm := &database.CachedMedia{ + MXC: parsedMXC, + EventRowID: eventRowID, + FileName: fileName, + } + if file != nil { + cm.EncFile = &file.EncryptedFile + } + if info != nil { + cm.MimeType = info.MimeType + } + err := h.DB.CachedMedia.Put(ctx, cm) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("mxc", parsedMXC). + Int64("event_rowid", int64(eventRowID)). + Msg("Failed to add cached media entry") + } +} + +func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) { + switch evt.Type { + case event.EventMessage, event.EventSticker: + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return + } + if content.File != nil { + h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName()) + } else if content.URL != "" { + h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName()) + } + if content.GetInfo().ThumbnailFile != nil { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "") + } else if content.GetInfo().ThumbnailURL != "" { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "") + } + case event.StateRoomAvatar: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.URL, nil, nil, "") + case event.StateMember: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "") + } +} + +func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { + if evt.Type != event.EventMessage && evt.Type != event.EventSticker { + return nil + } + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return nil + } + if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { + content = content.NewContent + } + if content != nil { + var sanitizedHTML string + if content.Format == event.FormatHTML { + sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) + } else { + var builder strings.Builder + linkifyAndWriteBytes(&builder, []byte(content.Body)) + sanitizedHTML = builder.String() + } + return &database.LocalContent{SanitizedHTML: sanitizedHTML, HTMLVersion: CurrentHTMLSanitizerVersion} + } + return nil +} + +const CurrentHTMLSanitizerVersion = 1 + +func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) { + if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion { + return + } + evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix()) + err := h.DB.Event.UpdateLocalContent(ctx, evt) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", evt.ID). + Msg("Failed to update local content") + } +} + +func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { + if dbEvt.RowID != 0 { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } + if evt.Sender != h.Account.UserID { + dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) + } + dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) +} + +func (h *HiClient) processEvent( + ctx context.Context, + evt *event.Event, + llSummary *mautrix.LazyLoadSummary, + decryptionQueue map[id.SessionID]*database.SessionRequest, + checkDB bool, +) (*database.Event, error) { + if checkDB { + dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) + if err != nil { + return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err) + } else if dbEvt != nil { + return dbEvt, nil + } + } + dbEvt := database.MautrixToEvent(evt) + contentWithoutFallback := removeReplyFallback(evt) + if contentWithoutFallback != nil { + dbEvt.Content = contentWithoutFallback + } + var decryptionErr error + var decryptedMautrixEvt *event.Event + if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { + decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + if decryptionErr != nil { + dbEvt.DecryptionError = decryptionErr.Error() + } + } else if evt.Type == event.EventRedaction { + if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { + var err error + evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts) + if err != nil { + return dbEvt, fmt.Errorf("failed to set redacts field: %w", err) + } + } else if evt.Redacts == "" { + evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) + } + } + if decryptedMautrixEvt != nil { + h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) + } else { + h.postDecryptProcess(ctx, llSummary, dbEvt, evt) + } + _, err := h.DB.Event.Upsert(ctx, dbEvt) + if err != nil { + return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + } + if decryptedMautrixEvt != nil { + h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID) + } else { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } + if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { + req, ok := decryptionQueue[dbEvt.MegolmSessionID] + if !ok { + req = &database.SessionRequest{ + RoomID: evt.RoomID, + SessionID: dbEvt.MegolmSessionID, + Sender: evt.Sender, + } + } + minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) + req.MinIndex = min(uint32(minIndex), req.MinIndex) + if decryptionQueue != nil { + decryptionQueue[dbEvt.MegolmSessionID] = req + } else { + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("session_id", dbEvt.MegolmSessionID). + Msg("Failed to save session request") + } else { + h.WakeupRequestQueue() + } + } + } + return dbEvt, err +} + +func (h *HiClient) processStateAndTimeline( + ctx context.Context, + room *database.Room, + state *mautrix.SyncEventsList, + timeline *mautrix.SyncTimeline, + summary *mautrix.LazyLoadSummary, + newOwnReceipts []id.EventID, + serverNotificationCounts *mautrix.UnreadNotificationCounts, +) error { + updatedRoom := &database.Room{ + ID: room.ID, + + SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + UnreadHighlights: room.UnreadHighlights, + UnreadNotifications: room.UnreadNotifications, + UnreadMessages: room.UnreadMessages, + } + if serverNotificationCounts != nil { + updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount + updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount + } + heroesChanged := false + if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { + summary = room.LazyLoadSummary + } else if room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) { + updatedRoom.LazyLoadSummary = summary + heroesChanged = true + } + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) + newNotifications := make([]SyncNotification, 0) + recalculatePreviewEvent := false + addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { + if rowID != 0 { + dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID) + } else { + dbEvt, err = h.DB.Event.GetByID(ctx, evtID) + } + if err != nil { + return nil, fmt.Errorf("failed to get redaction target: %w", err) + } else if dbEvt == nil { + return nil, nil + } + allNewEvents = append(allNewEvents, dbEvt) + return dbEvt, nil + } + processRedaction := func(evt *event.Event) error { + dbEvt, err := addOldEvent(0, evt.Redacts) + if err != nil { + return fmt.Errorf("failed to get redaction target: %w", err) + } + if dbEvt == nil { + return nil + } + if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { + _, err = addOldEvent(0, dbEvt.RelatesTo) + if err != nil { + return fmt.Errorf("failed to get relation target of redaction target: %w", err) + } + } + if updatedRoom.PreviewEventRowID == dbEvt.RowID { + updatedRoom.PreviewEventRowID = 0 + recalculatePreviewEvent = true + } + return nil + } + processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { + evt.RoomID = room.ID + dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false) + if err != nil { + return -1, err + } + if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) { + newNotifications = append(newNotifications, SyncNotification{ + RowID: dbEvt.RowID, + Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), + }) + } + if isTimeline { + if dbEvt.CanUseForPreview() { + updatedRoom.PreviewEventRowID = dbEvt.RowID + recalculatePreviewEvent = false + } + updatedRoom.BumpSortingTimestamp(dbEvt) + } + if evt.StateKey != nil { + var membership event.Membership + if evt.Type == event.StateMember { + membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) + if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { + heroesChanged = true + } + } else if evt.Type == event.StateElementFunctionalMembers { + heroesChanged = true + } + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) + if err != nil { + return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) + } + processImportantEvent(ctx, evt, room, updatedRoom) + } + allNewEvents = append(allNewEvents, dbEvt) + if evt.Type == event.EventRedaction && evt.Redacts != "" { + err = processRedaction(evt) + if err != nil { + return -1, fmt.Errorf("failed to process redaction: %w", err) + } + } else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { + _, err = addOldEvent(0, dbEvt.RelatesTo) + if err != nil { + return -1, fmt.Errorf("failed to get relation target of event: %w", err) + } + } + return dbEvt.RowID, nil + } + changedState := make(map[event.Type]map[string]database.EventRowID) + setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) { + if _, ok := changedState[evtType]; !ok { + changedState[evtType] = make(map[string]database.EventRowID) + } + changedState[evtType][stateKey] = rowID + } + for _, evt := range state.Events { + evt.Type.Class = event.StateEventType + rowID, err := processNewEvent(evt, false, false) + if err != nil { + return err + } + setNewState(evt.Type, *evt.StateKey, rowID) + } + var timelineRowTuples []database.TimelineRowTuple + var err error + if len(timeline.Events) > 0 { + timelineIDs := make([]database.EventRowID, len(timeline.Events)) + readUpToIndex := -1 + for i := len(timeline.Events) - 1; i >= 0; i-- { + if slices.Contains(newOwnReceipts, timeline.Events[i].ID) { + readUpToIndex = i + break + } + } + for i, evt := range timeline.Events { + if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex) + if err != nil { + return err + } + if evt.StateKey != nil { + setNewState(evt.Type, *evt.StateKey, timelineIDs[i]) + } + } + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + if len(decryptionQueue) > 0 { + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + if timeline.Limited { + err = h.DB.Timeline.Clear(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to clear old timeline: %w", err) + } + updatedRoom.PrevBatch = timeline.PrevBatch + h.paginationInterrupterLock.Lock() + if interrupt, ok := h.paginationInterrupter[room.ID]; ok { + interrupt(ErrTimelineReset) + } + h.paginationInterrupterLock.Unlock() + } + timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) + if err != nil { + return fmt.Errorf("failed to append timeline: %w", err) + } + } else { + timelineRowTuples = make([]database.TimelineRowTuple, 0) + } + if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 { + updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to recalculate preview event: %w", err) + } + _, err = addOldEvent(updatedRoom.PreviewEventRowID, "") + if err != nil { + return fmt.Errorf("failed to get preview event: %w", err) + } + } + // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset + if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { + name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary) + if err != nil { + return fmt.Errorf("failed to calculate room name: %w", err) + } + updatedRoom.Name = &name + updatedRoom.NameQuality = database.NameQualityParticipants + if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar { + updatedRoom.Avatar = &dmAvatarURL + } + } + if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) { + updatedRoom.PrevBatch = timeline.PrevBatch + } + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { + err = h.DB.Room.Upsert(ctx, updatedRoom) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { + ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ + Meta: room, + Timeline: timelineRowTuples, + State: changedState, + Reset: timeline.Limited, + Events: allNewEvents, + Notifications: newNotifications, + } + } + return nil +} + +func joinMemberNames(names []string, totalCount int) string { + if len(names) == 1 { + return names[0] + } else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) { + return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1] + } else { + return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5) + } +} + +func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) { + var primaryAvatarURL id.ContentURI + if summary == nil || len(summary.Heroes) == 0 { + return "Empty room", primaryAvatarURL, nil + } + var functionalMembers []id.UserID + functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "") + if err != nil { + return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) + } else if functionalMembersEvt != nil { + mautrixEvt := functionalMembersEvt.AsRawMautrix() + _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok { + functionalMembers = content.ServiceMembers + } + } + var members, leftMembers []string + var memberCount int + if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 { + memberCount = *summary.JoinedMemberCount + } else if summary.InvitedMemberCount != nil { + memberCount = *summary.InvitedMemberCount + } + for _, hero := range summary.Heroes { + if slices.Contains(functionalMembers, hero) { + memberCount-- + continue + } else if len(members) >= 5 { + break + } + heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) + if err != nil { + return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err) + } else if heroEvt == nil { + leftMembers = append(leftMembers, hero.String()) + continue + } + membership := gjson.GetBytes(heroEvt.Content, "membership").Str + name := gjson.GetBytes(heroEvt.Content, "displayname").Str + if name == "" { + name = hero.String() + } + avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str + if avatarURL != "" { + primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore() + } + if membership == "join" || membership == "invite" { + members = append(members, name) + } else { + leftMembers = append(leftMembers, name) + } + } + if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() { + primaryAvatarURL = id.ContentURI{} + } + if len(members) > 0 { + return joinMemberNames(members, memberCount), primaryAvatarURL, nil + } else if len(leftMembers) > 0 { + return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil + } else { + return "Empty room", primaryAvatarURL, nil + } +} + +func intPtrEqual(a, b *int) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + +func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) { + if evt.StateKey == nil { + return + } + switch evt.Type { + case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: + if *evt.StateKey != "" { + return + } + default: + return + } + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("event_id", evt.ID). + Msg("Failed to parse state event, skipping") + return + } + switch evt.Type { + case event.StateCreate: + updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) + case event.StateEncryption: + newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) + if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { + updatedRoom.EncryptionEvent = newEncryption + } + case event.StateRoomName: + content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) + if ok { + updatedRoom.Name = &content.Name + updatedRoom.NameQuality = database.NameQualityExplicit + if content.Name == "" { + if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" { + updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" { + updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else { + updatedRoom.NameQuality = database.NameQualityNil + } + } + } + case event.StateCanonicalAlias: + content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent) + if ok { + updatedRoom.CanonicalAlias = &content.Alias + if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias { + updatedRoom.Name = (*string)(&content.Alias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + if content.Alias == "" { + updatedRoom.NameQuality = database.NameQualityNil + } + } + } + case event.StateRoomAvatar: + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if ok { + url, _ := content.URL.Parse() + updatedRoom.Avatar = &url + updatedRoom.ExplicitAvatar = true + } + case event.StateTopic: + content, ok := evt.Content.Parsed.(*event.TopicEventContent) + if ok { + updatedRoom.Topic = &content.Topic + } + } + return +} diff --git a/pkg/hicli/syncwrap.go b/pkg/hicli/syncwrap.go new file mode 100644 index 0000000..1383720 --- /dev/null +++ b/pkg/hicli/syncwrap.go @@ -0,0 +1,96 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type hiSyncer HiClient + +var _ mautrix.Syncer = (*hiSyncer)(nil) + +type contextKey int + +const ( + syncContextKey contextKey = iota +) + +func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + c := (*HiClient)(h) + ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}}) + err := c.preProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + return c.processSyncResponse(ctx, resp, since) + }) + if err != nil { + return err + } + c.postProcessSyncResponse(ctx, resp, since) + return nil +} + +func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second") + return 1 * time.Second, nil +} + +func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + if !h.Verified { + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + NotRooms: []id.RoomID{"*"}, + }, + } + } + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + State: mautrix.FilterPart{ + LazyLoadMembers: true, + }, + Timeline: mautrix.FilterPart{ + Limit: 100, + LazyLoadMembers: true, + }, + }, + } +} + +type hiStore HiClient + +var _ mautrix.SyncStore = (*hiStore)(nil) + +// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing + +func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil } +func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil } + +func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { + // This is intentionally a no-op: we don't want to save the next batch before processing the sync + return nil +} + +func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) { + if h.Account.UserID != userID { + return "", fmt.Errorf("mismatching user ID") + } + return h.Account.NextBatch, nil +} diff --git a/pkg/hicli/verify.go b/pkg/hicli/verify.go new file mode 100644 index 0000000..71ce8c1 --- /dev/null +++ b/pkg/hicli/verify.go @@ -0,0 +1,161 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/crypto/ssss" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) { + keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx) + if keys == nil { + return false, fmt.Errorf("own cross-signing keys not found") + } + isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey) + if err != nil { + return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + return isVerified, nil +} + +func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error { + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey) + if err != nil { + return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err) + } + key, err := backup.MegolmBackupKeyFromBytes(data) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes())) + if err != nil { + return fmt.Errorf("failed to store megolm backup key: %w", err) + } + h.KeyBackupKey = key + return nil +} + +func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) { + secretData, err := h.CryptoStore.GetSecret(ctx, secret) + if err != nil { + return nil, fmt.Errorf("failed to get secret %s: %w", secret, err) + } + data, err := base64.StdEncoding.DecodeString(secretData) + if err != nil { + return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err) + } + return data, nil +} + +func (h *HiClient) loadPrivateKeys(ctx context.Context) error { + zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys") + masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster) + if err != nil { + return fmt.Errorf("failed to get master key: %w", err) + } + selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning) + if err != nil { + return fmt.Errorf("failed to get self-signing key: %w", err) + } + userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning) + if err != nil { + return fmt.Errorf("failed to get user signing key: %w", err) + } + err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{ + MasterKey: masterKeySeed, + SelfSigningKey: selfSigningKeySeed, + UserSigningKey: userSigningKeySeed, + }) + if err != nil { + return fmt.Errorf("failed to import cross-signing private keys: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Loading key backup key") + keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1) + if err != nil { + return fmt.Errorf("failed to get megolm backup key: %w", err) + } + h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version") + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + zerolog.Ctx(ctx).Debug().Msg("Secrets loaded") + return nil +} + +func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { + keys := h.Crypto.CrossSigningKeys + err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed())) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error { + defer h.dispatchCurrentState() + keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(keyID, code) + if err != nil { + return err + } + err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = h.Crypto.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + err = h.storeCrossSigningPrivateKeys(ctx) + if err != nil { + return fmt.Errorf("failed to store cross-signing private keys: %w", err) + } + err = h.fetchKeyBackupKey(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch key backup key: %w", err) + } + h.Verified = true + if !h.IsSyncing() { + go h.Sync() + } + return nil +} diff --git a/pkg/rainbow/goldmark.go b/pkg/rainbow/goldmark.go new file mode 100644 index 0000000..59a3617 --- /dev/null +++ b/pkg/rainbow/goldmark.go @@ -0,0 +1,120 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package rainbow + +import ( + "fmt" + "unicode" + + "github.com/rivo/uniseg" + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/renderer/html" + "github.com/yuin/goldmark/util" + "go.mau.fi/util/random" +) + +// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer. +var Extension = &extRainbow{} + +type extRainbow struct{} +type rainbowRenderer struct { + HardWraps bool + ColorID string +} + +var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)} + +func (er *extRainbow) Extend(m goldmark.Markdown) { + m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0))) +} + +func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(ast.KindText, rb.renderText) + reg.Register(ast.KindString, rb.renderString) +} + +type rainbowBufWriter struct { + util.BufWriter + ColorID string +} + +func (rbw rainbowBufWriter) WriteString(s string) (int, error) { + i := 0 + graphemes := uniseg.NewGraphemes(s) + for graphemes.Next() { + runes := graphemes.Runes() + if len(runes) == 1 && unicode.IsSpace(runes[0]) { + i2, err := rbw.BufWriter.WriteRune(runes[0]) + i += i2 + if err != nil { + return i, err + } + continue + } + i2, err := fmt.Fprintf(rbw.BufWriter, "%s", rbw.ColorID, graphemes.Str()) + i += i2 + if err != nil { + return i, err + } + } + return i, nil +} + +func (rbw rainbowBufWriter) Write(data []byte) (int, error) { + return rbw.WriteString(string(data)) +} + +func (rbw rainbowBufWriter) WriteByte(c byte) error { + _, err := rbw.WriteRune(rune(c)) + return err +} + +func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) { + if unicode.IsSpace(r) { + return rbw.BufWriter.WriteRune(r) + } else { + return fmt.Fprintf(rbw.BufWriter, "%c", rbw.ColorID, r) + } +} + +func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + if !entering { + return ast.WalkContinue, nil + } + n := node.(*ast.Text) + segment := n.Segment + if n.IsRaw() { + html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) + } else { + html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) + if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) { + _, _ = w.WriteString("
\n") + } else if n.SoftLineBreak() { + _ = w.WriteByte('\n') + } + } + return ast.WalkContinue, nil +} + +func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + if !entering { + return ast.WalkContinue, nil + } + n := node.(*ast.String) + if n.IsCode() { + _, _ = w.Write(n.Value) + } else { + if n.IsRaw() { + html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value) + } else { + html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value) + } + } + return ast.WalkContinue, nil +} diff --git a/pkg/rainbow/gradient.go b/pkg/rainbow/gradient.go new file mode 100644 index 0000000..34c499e --- /dev/null +++ b/pkg/rainbow/gradient.go @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package rainbow + +import ( + "regexp" + "strings" + + "github.com/lucasb-eyer/go-colorful" +) + +// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go +type GradientTable []struct { + Col colorful.Color + Pos float64 +} + +func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color { + for i := 0; i < len(gt)-1; i++ { + c1 := gt[i] + c2 := gt[i+1] + if c1.Pos <= t && t <= c2.Pos { + t := (t - c1.Pos) / (c2.Pos - c1.Pos) + return c1.Col.BlendHcl(c2.Col, t).Clamped() + } + } + return gt[len(gt)-1].Col +} + +var Gradient = GradientTable{ + {colorful.LinearRgb(1, 0, 0), 0 / 11.0}, + {colorful.LinearRgb(1, 0.5, 0), 1 / 11.0}, + {colorful.LinearRgb(1, 1, 0), 2 / 11.0}, + {colorful.LinearRgb(0.5, 1, 0), 3 / 11.0}, + {colorful.LinearRgb(0, 1, 0), 4 / 11.0}, + {colorful.LinearRgb(0, 1, 0.5), 5 / 11.0}, + {colorful.LinearRgb(0, 1, 1), 6 / 11.0}, + {colorful.LinearRgb(0, 0.5, 1), 7 / 11.0}, + {colorful.LinearRgb(0, 0, 1), 8 / 11.0}, + {colorful.LinearRgb(0.5, 0, 1), 9 / 11.0}, + {colorful.LinearRgb(1, 0, 1), 10 / 11.0}, + {colorful.LinearRgb(1, 0, 0.5), 11 / 11.0}, +} + +func ApplyColor(htmlBody string) string { + count := strings.Count(htmlBody, defaultRB.ColorID) + i := -1 + return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string { + i++ + return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex() + }) +} diff --git a/server.go b/server.go index c20baec..0cccee1 100644 --- a/server.go +++ b/server.go @@ -82,11 +82,12 @@ var ( ) type tokenData struct { - Username string `json:"username"` - Expiry jsontime.Unix `json:"expiry"` + Username string `json:"username"` + Expiry jsontime.Unix `json:"expiry"` + ImageOnly bool `json:"image_only,omitempty"` } -func (gmx *Gomuks) validateAuth(token string) bool { +func (gmx *Gomuks) validateAuth(token string, imageOnly bool) bool { if len(token) > 500 { return false } @@ -110,19 +111,31 @@ func (gmx *Gomuks) validateAuth(token string) bool { var td tokenData err = json.Unmarshal(rawJSON, &td) - return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now()) + return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now()) && td.ImageOnly == imageOnly } func (gmx *Gomuks) generateToken() (string, time.Time) { expiry := time.Now().Add(7 * 24 * time.Hour) - data := exerrors.Must(json.Marshal(tokenData{ + return gmx.signToken(tokenData{ Username: gmx.Config.Web.Username, Expiry: jsontime.U(expiry), - })) + }), expiry +} + +func (gmx *Gomuks) generateImageToken() string { + return gmx.signToken(tokenData{ + Username: gmx.Config.Web.Username, + Expiry: jsontime.U(time.Now().Add(1 * time.Hour)), + ImageOnly: true, + }) +} + +func (gmx *Gomuks) signToken(td tokenData) string { + data := exerrors.Must(json.Marshal(td)) hasher := hmac.New(sha256.New, []byte(gmx.Config.Web.TokenKey)) hasher.Write(data) checksum := hasher.Sum(nil) - return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum), expiry + return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum) } func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) { @@ -137,7 +150,7 @@ func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) { func (gmx *Gomuks) Authenticate(w http.ResponseWriter, r *http.Request) { authCookie, err := r.Cookie("gomuks_auth") - if err == nil && gmx.validateAuth(authCookie.Value) { + if err == nil && gmx.validateAuth(authCookie.Value, false) { gmx.writeTokenCookie(w) w.WriteHeader(http.StatusOK) } else if username, password, ok := r.BasicAuth(); !ok { @@ -164,9 +177,23 @@ func isUserFetch(header http.Header) bool { header.Get("Sec-Fetch-User") == "?1" } +func isImageFetch(header http.Header) bool { + return header.Get("Sec-Fetch-Site") == "cross-site" && + header.Get("Sec-Fetch-Mode") == "no-cors" && + header.Get("Sec-Fetch-Dest") == "image" +} + func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Sec-Fetch-Site") != "" && r.Header.Get("Sec-Fetch-Site") != "same-origin" && !isUserFetch(r.Header) { + if strings.HasPrefix(r.URL.Path, "/media") && + isImageFetch(r.Header) && + gmx.validateAuth(r.URL.Query().Get("image_auth"), true) && + r.URL.Query().Get("encrypted") == "false" { + next.ServeHTTP(w, r) + return + } else if r.Header.Get("Sec-Fetch-Site") != "" && + r.Header.Get("Sec-Fetch-Site") != "same-origin" && + !isUserFetch(r.Header) { hlog.FromRequest(r).Debug(). Str("site", r.Header.Get("Sec-Fetch-Site")). Str("dest", r.Header.Get("Sec-Fetch-Dest")). @@ -181,7 +208,7 @@ func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler { if err != nil { ErrMissingCookie.Write(w) return - } else if !gmx.validateAuth(authCookie.Value) { + } else if !gmx.validateAuth(authCookie.Value, false) { ErrInvalidCookie.Write(w) return } diff --git a/web/src/App.tsx b/web/src/App.tsx index 1bee2dc..4644dd0 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -29,6 +29,8 @@ function App() { const clientState = useEventAsState(client.state) ;((window as unknown) as { client: Client }).client = client useEffect(() => { + Notification.requestPermission() + .then(permission => console.log("Notification permission:", permission)) client.rpc.start() return () => client.rpc.stop() }, [client]) diff --git a/web/src/api/client.ts b/web/src/api/client.ts index 6173d27..e898bd6 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -39,6 +39,8 @@ export default class Client { this.store.applyDecrypted(ev.data) } else if (ev.command === "send_complete") { this.store.applySendComplete(ev.data) + } else if (ev.command === "image_auth_token") { + this.store.imageAuthToken = ev.data } } diff --git a/web/src/api/statestore/main.ts b/web/src/api/statestore/main.ts index 3b1a625..6b7aa3e 100644 --- a/web/src/api/statestore/main.ts +++ b/web/src/api/statestore/main.ts @@ -14,9 +14,11 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . import unhomoglyph from "unhomoglyph" +import { getMediaURL } from "@/api/media.ts" import { NonNullCachedEventDispatcher } from "@/util/eventdispatcher.ts" +import { focused } from "@/util/focus.ts" import type { - ContentURI, + ContentURI, EventRowID, EventsDecryptedData, MemDBEvent, RoomID, @@ -34,6 +36,8 @@ export interface RoomListEntry { name: string search_name: string avatar?: ContentURI + unread: number + highlighted: boolean } // eslint-disable-next-line no-misleading-character-class @@ -49,9 +53,13 @@ export function toSearchableString(str: string): string { export class StateStore { readonly rooms: Map = new Map() readonly roomList = new NonNullCachedEventDispatcher([]) + switchRoom?: (roomID: RoomID) => void + imageAuthToken?: string #roomListEntryChanged(entry: SyncRoom, oldEntry: RoomStateStore): boolean { return entry.meta.sorting_timestamp !== oldEntry.meta.current.sorting_timestamp || + entry.meta.unread_notifications !== oldEntry.meta.current.unread_notifications || + entry.meta.unread_highlights !== oldEntry.meta.current.unread_highlights || entry.meta.preview_event_rowid !== oldEntry.meta.current.preview_event_rowid || entry.events.findIndex(evt => evt.rowid === entry.meta.preview_event_rowid) !== -1 } @@ -71,6 +79,8 @@ export class StateStore { name, search_name: toSearchableString(name), avatar: entry.meta.avatar, + unread: entry.meta.unread_notifications, + highlighted: entry.meta.unread_highlights > 0, } } @@ -90,6 +100,12 @@ export class StateStore { if (roomListEntryChanged) { changedRoomListEntries.set(roomID, this.#makeRoomListEntry(data, room)) } + + if (Notification.permission === "granted" && !focused.current) { + for (const notification of data.notifications) { + this.showNotification(room, notification.event_rowid, notification.sound) + } + } } let updatedRoomList: RoomListEntry[] | undefined @@ -117,6 +133,40 @@ export class StateStore { } } + showNotification(room: RoomStateStore, rowid: EventRowID, sound: boolean) { + const evt = room.eventsByRowID.get(rowid) + if (!evt || typeof evt.content.body !== "string") { + return + } + let body = evt.content.body + if (body.length > 200) { + body = body.slice(0, 150) + " […]" + } + const memberEvt = room.getStateEvent("m.room.member", evt.sender) + const icon = `${getMediaURL(memberEvt?.content.avatar_url)}&image_auth=${this.imageAuthToken}` + const roomName = room.meta.current.name ?? "Unnamed room" + const senderName = memberEvt?.content.displayname ?? evt.sender + const title = senderName === roomName ? senderName : `${senderName} (${roomName})` + const notif = new Notification(title, { + body, + icon, + badge: "/gomuks.png", + // timestamp: evt.timestamp, + // image: ..., + tag: rowid.toString(), + }) + notif.onclick = () => this.onClickNotification(room.roomID) + if (sound) { + // TODO play sound + } + } + + onClickNotification(roomID: RoomID) { + if (this.switchRoom) { + this.switchRoom(roomID) + } + } + applySendComplete(data: SendCompleteData) { const room = this.rooms.get(data.event.room_id) if (!room) { diff --git a/web/src/api/statestore/room.ts b/web/src/api/statestore/room.ts index 8c05571..fd5333b 100644 --- a/web/src/api/statestore/room.ts +++ b/web/src/api/statestore/room.ts @@ -147,6 +147,7 @@ export class RoomStateStore { if (memEvt.last_edit) { memEvt.orig_content = memEvt.content memEvt.content = memEvt.last_edit.content["m.new_content"] + memEvt.local_content = memEvt.last_edit.local_content } } else if (memEvt.relation_type === "m.replace" && memEvt.relates_to) { const editTarget = this.eventsByID.get(memEvt.relates_to) @@ -154,6 +155,7 @@ export class RoomStateStore { editTarget.last_edit = memEvt editTarget.orig_content = editTarget.content editTarget.content = memEvt.content["m.new_content"] + editTarget.local_content = memEvt.local_content } } this.eventsByRowID.set(memEvt.rowid, memEvt) diff --git a/web/src/api/types/hievents.ts b/web/src/api/types/hievents.ts index c208e67..e12541b 100644 --- a/web/src/api/types/hievents.ts +++ b/web/src/api/types/hievents.ts @@ -60,12 +60,22 @@ export interface EventsDecryptedEvent extends RPCCommand { command: "events_decrypted" } +export interface ImageAuthTokenEvent extends RPCCommand { + command: "image_auth_token" +} + export interface SyncRoom { meta: DBRoom timeline: TimelineRowTuple[] events: RawDBEvent[] state: Record> reset: boolean + notifications: SyncNotification[] +} + +export interface SyncNotification { + event_rowid: EventRowID + sound: boolean } export interface SyncCompleteData { @@ -97,4 +107,5 @@ export type RPCEvent = TypingEvent | SendCompleteEvent | EventsDecryptedEvent | - SyncCompleteEvent + SyncCompleteEvent | + ImageAuthTokenEvent diff --git a/web/src/api/types/hitypes.ts b/web/src/api/types/hitypes.ts index 5f847d6..f637a1d 100644 --- a/web/src/api/types/hitypes.ts +++ b/web/src/api/types/hitypes.ts @@ -59,6 +59,9 @@ export interface DBRoom { preview_event_rowid: EventRowID sorting_timestamp: number + unread_highlights: number + unread_notifications: number + unread_messages: number prev_batch: string } @@ -66,6 +69,18 @@ export interface DBRoom { //eslint-disable-next-line @typescript-eslint/no-explicit-any export type UnknownEventContent = Record +export enum UnreadType { + None = 0b0000, + Normal = 0b0001, + Notify = 0b0010, + Highlight = 0b0100, + Sound = 0b1000, +} + +export interface LocalContent { + sanitized_html?: TrustedHTML +} + export interface BaseDBEvent { rowid: EventRowID timeline_rowid: TimelineRowID @@ -79,6 +94,7 @@ export interface BaseDBEvent { content: UnknownEventContent unsigned: EventUnsigned + local_content?: LocalContent transaction_id?: string @@ -91,6 +107,7 @@ export interface BaseDBEvent { reactions?: Record last_edit_rowid?: EventRowID + unread_type: UnreadType } export interface RawDBEvent extends BaseDBEvent { diff --git a/web/src/ui/MainScreen.tsx b/web/src/ui/MainScreen.tsx index b44018f..d5743d3 100644 --- a/web/src/ui/MainScreen.tsx +++ b/web/src/ui/MainScreen.tsx @@ -31,6 +31,7 @@ const MainScreen = () => { .catch(err => console.error("Failed to load room state", err)) } }, [client]) + client.store.switchRoom = setActiveRoom const clearActiveRoom = useCallback(() => setActiveRoomID(null), []) return
diff --git a/web/src/ui/roomlist/Entry.tsx b/web/src/ui/roomlist/Entry.tsx index 5ec9be9..d053063 100644 --- a/web/src/ui/roomlist/Entry.tsx +++ b/web/src/ui/roomlist/Entry.tsx @@ -60,6 +60,11 @@ const Entry = ({ room, setActiveRoom, isActive, hidden }: RoomListEntryProps) =>
{room.name}
{previewText &&
{croppedPreviewText}
} + {room.unread ?
+
+ {room.unread} +
+
: null} } diff --git a/web/src/ui/roomlist/RoomList.css b/web/src/ui/roomlist/RoomList.css index 36693bc..9a9bec7 100644 --- a/web/src/ui/roomlist/RoomList.css +++ b/web/src/ui/roomlist/RoomList.css @@ -78,6 +78,26 @@ div.room-entry { } } } + + > div.room-entry-unreads { + display: flex; + align-items: center; + width: 3rem; + + > div.unread-count { + width: 1.5rem; + height: 1.5rem; + border-radius: 50%; + background-color: green; + text-align: center; + color: white; + font-weight: bold; + + &.highlighted { + background-color: darkred; + } + } + } } img.avatar { diff --git a/web/src/ui/timeline/TimelineEvent.css b/web/src/ui/timeline/TimelineEvent.css index f99a679..74eae40 100644 --- a/web/src/ui/timeline/TimelineEvent.css +++ b/web/src/ui/timeline/TimelineEvent.css @@ -166,7 +166,20 @@ div.html-body { vertical-align: middle; } - span[data-mx-spoiler] { + img.hicli-custom-emoji { + vertical-align: middle; + height: 24px; + width: auto; + max-width: 72px; + } + + img.hicli-sizeless-inline-img { + height: 24px; + width: auto; + max-width: 72px; + } + + span[data-mx-spoiler], span.hicli-spoiler { filter: blur(4px); transition: filter .5s; cursor: pointer; diff --git a/web/src/ui/timeline/TimelineEvent.tsx b/web/src/ui/timeline/TimelineEvent.tsx index d608467..03fed1a 100644 --- a/web/src/ui/timeline/TimelineEvent.tsx +++ b/web/src/ui/timeline/TimelineEvent.tsx @@ -24,7 +24,11 @@ import { ReplyIDBody } from "./ReplyBody.tsx" import EncryptedBody from "./content/EncryptedBody.tsx" import HiddenEvent from "./content/HiddenEvent.tsx" import MemberBody from "./content/MemberBody.tsx" -import { MediaMessageBody, TextMessageBody, UnknownMessageBody } from "./content/MessageBody.tsx" +import { + MediaMessageBody, + TextMessageBody, + UnknownMessageBody, +} from "./content/MessageBody.tsx" import RedactedBody from "./content/RedactedBody.tsx" import { EventContentProps } from "./content/props.ts" import ErrorIcon from "../../icons/error.svg?react" diff --git a/web/src/ui/timeline/content/MessageBody.tsx b/web/src/ui/timeline/content/MessageBody.tsx index b20a444..44c8027 100644 --- a/web/src/ui/timeline/content/MessageBody.tsx +++ b/web/src/ui/timeline/content/MessageBody.tsx @@ -13,11 +13,9 @@ // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -import { CSSProperties, use, useMemo } from "react" -import sanitizeHtml from "sanitize-html" +import { CSSProperties, use } from "react" import { getEncryptedMediaURL, getMediaURL } from "@/api/media.ts" import type { EventType, MediaMessageEventContent, MessageEventContent } from "@/api/types" -import { sanitizeHtmlParams } from "@/util/html.ts" import { calculateMediaSize } from "@/util/mediasize.ts" import { LightboxContext } from "../../Lightbox.tsx" import { EventContentProps } from "./props.ts" @@ -32,14 +30,10 @@ const onClickHTML = (evt: React.MouseEvent) => { export const TextMessageBody = ({ event }: EventContentProps) => { const content = event.content as MessageEventContent - const __html = useMemo(() => { - if (content.format === "org.matrix.custom.html") { - return sanitizeHtml(content.formatted_body!, sanitizeHtmlParams) - } - return undefined - }, [content.format, content.formatted_body]) - if (__html) { - return
+ if (event.local_content?.sanitized_html) { + return
} return
{content.body}
} diff --git a/web/src/util/focus.ts b/web/src/util/focus.ts index 2d268df..c3717a1 100644 --- a/web/src/util/focus.ts +++ b/web/src/util/focus.ts @@ -15,7 +15,7 @@ // along with this program. If not, see . import { NonNullCachedEventDispatcher, useNonNullEventAsState } from "@/util/eventdispatcher.ts" -const focused = new NonNullCachedEventDispatcher(document.hasFocus()) +export const focused = new NonNullCachedEventDispatcher(document.hasFocus()) window.addEventListener("focus", () => focused.emit(true)) window.addEventListener("blur", () => focused.emit(false)) diff --git a/websocket.go b/websocket.go index 48eac6f..d6a89ba 100644 --- a/websocket.go +++ b/websocket.go @@ -28,10 +28,12 @@ import ( "github.com/coder/websocket" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli" - "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" + + "go.mau.fi/gomuks/pkg/hicli" + "go.mau.fi/gomuks/pkg/hicli/database" ) func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand) error { @@ -120,6 +122,17 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { lastDataReceived := &atomic.Int64{} lastDataReceived.Store(time.Now().UnixMilli()) const RecvTimeout = 60 * time.Second + lastImageAuthTokenSent := time.Now() + sendImageAuthToken := func() { + err := writeCmd(ctx, conn, &hicli.JSONCommand{ + Command: "image_auth_token", + Data: exerrors.Must(json.Marshal(gmx.generateImageToken())), + }) + if err != nil { + log.Err(err).Msg("Failed to write image auth token message") + return + } + } go func() { defer recoverPanic("event loop") defer closeOnce.Do(forceClose) @@ -137,6 +150,10 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event") } case <-ticker.C: + if time.Since(lastImageAuthTokenSent) > 30*time.Minute { + sendImageAuthToken() + lastImageAuthTokenSent = time.Now() + } if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() { log.Warn().Msg("No data received in a minute, closing connection") _ = conn.Close(StatusPingTimeout, "Ping timeout") @@ -187,6 +204,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) { log.Err(initErr).Msg("Failed to write init message") return } + go sendImageAuthToken() go gmx.sendInitialData(ctx, conn) log.Debug().Msg("Connection initialization complete") var closeErr websocket.CloseError @@ -242,10 +260,11 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) { } maxTS = room.SortingTimestamp.Time syncRoom := &hicli.SyncRoom{ - Meta: room, - Events: make([]*database.Event, 0, 2), - Timeline: make([]database.TimelineRowTuple, 0), - State: map[event.Type]map[string]database.EventRowID{}, + Meta: room, + Events: make([]*database.Event, 0, 2), + Timeline: make([]database.TimelineRowTuple, 0), + State: map[event.Type]map[string]database.EventRowID{}, + Notifications: make([]hicli.SyncNotification, 0), } payload.Rooms[room.ID] = syncRoom if room.PreviewEventRowID != 0 { @@ -255,6 +274,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) { return } if previewEvent != nil { + gmx.Client.ReprocessExistingEvent(ctx, previewEvent) previewMember, err := gmx.Client.DB.CurrentState.Get(ctx, room.ID, event.StateMember, previewEvent.Sender.String()) if err != nil { log.Err(err).Msg("Failed to get preview member event for room") @@ -269,6 +289,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) { if err != nil { log.Err(err).Msg("Failed to get last edit for preview event") } else if lastEdit != nil { + gmx.Client.ReprocessExistingEvent(ctx, lastEdit) syncRoom.Events = append(syncRoom.Events, lastEdit) } }