diff --git a/.editorconfig b/.editorconfig index d6cdd31..f683f30 100644 --- a/.editorconfig +++ b/.editorconfig @@ -11,5 +11,8 @@ insert_final_newline = true [*.{yaml,yml,sql}] indent_style = space +[*.html] +indent_size = 2 + [.gitlab-ci.yml] indent_size = 2 diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 173e36a..644bcc7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,14 +2,17 @@ name: Go on: [push, pull_request] +env: + GOTOOLCHAIN: local + jobs: lint: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.24", "1.25"] - name: Lint ${{ matrix.go-version == '1.25' && '(latest)' || '(old)' }} + go-version: ["1.25", "1.26"] + name: Lint ${{ matrix.go-version == '1.26' && '(latest)' || '(old)' }} steps: - uses: actions/checkout@v4 @@ -23,13 +26,11 @@ jobs: - name: Install libolm run: sudo apt-get install libolm-dev libolm3 - - name: Install goimports + - name: Install dependencies run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - - name: Install pre-commit - run: pip install pre-commit - - - name: Lint - run: pre-commit run -a + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9037dbd..2fa759a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,3 +1,3 @@ include: - project: 'mautrix/ci' - file: '/go.yml' + file: '/gov2-as-default.yml' diff --git a/.idea/icon.svg b/.idea/icon.svg new file mode 100644 index 0000000..87eeadb --- /dev/null +++ b/.idea/icon.svg @@ -0,0 +1,16 @@ + + + + + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6da4e37..77d6005 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,12 +9,18 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.1 + rev: v1.0.0-rc.4 hooks: - id: go-imports-repo + args: + - "-local" + - "go.mau.fi/mautrix-discord" + - "-w" - id: go-vet-repo-mod + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.3.1 + rev: v0.4.2 hooks: - id: zerolog-ban-msgf + - id: zerolog-use-stringer diff --git a/Dockerfile b/Dockerfile index 4664399..c5f7e32 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,6 @@ ENV UID=1337 \ RUN apk add --no-cache ffmpeg su-exec ca-certificates olm bash jq curl yq-go lottieconverter COPY --from=builder /usr/bin/mautrix-discord /usr/bin/mautrix-discord -COPY --from=builder /build/example-config.yaml /opt/mautrix-discord/example-config.yaml COPY --from=builder /build/docker-run.sh /docker-run.sh VOLUME /data diff --git a/Dockerfile.ci b/Dockerfile.ci index 32b9c35..6664fe9 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -7,7 +7,6 @@ RUN apk add --no-cache ffmpeg su-exec ca-certificates bash jq curl yq-go lottiec ARG EXECUTABLE=./mautrix-discord COPY $EXECUTABLE /usr/bin/mautrix-discord -COPY ./example-config.yaml /opt/mautrix-discord/example-config.yaml COPY ./docker-run.sh /docker-run.sh VOLUME /data diff --git a/LICENSE.exceptions b/LICENSE.exceptions new file mode 100644 index 0000000..2754eb3 --- /dev/null +++ b/LICENSE.exceptions @@ -0,0 +1,12 @@ +The mautrix-discord developers grant the following special exceptions: + +* to Beeper the right to embed the program in the Beeper clients and servers, + and use and distribute the collective work without applying the license to + the whole. +* to Element the right to distribute compiled binaries of the program as a part + of the Element Server Suite and other server bundles without applying the + license. + +All exceptions are only valid under the condition that any modifications to +the source code of mautrix-discord remain publicly available under the terms +of the GNU AGPL version 3 or later. diff --git a/README.md b/README.md index 73aad93..b54d262 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ # mautrix-discord + +> [!CAUTION] +> This branch houses a work-in-progress rewrite of the bridge to interface with +> [Megabridge/"bridgev2"][bridgev2]. This branch is **NOT** ready for general +> consumption, especially for self-hosting. + +[bridgev2]: https://github.com/mautrix/go/tree/38278ef37d199d3a9deba04b825a094eea6c1d10/bridgev2/unorganized-docs + A Matrix-Discord puppeting bridge based on [discordgo](https://github.com/bwmarrin/discordgo). ## Documentation diff --git a/ROADMAP.md b/ROADMAP.md index aab2680..4e30a1e 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,24 +1,24 @@ # Features & roadmap * Matrix → Discord * [ ] Message content - * [x] Plain text - * [x] Formatted messages - * [x] Media/files - * [x] Replies - * [x] Threads + * [ ] Plain text + * [ ] Formatted messages + * [ ] Media/files + * [ ] Replies + * [ ] Threads * [ ] Custom emojis - * [x] Message redactions - * [x] Reactions - * [x] Unicode emojis + * [ ] Message redactions + * [ ] Reactions + * [ ] Unicode emojis * [ ] Custom emojis (re-reacting with custom emojis sent from Discord already works) * [ ] Executing Discord bot commands - * [x] Basic arguments and subcommands + * [ ] Basic arguments and subcommands * [ ] Subcommand groups * [ ] Mention arguments * [ ] Attachment arguments * [ ] Presence - * [x] Typing notifications - * [x] Own read status + * [ ] Typing notifications + * [ ] Own read status * [ ] Power level * [ ] Membership actions * [ ] Invite @@ -31,37 +31,37 @@ * [ ] Initial room metadata * Discord → Matrix * [ ] Message content - * [x] Plain text - * [x] Formatted messages - * [x] Media/files - * [x] Replies - * [x] Threads - * [x] Auto-joining threads when opening + * [ ] Plain text + * [ ] Formatted messages + * [ ] Media/files + * [ ] Replies + * [ ] Threads + * [ ] Auto-joining threads when opening * [ ] Backfilling threads after joining - * [x] Custom emojis - * [x] Embeds + * [ ] Custom emojis + * [ ] Embeds * [ ] Interactive components - * [x] Interactions (commands) - * [x] @everyone/@here mentions into @room - * [x] Message deletions - * [x] Reactions - * [x] Unicode emojis - * [x] Custom emojis ([MSC4027](https://github.com/matrix-org/matrix-spec-proposals/pull/4027)) - * [x] Avatars + * [ ] Interactions (commands) + * [ ] @everyone/@here mentions into @room + * [ ] Message deletions + * [ ] Reactions + * [ ] Unicode emojis + * [ ] Custom emojis ([MSC4027](https://github.com/matrix-org/matrix-spec-proposals/pull/4027)) + * [ ] Avatars * [ ] Presence * [ ] Typing notifications (currently partial support: DMs work after you type in them) - * [x] Own read status + * [ ] Own read status * [ ] Role permissions * [ ] Membership actions * [ ] Invite * [ ] Join * [ ] Leave * [ ] Kick - * [x] Channel/group DM metadata changes - * [x] Title - * [x] Avatar - * [x] Description - * [x] Initial channel/group DM metadata + * [ ] Channel/group DM metadata changes + * [ ] Title + * [ ] Avatar + * [ ] Description + * [ ] Initial channel/group DM metadata * [ ] User metadata changes * [ ] Display name * [ ] Avatar @@ -69,11 +69,12 @@ * [ ] Display name * [ ] Avatar * Misc - * [x] Login methods - * [x] QR scan from mobile - * [x] Manually providing access token - * [x] Automatic portal creation - * [x] After login - * [x] When receiving DM + * [ ] Login methods + * [ ] QR scan from mobile + * [ ] Username/password + * [ ] Manually providing access token + * [ ] Automatic portal creation + * [ ] After login + * [ ] When receiving DM * [ ] Private chat creation by inviting Matrix puppet of Discord user to new room - * [x] Option to use own Matrix account for messages sent from other Discord clients + * [ ] Option to use own Matrix account for messages sent from other Discord clients diff --git a/attachments.go b/attachments.go deleted file mode 100644 index 15efbbd..0000000 --- a/attachments.go +++ /dev/null @@ -1,353 +0,0 @@ -package main - -import ( - "bytes" - "context" - "errors" - "fmt" - "image" - "io" - "net/http" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/gabriel-vasile/mimetype" - "go.mau.fi/util/exsync" - "go.mau.fi/util/ffmpeg" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" -) - -func downloadDiscordAttachment(cli *http.Client, url string, maxSize int64) ([]byte, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } - for key, value := range discordgo.DroidDownloadHeaders { - req.Header.Set(key, value) - } - - resp, err := cli.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode > 300 { - data, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("unexpected status %d downloading %s: %s", resp.StatusCode, url, data) - } - if resp.Header.Get("Content-Length") != "" { - length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse content length: %w", err) - } else if length > maxSize { - return nil, fmt.Errorf("attachment too large (%d > %d)", length, maxSize) - } - return io.ReadAll(resp.Body) - } else { - var mbe *http.MaxBytesError - data, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxSize)) - if err != nil && errors.As(err, &mbe) { - return nil, fmt.Errorf("attachment too large (over %d)", maxSize) - } - return data, err - } -} - -func uploadDiscordAttachment(cli *http.Client, url string, data []byte) error { - req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) - if err != nil { - return err - } - for key, value := range discordgo.DroidBaseHeaders { - req.Header.Set(key, value) - } - req.Header.Set("Content-Type", "application/octet-stream") - req.Header.Set("Referer", "https://discord.com/") - req.Header.Set("Sec-Fetch-Dest", "empty") - req.Header.Set("Sec-Fetch-Mode", "cors") - req.Header.Set("Sec-Fetch-Site", "cross-site") - - resp, err := cli.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode > 300 { - respData, _ := io.ReadAll(resp.Body) - return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, respData) - } - return nil -} - -func downloadMatrixAttachment(intent *appservice.IntentAPI, content *event.MessageEventContent) ([]byte, error) { - var file *event.EncryptedFileInfo - rawMXC := content.URL - - if content.File != nil { - file = content.File - rawMXC = file.URL - } - - mxc, err := rawMXC.Parse() - if err != nil { - return nil, err - } - - data, err := intent.DownloadBytes(mxc) - if err != nil { - return nil, err - } - - if file != nil { - err = file.DecryptInPlace(data) - if err != nil { - return nil, err - } - } - - return data, nil -} - -func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, meta AttachmentMeta, semaWg *sync.WaitGroup) (*database.File, error) { - dbFile := br.DB.File.New() - dbFile.Timestamp = time.Now() - dbFile.URL = url - dbFile.ID = meta.AttachmentID - dbFile.EmojiName = meta.EmojiName - dbFile.Size = len(data) - dbFile.MimeType = mimetype.Detect(data).String() - if meta.MimeType == "" { - meta.MimeType = dbFile.MimeType - } - if strings.HasPrefix(meta.MimeType, "image/") { - cfg, _, _ := image.DecodeConfig(bytes.NewReader(data)) - dbFile.Width = cfg.Width - dbFile.Height = cfg.Height - } - - uploadMime := meta.MimeType - if encrypt { - dbFile.Encrypted = true - dbFile.DecryptionInfo = attachment.NewEncryptedFile() - dbFile.DecryptionInfo.EncryptInPlace(data) - uploadMime = "application/octet-stream" - } - req := mautrix.ReqUploadMedia{ - ContentBytes: data, - ContentType: uploadMime, - } - if br.Config.Homeserver.AsyncMedia { - resp, err := intent.CreateMXC() - if err != nil { - return nil, err - } - dbFile.MXC = resp.ContentURI - req.MXC = resp.ContentURI - req.UnstableUploadURL = resp.UnstableUploadURL - semaWg.Add(1) - go func() { - defer semaWg.Done() - _, err = intent.UploadMedia(req) - if err != nil { - br.Log.Errorfln("Failed to upload %s: %v", req.MXC, err) - dbFile.Delete() - } - }() - } else { - uploaded, err := intent.UploadMedia(req) - if err != nil { - return nil, err - } - dbFile.MXC = uploaded.ContentURI - } - return dbFile, nil -} - -type AttachmentMeta struct { - AttachmentID string - MimeType string - EmojiName string - CopyIfMissing bool - Converter func([]byte) ([]byte, string, error) -} - -var NoMeta = AttachmentMeta{} - -type attachmentKey struct { - URL string - Encrypt bool -} - -func (br *DiscordBridge) convertLottie(data []byte) ([]byte, string, error) { - fps := br.Config.Bridge.AnimatedSticker.Args.FPS - width := br.Config.Bridge.AnimatedSticker.Args.Width - height := br.Config.Bridge.AnimatedSticker.Args.Height - target := br.Config.Bridge.AnimatedSticker.Target - var lottieTarget, outputMime string - switch target { - case "png": - lottieTarget = "png" - outputMime = "image/png" - fps = 1 - case "gif": - lottieTarget = "gif" - outputMime = "image/gif" - case "webm": - lottieTarget = "pngs" - outputMime = "video/webm" - case "webp": - lottieTarget = "pngs" - outputMime = "image/webp" - case "disable": - return data, "application/json", nil - default: - return nil, "", fmt.Errorf("invalid animated sticker target %q in bridge config", br.Config.Bridge.AnimatedSticker.Target) - } - - ctx := context.Background() - tempdir, err := os.MkdirTemp("", "mautrix_discord_lottie_") - if err != nil { - return nil, "", fmt.Errorf("failed to create temp dir: %w", err) - } - defer func() { - removErr := os.RemoveAll(tempdir) - if removErr != nil { - br.Log.Warnfln("Failed to delete lottie conversion temp dir: %v", removErr) - } - }() - - lottieOutput := filepath.Join(tempdir, "out_") - if lottieTarget != "pngs" { - lottieOutput = filepath.Join(tempdir, "output."+lottieTarget) - } - cmd := exec.CommandContext(ctx, "lottieconverter", "-", lottieOutput, lottieTarget, fmt.Sprintf("%dx%d", width, height), strconv.Itoa(fps)) - cmd.Stdin = bytes.NewReader(data) - err = cmd.Run() - if err != nil { - return nil, "", fmt.Errorf("failed to run lottieconverter: %w", err) - } - var path string - if lottieTarget == "pngs" { - var videoCodec string - outputExtension := "." + target - if target == "webm" { - videoCodec = "libvpx-vp9" - } else if target == "webp" { - videoCodec = "libwebp_anim" - } else { - panic(fmt.Errorf("impossible case: unknown target %q", target)) - } - path, err = ffmpeg.ConvertPath( - ctx, lottieOutput+"*.png", outputExtension, - []string{"-framerate", strconv.Itoa(fps), "-pattern_type", "glob"}, - []string{"-c:v", videoCodec, "-pix_fmt", "yuva420p", "-f", target}, - false, - ) - if err != nil { - return nil, "", fmt.Errorf("failed to run ffmpeg: %w", err) - } - } else { - path = lottieOutput - } - data, err = os.ReadFile(path) - if err != nil { - return nil, "", fmt.Errorf("failed to read converted file: %w", err) - } - return data, outputMime, nil -} - -func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, url string, encrypt bool, meta AttachmentMeta) (returnDBFile *database.File, returnErr error) { - isCacheable := br.Config.Bridge.CacheMedia != "never" && (br.Config.Bridge.CacheMedia == "always" || !encrypt) - returnDBFile = br.DB.File.Get(url, encrypt) - if returnDBFile == nil { - transferKey := attachmentKey{url, encrypt} - once, _ := br.attachmentTransfers.GetOrSet(transferKey, &exsync.ReturnableOnce[*database.File]{}) - returnDBFile, returnErr = once.Do(func() (onceDBFile *database.File, onceErr error) { - if isCacheable { - onceDBFile = br.DB.File.Get(url, encrypt) - if onceDBFile != nil { - return - } - } - - const attachmentSizeVal = 1 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - onceErr = br.parallelAttachmentSemaphore.Acquire(ctx, attachmentSizeVal) - cancel() - if onceErr != nil { - br.ZLog.Warn().Err(onceErr).Msg("Failed to acquire semaphore") - onceErr = fmt.Errorf("reuploading timed out") - return - } - var semaWg sync.WaitGroup - semaWg.Add(1) - defer semaWg.Done() - go func() { - semaWg.Wait() - br.parallelAttachmentSemaphore.Release(attachmentSizeVal) - }() - - var data []byte - data, onceErr = downloadDiscordAttachment(http.DefaultClient, url, br.MediaConfig.UploadSize) - if onceErr != nil { - return - } - - if meta.Converter != nil { - data, meta.MimeType, onceErr = meta.Converter(data) - if onceErr != nil { - onceErr = fmt.Errorf("failed to convert attachment: %w", onceErr) - return - } - } - - onceDBFile, onceErr = br.uploadMatrixAttachment(intent, data, url, encrypt, meta, &semaWg) - if onceErr != nil { - return - } - if isCacheable { - onceDBFile.Insert(nil) - } - br.attachmentTransfers.Delete(transferKey) - return - }) - } - return -} - -func (portal *Portal) getEmojiMXCByDiscordID(emojiID, name string, animated bool) id.ContentURI { - mxc := portal.bridge.DMA.EmojiMXC(emojiID, name, animated) - if !mxc.IsEmpty() { - return mxc - } - var url, mimeType string - if animated { - url = discordgo.EndpointEmojiAnimated(emojiID) - mimeType = "image/gif" - } else { - url = discordgo.EndpointEmoji(emojiID) - mimeType = "image/png" - } - dbFile, err := portal.bridge.copyAttachmentToMatrix(portal.MainIntent(), url, false, AttachmentMeta{ - AttachmentID: emojiID, - MimeType: mimeType, - EmojiName: name, - }) - if err != nil { - portal.log.Warn().Err(err).Str("emoji_id", emojiID).Msg("Failed to copy emoji to Matrix") - return id.ContentURI{} - } - return dbFile.MXC -} diff --git a/backfill.go b/backfill.go deleted file mode 100644 index c4966bd..0000000 --- a/backfill.go +++ /dev/null @@ -1,383 +0,0 @@ -package main - -import ( - "context" - "crypto/sha256" - "encoding/base64" - "fmt" - "sort" - - "github.com/bwmarrin/discordgo" - "github.com/rs/zerolog" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" -) - -func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) { - log := portal.log - defer func() { - log.Debug().Msg("Forward backfill finished, unlocking lock") - portal.forwardBackfillLock.Unlock() - }() - // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room. - if portal.forwardBackfillLock.TryLock() { - panic("forwardBackfillInitial() called without locking forwardBackfillLock") - } - - limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel - if portal.GuildID == "" { - limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM - if thread != nil { - limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread - thread.initialBackfillAttempted = true - } - } - if limit == 0 { - return - } - - with := log.With(). - Str("action", "initial backfill"). - Str("room_id", portal.MXID.String()). - Int("limit", limit) - if thread != nil { - with = with.Str("thread_id", thread.ID) - } - log = with.Logger() - - portal.backfillLimited(log, source, limit, "", thread) -} - -func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) { - if portal.MXID == "" { - return - } - - limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel - if portal.GuildID == "" { - limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM - if thread != nil { - limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread - } - } - if limit == 0 { - return - } - with := portal.log.With(). - Str("action", "missed event backfill"). - Str("room_id", portal.MXID.String()). - Int("limit", limit) - if thread != nil { - with = with.Str("thread_id", thread.ID) - } - log := with.Logger() - - portal.forwardBackfillLock.Lock() - defer portal.forwardBackfillLock.Unlock() - - var lastMessage *database.Message - if thread != nil { - lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) - } else { - lastMessage = portal.bridge.DB.Message.GetLast(portal.Key) - } - if lastMessage == nil || serverLastMessageID == "" { - log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata") - return - } else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) { - log.Debug(). - Str("last_bridged_message", lastMessage.DiscordID). - Str("last_server_message", serverLastMessageID). - Msg("Not backfilling, last message in database is newer than last message in metadata") - return - } - log.Debug(). - Str("last_bridged_message", lastMessage.DiscordID). - Str("last_server_message", serverLastMessageID). - Msg("Backfilling missed messages") - if limit < 0 { - portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread) - } else { - portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread) - } -} - -const messageFetchChunkSize = 50 - -func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) { - var messages []*discordgo.Message - var before string - var foundAll bool - protoChannelID := portal.Key.ChannelID - if thread != nil { - protoChannelID = thread.ID - } - for { - log.Debug().Str("before_id", before).Msg("Fetching messages for backfill") - newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "", portal.RefererOptIfUser(source.Session, protoChannelID)...) - if err != nil { - return nil, false, err - } - if until != "" { - for i, msg := range newMessages { - if compareMessageIDs(msg.ID, until) <= 0 { - log.Debug(). - Str("message_id", msg.ID). - Str("until_id", until). - Msg("Found message that was already bridged") - newMessages = newMessages[:i] - foundAll = true - break - } - } - } - messages = append(messages, newMessages...) - log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection") - if len(newMessages) < messageFetchChunkSize || len(messages) >= limit { - break - } - before = newMessages[len(newMessages)-1].ID - } - if len(messages) > limit { - foundAll = false - messages = messages[:limit] - } - return messages, foundAll, nil -} - -func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) { - messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread) - if err != nil { - if source.handlePossible40002(err) { - panic(err) - } - log.Err(err).Msg("Error collecting messages to forward backfill") - return - } - log.Info(). - Int("count", len(messages)). - Bool("found_all", foundAll). - Msg("Collected messages to backfill") - sort.Sort(MessageSlice(messages)) - if !foundAll && after != "" { - _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: "Some messages may have been missed here while the bridge was offline.", - }, nil, 0) - if err != nil { - log.Warn().Err(err).Msg("Failed to send missed message warning") - } else { - log.Debug().Msg("Sent warning about possibly missed messages") - } - } - portal.sendBackfillBatch(log, source, messages, thread) -} - -func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) { - protoChannelID := portal.Key.ChannelID - if thread != nil { - protoChannelID = thread.ID - } - for { - log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill") - messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "", portal.RefererOptIfUser(source.Session, protoChannelID)...) - if err != nil { - log.Err(err).Msg("Error fetching chunk of messages to forward backfill") - return - } - log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill") - sort.Sort(MessageSlice(messages)) - - portal.sendBackfillBatch(log, source, messages, thread) - - if len(messages) < messageFetchChunkSize { - // Assume that was all the missing messages - log.Debug().Msg("Chunk had less than 50 messages, stopping backfill") - return - } - after = messages[len(messages)-1].ID - } -} - -func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) { - if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) { - log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint") - portal.forwardBatchSend(log, source, messages, thread) - } else { - log.Debug().Msg("Not using hungryserv, sending messages one by one") - for _, msg := range messages { - portal.handleDiscordMessageCreate(source, msg, thread) - } - } -} - -func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) { - evts, metas, dbMessages := portal.convertMessageBatch(log, source, messages, thread) - if len(evts) == 0 { - log.Warn().Msg("Didn't get any events to backfill") - return - } - log.Info().Int("events", len(evts)).Msg("Converted messages to backfill") - resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &mautrix.ReqBeeperBatchSend{ - Forward: true, - Events: evts, - }) - if err != nil { - log.Err(err).Msg("Error sending backfill batch") - return - } - for i, evtID := range resp.EventIDs { - dbMessages[i].MXID = evtID - if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread { - // TODO proper context - ctx := log.WithContext(context.Background()) - portal.bridge.threadFound(ctx, source, &dbMessages[i], metas[i].ID, metas[i].Thread) - } - } - portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages) -} - -func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []*discordgo.Message, []database.Message) { - var discordThreadID string - var threadRootEvent, lastThreadEvent id.EventID - if thread != nil { - discordThreadID = thread.ID - threadRootEvent = thread.RootMXID - lastThreadEvent = threadRootEvent - lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) - if lastInThread != nil { - lastThreadEvent = lastInThread.MXID - } - } - - evts := make([]*event.Event, 0, len(messages)) - dbMessages := make([]database.Message, 0, len(messages)) - metas := make([]*discordgo.Message, 0, len(messages)) - ctx := context.Background() - for _, msg := range messages { - for _, mention := range msg.Mentions { - puppet := portal.bridge.GetPuppetByID(mention.ID) - puppet.UpdateInfo(nil, mention, nil) - } - - puppet := portal.bridge.GetPuppetByID(msg.Author.ID) - puppet.UpdateInfo(source, msg.Author, msg) - intent := puppet.IntentFor(portal) - replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true) - mentions := portal.convertDiscordMentions(msg, false) - - ts, _ := discordgo.SnowflakeTimestamp(msg.ID) - log := log.With(). - Str("message_id", msg.ID). - Int("message_type", int(msg.Type)). - Str("author_id", msg.Author.ID). - Logger() - parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg) - for i, part := range parts { - if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil { - part.Content.RelatesTo = &event.RelatesTo{} - } - if threadRootEvent != "" { - part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent) - } - if replyTo != nil { - part.Content.RelatesTo.SetReplyTo(replyTo.EventID) - // Only set reply for first event - replyTo = nil - } - - part.Content.Mentions = mentions - // Only set mentions for first event, but keep empty object for rest - mentions = &event.Mentions{} - - partName := part.AttachmentID - // Always use blank part name for first part so that replies and other things - // can reference it without knowing about attachments. - if i == 0 { - partName = "" - } - evt := &event.Event{ - ID: portal.deterministicEventID(msg.ID, partName), - Type: part.Type, - Sender: intent.UserID, - Timestamp: ts.UnixMilli(), - Content: event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, - } - var err error - evt.Type, err = portal.encrypt(intent, &evt.Content, evt.Type) - if err != nil { - log.Err(err).Msg("Failed to encrypt event") - continue - } - intent.AddDoublePuppetValue(&evt.Content) - evts = append(evts, evt) - dbMessages = append(dbMessages, database.Message{ - Channel: portal.Key, - DiscordID: msg.ID, - SenderID: msg.Author.ID, - Timestamp: ts, - AttachmentID: part.AttachmentID, - SenderMXID: intent.UserID, - }) - if i == 0 { - metas = append(metas, msg) - } else { - metas = append(metas, nil) - } - lastThreadEvent = evt.ID - } - } - return evts, metas, dbMessages -} - -func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID { - data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName) - sum := sha256.Sum256([]byte(data)) - return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:]))) -} - -// compareMessageIDs compares two Discord message IDs. -// -// If the first ID is lower, -1 is returned. -// If the second ID is lower, 1 is returned. -// If the IDs are equal, 0 is returned. -func compareMessageIDs(id1, id2 string) int { - if id1 == id2 { - return 0 - } - if len(id1) < len(id2) { - return -1 - } else if len(id2) < len(id1) { - return 1 - } - if id1 < id2 { - return -1 - } - return 1 -} - -func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool { - return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1 -} - -type MessageSlice []*discordgo.Message - -var _ sort.Interface = (MessageSlice)(nil) - -func (a MessageSlice) Len() int { - return len(a) -} - -func (a MessageSlice) Swap(i, j int) { - a[i], a[j] = a[j], a[i] -} - -func (a MessageSlice) Less(i, j int) bool { - return compareMessageIDs(a[i].ID, a[j].ID) == -1 -} diff --git a/build.sh b/build.sh index 2409c5b..aa6d009 100755 --- a/build.sh +++ b/build.sh @@ -1,2 +1,4 @@ #!/bin/sh -go build -ldflags "-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date '+%b %_d %Y, %H:%M:%S'`'" "$@" +MAUTRIX_VERSION=$(cat go.mod | grep 'maunium.net/go/mautrix ' | awk '{ print $2 }') +GO_LDFLAGS="-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date -Iseconds`' -X 'maunium.net/go/mautrix.GoModVersion=$MAUTRIX_VERSION'" +go build -ldflags="-s -w $GO_LDFLAGS" ./cmd/mautrix-discord "$@" diff --git a/cmd/authtester/main.go b/cmd/authtester/main.go new file mode 100644 index 0000000..3e712c8 --- /dev/null +++ b/cmd/authtester/main.go @@ -0,0 +1,491 @@ +package main + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "maps" + "net/http" + "net/http/cookiejar" + "os" + "os/signal" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/google/uuid" + "github.com/rs/zerolog" + "golang.org/x/term" + + "go.mau.fi/mautrix-discord/pkg/discordauth" +) + +const fallbackClientBuildNumber = 497254 + +var mainJSRegex = regexp.MustCompile(`src="(/assets/web\.[a-f0-9]{12,32}\.js)"`) +var buildNumberRegex = regexp.MustCompile(`(?:buildNumber|build_number):\s?['"]?(\d{6,})['"]?`) + +func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { + var buildNumberFlag int + var apiBase string + var verbose bool + + flag.IntVar(&buildNumberFlag, "build-number", 0, "Discord client build number (default: auto-detect from discord.com)") + flag.StringVar(&apiBase, "api-base", "https://discord.com/api/v9", "Discord API base URL") + flag.BoolVar(&verbose, "verbose", false, "Lower the log level to debug") + flag.Parse() + + log := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}). + Level(zerolog.InfoLevel). + With(). + Timestamp(). + Logger() + if verbose { + log = log.Level(zerolog.DebugLevel) + } + + ctx, stop := signal.NotifyContext(log.WithContext(context.Background()), os.Interrupt, syscall.SIGTERM) + defer stop() + + jar, err := cookiejar.New(nil) + if err != nil { + return fmt.Errorf("failed to create cookie jar: %w", err) + } + client := &http.Client{ + Timeout: 30 * time.Second, + Jar: jar, + } + captchaServer := newCaptchaServer(log.With().Str("component", "authtester captcha").Logger()) + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := captchaServer.Close(shutdownCtx); err != nil { + fmt.Fprintf(os.Stderr, "Failed to gracefully terminate CAPTCHA server: %v\n", err) + } + }() + prompter := newPrompter(os.Stdin, os.Stdout, captchaServer) + + buildNumber := buildNumberFlag + if buildNumber == 0 { + fmt.Fprintln(os.Stdout, "Detecting an appropriate Discord client build number...") + buildNumber, err = fetchClientBuildNumber(ctx, client) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to detect build number automatically: %v\n", err) + fmt.Fprintf(os.Stderr, "Falling back to build number %d\n", fallbackClientBuildNumber) + buildNumber = fallbackClientBuildNumber + } + } + fmt.Fprintf(os.Stdout, "Using client build number %d\n", buildNumber) + + personality, err := newDefaultPersonality(buildNumber) + if err != nil { + return fmt.Errorf("failed to create auth personality: %w", err) + } + + machine := discordauth.NewAuthMachine(ctx, client, personality, prompter) + machine.APIBase = apiBase + if verbose { + machine.LogFilters = discordauth.LeakyDevelopmentAuthMachineLogFilters + } else { + machine.LogFilters = discordauth.DefaultAuthMachineLogFilters + } + + fmt.Fprintln(os.Stdout, "Preparing Discord auth...") + if err = machine.Prepare(ctx); err != nil { + return fmt.Errorf("failed to prepare auth machine: %w", err) + } + + login, err := prompter.promptRequired("Email or phone") + if err != nil { + return fmt.Errorf("failed to read login: %w", err) + } + password, err := prompter.promptSecretRequired("Password") + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + + fmt.Fprintln(os.Stdout, "Logging in...") + resp, err := machine.Login(ctx, discordauth.NewCreds(login, password)) + if err != nil { + return fmt.Errorf("login failed: %w", err) + } + + fmt.Fprintln(os.Stdout, "Login succeeded.") + fmt.Fprintf(os.Stdout, "User ID: %s\n", resp.UserID) + fmt.Fprintf(os.Stdout, "Token length: %d\n", len(resp.Token.UnwrapSensitive())) + if resp.UserSettings.Locale != "" { + fmt.Fprintf(os.Stdout, "User locale: %s\n", resp.UserSettings.Locale) + } + if resp.UserSettings.Theme != "" { + fmt.Fprintf(os.Stdout, "User theme: %s\n", resp.UserSettings.Theme) + } + + return nil +} + +func newDefaultPersonality(buildNumber int) (*discordauth.Personality, error) { + launchSignature, err := discordgo.NewVanillaSignature() + if err != nil { + return nil, fmt.Errorf("failed to generate launch signature: %w", err) + } + + extraHeaders := maps.Clone(discordgo.DroidFetchHeaders) + delete(extraHeaders, "User-Agent") + + return &discordauth.Personality{ + UserAgent: discordgo.DroidBrowserUserAgent, + Locale: "en-US", + TimeZone: defaultTimeZone(), + DebugOptions: discordauth.DefaultDebugOptions, + SuperProperties: discordauth.SuperProperties{ + OS: "Windows", + Browser: "Chrome", + SystemLocale: "en-US", + HasClientMods: false, + BrowserUserAgent: discordgo.DroidBrowserUserAgent, + BrowserVersion: discordgo.DroidBrowserVersion, + OSVersion: "10", + ReleaseChannel: "stable", + ClientBuildNumber: buildNumber, + ClientLaunchID: uuid.NewString(), + LaunchSignature: launchSignature, + ClientAppState: "focused", + }, + ExtraHeaders: extraHeaders, + }, nil +} + +func defaultTimeZone() string { + timeZone := time.Now().Location().String() + if timeZone == "" || timeZone == "Local" { + return "UTC" + } + + return timeZone +} + +func fetchClientBuildNumber(ctx context.Context, client *http.Client) (int, error) { + mainPageReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://discord.com/channels/@me", nil) + if err != nil { + return 0, fmt.Errorf("failed to create main page request: %w", err) + } + addHeaders(mainPageReq.Header, discordgo.DroidBaseHeaders) + mainPageReq.Header.Set("Sec-Fetch-Dest", "document") + mainPageReq.Header.Set("Sec-Fetch-Mode", "navigate") + mainPageReq.Header.Set("Sec-Fetch-Site", "none") + mainPageReq.Header.Set("Sec-Fetch-User", "?1") + mainPageReq.Header.Set("Upgrade-Insecure-Requests", "1") + mainPageReq.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7") + + mainPageData, err := doRequest(ctx, client, mainPageReq) + if err != nil { + return 0, fmt.Errorf("failed to fetch main page: %w", err) + } + + mainJSMatch := mainJSRegex.FindSubmatch(mainPageData) + if mainJSMatch == nil { + return 0, fmt.Errorf("failed to find main JS URL in Discord main page") + } + + jsReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://discord.com"+string(mainJSMatch[1]), nil) + if err != nil { + return 0, fmt.Errorf("failed to create JS request: %w", err) + } + addHeaders(jsReq.Header, discordgo.DroidBaseHeaders) + jsReq.Header.Set("Sec-Fetch-Dest", "script") + jsReq.Header.Set("Sec-Fetch-Mode", "no-cors") + jsReq.Header.Set("Sec-Fetch-Site", "same-origin") + jsReq.Header.Set("Accept", "*/*") + + jsData, err := doRequest(ctx, client, jsReq) + if err != nil { + return 0, fmt.Errorf("failed to fetch main JS: %w", err) + } + + buildNumberMatch := buildNumberRegex.FindSubmatch(jsData) + if buildNumberMatch == nil { + return 0, fmt.Errorf("failed to find build number in Discord JS bundle") + } + + buildNumber, err := strconv.Atoi(string(buildNumberMatch[1])) + if err != nil { + return 0, fmt.Errorf("failed to parse build number %q: %w", buildNumberMatch[1], err) + } + + return buildNumber, nil +} + +func doRequest(ctx context.Context, client *http.Client, req *http.Request) ([]byte, error) { + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read %s %s response body: %w", req.Method, req.URL, err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("unexpected status %s for %s %s", resp.Status, req.Method, req.URL) + } + if err := ctx.Err(); err != nil { + return nil, err + } + + return body, nil +} + +func addHeaders(header http.Header, values map[string]string) { + for key, value := range values { + header.Set(key, value) + } +} + +type prompter struct { + in *bufio.Reader + inFile *os.File + out io.Writer + captchaServer *captchaServer +} + +var _ discordauth.ChallengeHandler = (*prompter)(nil) + +type mfaMethodOption struct { + Type discordauth.AuthenticatorType + Label string + CodePrompt string +} + +func newPrompter(in io.Reader, out io.Writer, captchaServer *captchaServer) *prompter { + file, _ := in.(*os.File) + + return &prompter{ + in: bufio.NewReader(in), + inFile: file, + out: out, + captchaServer: captchaServer, + } +} + +func (p *prompter) promptRequired(label string) (string, error) { + value, err := p.prompt(label) + if err != nil { + return "", err + } + if value == "" { + return "", fmt.Errorf("%s is required", strings.ToLower(label)) + } + + return value, nil +} + +func (p *prompter) prompt(label string) (string, error) { + fmt.Fprintf(p.out, "%s: ", label) + line, err := p.in.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return "", err + } + if errors.Is(err, io.EOF) && len(line) == 0 { + return "", io.EOF + } + + return strings.TrimRight(line, "\r\n"), nil +} + +func (p *prompter) promptSecretRequired(label string) (string, error) { + value, err := p.promptSecret(label) + if err != nil { + return "", err + } + if value == "" { + return "", fmt.Errorf("%s is required", strings.ToLower(label)) + } + + return value, nil +} + +func (p *prompter) promptSecret(label string) (string, error) { + if p.inFile == nil || !term.IsTerminal(int(p.inFile.Fd())) { + return p.prompt(label) + } + + fmt.Fprintf(p.out, "%s: ", label) + line, err := term.ReadPassword(int(p.inFile.Fd())) + fmt.Fprintln(p.out) + if err != nil { + return "", err + } + + return strings.TrimRight(string(line), "\r\n"), nil +} + +func (p *prompter) promptMFAChoice(options []mfaMethodOption) (mfaMethodOption, error) { + fmt.Fprintln(p.out) + fmt.Fprintln(p.out, "Available MFA methods:") + for i, option := range options { + fmt.Fprintf(p.out, " %d. %s\n", i+1, option.Label) + } + + for { + choice, err := p.promptRequired("Choose MFA method") + if err != nil { + return mfaMethodOption{}, err + } + + index, err := strconv.Atoi(choice) + if err == nil && index >= 1 && index <= len(options) { + return options[index-1], nil + } + + fmt.Fprintf(p.out, "Invalid choice %q. Enter a number from 1 to %d.\n", choice, len(options)) + } +} + +func supportedMFAMethods(challenge *discordauth.MFAChallenge) []mfaMethodOption { + options := make([]mfaMethodOption, 0, 3) + if challenge.TOTPEnabled { + options = append(options, mfaMethodOption{ + Type: discordauth.AuthenticatorTOTP, + Label: "TOTP authenticator", + CodePrompt: "TOTP code", + }) + } + if challenge.SMSEnabled { + options = append(options, mfaMethodOption{ + Type: discordauth.AuthenticatorSMS, + Label: "SMS code", + CodePrompt: "SMS code", + }) + } + if challenge.BackupCodesAccepted { + options = append(options, mfaMethodOption{ + Type: discordauth.AuthenticatorBackup, + Label: "Backup code", + CodePrompt: "Backup code", + }) + } + + return options +} + +func newMFAContinue(challenge *discordauth.MFAChallenge, authType discordauth.AuthenticatorType, code string) *discordauth.MFAContinue { + return &discordauth.MFAContinue{ + Type: authType, + MFAContinuation: discordauth.MFAContinuation{ + MFAState: challenge.MFAState, + Code: code, + }, + } +} + +func (p *prompter) ContinueMFA(ctx context.Context, challenge *discordauth.MFAChallenge) (*discordauth.MFAContinue, error) { + options := supportedMFAMethods(challenge) + if len(options) == 0 { + if challenge.WebAuthnCredential != nil { + panic("authtester does not support WebAuthn MFA") + } + return nil, fmt.Errorf("discord did not offer a supported MFA method") + } + + selected := options[0] + if len(options) == 1 { + fmt.Fprintln(p.out) + fmt.Fprintf(p.out, "Using MFA method: %s\n", selected.Label) + } else { + var err error + selected, err = p.promptMFAChoice(options) + if err != nil { + return nil, err + } + } + + switch selected.Type { + case discordauth.AuthenticatorSMS: + if challenge.RequestSMS == nil { + return nil, fmt.Errorf("discord MFA challenge did not provide an SMS request callback") + } + + fmt.Fprintln(p.out) + fmt.Fprintln(p.out, "Requesting an MFA SMS code...") + resp, err := challenge.RequestSMS(ctx) + if err != nil { + return nil, fmt.Errorf("failed to request SMS code: %w", err) + } + if resp != nil && resp.Phone != "" { + fmt.Fprintf(p.out, "Discord sent an MFA SMS code to %s\n", resp.Phone) + } else { + fmt.Fprintln(p.out, "Discord sent an MFA SMS code.") + } + } + + code, err := p.promptSecretRequired(selected.CodePrompt) + if err != nil { + return nil, err + } + + return newMFAContinue(challenge, selected.Type, code), nil +} + +func (p *prompter) SolveCaptcha(ctx context.Context, captcha *discordauth.Captcha) (*discordauth.CaptchaSolution, error) { + captchaData, err := json.MarshalIndent(captcha, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to encode captcha challenge: %w", err) + } + + fmt.Fprintln(p.out) + fmt.Fprintln(p.out, "Received CAPTCHA challenge:") + fmt.Fprintln(p.out, string(captchaData)) + + if p.captchaServer != nil && supportsBrowserCaptcha(captcha) { + pageURL, waitForSolution, err := p.captchaServer.startChallenge(captcha) + if err != nil { + fmt.Fprintf(p.out, "Failed to start local CAPTCHA page: %v\n", err) + fmt.Fprintln(p.out, "Falling back to manual token entry.") + } else { + fmt.Fprintln(p.out) + fmt.Fprintln(p.out, "Open this page in your browser and solve the CAPTCHA:") + fmt.Fprintf(p.out, " %s\n", pageURL) + fmt.Fprintln(p.out, "If the page reports an error or you cancel it, authtester will fall back to manual token entry.") + + solution, err := waitForSolution(ctx) + switch { + case err == nil: + return &discordauth.CaptchaSolution{Solution: solution}, nil + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return nil, err + case errors.Is(err, errCaptchaBrowserCanceled): + fmt.Fprintln(p.out, "Local CAPTCHA page was canceled.") + fmt.Fprintln(p.out, "Falling back to manual token entry.") + default: + fmt.Fprintf(p.out, "Local CAPTCHA page failed: %v\n", err) + fmt.Fprintln(p.out, "Falling back to manual token entry.") + } + } + } else { + fmt.Fprintln(p.out, "Local browser flow only supports hCaptcha challenges with a sitekey.") + } + + solution, err := p.promptRequired("CAPTCHA solution") + if err != nil { + return nil, err + } + + return &discordauth.CaptchaSolution{Solution: solution}, nil +} diff --git a/cmd/authtester/main_captcha.go b/cmd/authtester/main_captcha.go new file mode 100644 index 0000000..787140e --- /dev/null +++ b/cmd/authtester/main_captcha.go @@ -0,0 +1,460 @@ +package main + +import ( + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" + + "go.mau.fi/mautrix-discord/pkg/discordauth" +) + +var errCaptchaBrowserCanceled = errors.New("captcha browser flow canceled") + +//go:embed main_captcha.html +var captchaPageHTML string + +type captchaServer struct { + mu sync.Mutex + log zerolog.Logger + handler http.Handler + server *http.Server + ln net.Listener + baseURL string + active *activeCaptcha +} + +type activeCaptcha struct { + challenge browserCaptchaChallenge + resultCh chan captchaBrowserResult +} + +type browserCaptchaChallenge struct { + ID string `json:"id"` + Service discordauth.CaptchaService `json:"service"` + SiteKey string `json:"site_key"` + RqData string `json:"rqdata,omitempty"` + Invisible bool `json:"invisible"` +} + +type captchaBrowserResult struct { + token string + err error +} + +type captchaSolveRequest struct { + ID string `json:"id"` + Token string `json:"token"` +} + +type captchaCancelRequest struct { + ID string `json:"id"` +} + +type captchaErrorRequest struct { + ID string `json:"id"` + Error string `json:"error"` +} + +type captchaErrorResponse struct { + Error string `json:"error"` +} + +func newCaptchaServer(log zerolog.Logger) *captchaServer { + cs := &captchaServer{log: log} + mux := http.NewServeMux() + mux.HandleFunc("/", cs.handlePage) + mux.HandleFunc("/api/challenge", cs.handleChallenge) + mux.HandleFunc("/api/solve", cs.handleSolve) + mux.HandleFunc("/api/cancel", cs.handleCancel) + mux.HandleFunc("/api/error", cs.handleError) + cs.handler = mux + return cs +} + +func supportsBrowserCaptcha(captcha *discordauth.Captcha) bool { + return captcha != nil && + captcha.Service == discordauth.CaptchaServiceHCaptcha && + captcha.SiteKey != nil && + strings.TrimSpace(*captcha.SiteKey) != "" +} + +func (cs *captchaServer) startChallenge(captcha *discordauth.Captcha) (string, func(context.Context) (string, error), error) { + if !supportsBrowserCaptcha(captcha) { + return "", nil, fmt.Errorf("browser flow only supports hcaptcha challenges with a sitekey") + } + if err := cs.ensureStarted(); err != nil { + return "", nil, err + } + + challenge := &activeCaptcha{ + challenge: browserCaptchaChallenge{ + ID: uuid.NewString(), + Service: captcha.Service, + SiteKey: strings.TrimSpace(*captcha.SiteKey), + Invisible: captcha.Invisible, + }, + resultCh: make(chan captchaBrowserResult, 1), + } + if captcha.RqData != nil { + challenge.challenge.RqData = *captcha.RqData + } + + cs.mu.Lock() + if cs.active != nil { + cs.log.Warn(). + Str("replaced_challenge_id", cs.active.challenge.ID). + Msg("Replacing active CAPTCHA challenge before it was resolved") + } + cs.active = challenge + pageURL := cs.baseURL + cs.mu.Unlock() + + cs.log.Info(). + Str("challenge_id", challenge.challenge.ID). + Str("captcha_service", string(challenge.challenge.Service)). + Bool("captcha_invisible", challenge.challenge.Invisible). + Bool("captcha_has_rqdata", challenge.challenge.RqData != ""). + Str("page_url", pageURL). + Msg("Started local CAPTCHA challenge") + + wait := func(ctx context.Context) (string, error) { + defer cs.clearActiveChallenge(challenge.challenge.ID) + + select { + case result := <-challenge.resultCh: + if result.err != nil { + cs.log.Warn(). + Str("challenge_id", challenge.challenge.ID). + Err(result.err). + Msg("Local CAPTCHA challenge completed with error") + return "", result.err + } + if result.token == "" { + return "", fmt.Errorf("browser page returned an empty CAPTCHA token") + } + cs.log.Info(). + Str("challenge_id", challenge.challenge.ID). + Int("token_length", len(result.token)). + Msg("Local CAPTCHA challenge returned a token") + return result.token, nil + case <-ctx.Done(): + cs.log.Warn(). + Str("challenge_id", challenge.challenge.ID). + Err(ctx.Err()). + Msg("Stopped waiting for local CAPTCHA challenge") + return "", ctx.Err() + } + } + + return pageURL, wait, nil +} + +func (cs *captchaServer) ensureStarted() error { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.server != nil { + return nil + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return fmt.Errorf("failed to listen on 127.0.0.1: %w", err) + } + + addr := ln.Addr().(*net.TCPAddr) + server := &http.Server{ + Handler: cs.handler, + ReadHeaderTimeout: 5 * time.Second, + } + + cs.ln = ln + cs.server = server + cs.baseURL = fmt.Sprintf("http://localhost:%d/", addr.Port) + + cs.log.Info(). + Str("listen_addr", ln.Addr().String()). + Str("page_url", cs.baseURL). + Msg("Started local CAPTCHA server") + + go func() { + if err := server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + cs.log.Error().Err(err).Msg("Local CAPTCHA server stopped unexpectedly") + cs.failActiveChallenge(fmt.Errorf("captcha server stopped unexpectedly: %w", err)) + } + }() + + return nil +} + +func (cs *captchaServer) Close(ctx context.Context) error { + cs.mu.Lock() + server := cs.server + cs.server = nil + cs.ln = nil + cs.baseURL = "" + cs.active = nil + cs.mu.Unlock() + + if server == nil { + return nil + } + + cs.log.Info().Msg("Shutting down local CAPTCHA server") + return server.Shutdown(ctx) +} + +func (cs *captchaServer) clearActiveChallenge(id string) { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.active != nil && cs.active.challenge.ID == id { + cs.active = nil + } +} + +func (cs *captchaServer) failActiveChallenge(err error) { + cs.log.Error().Err(err).Msg("Failing active CAPTCHA challenge") + cs.mu.Lock() + active := cs.active + cs.active = nil + cs.mu.Unlock() + + if active == nil { + return + } + + select { + case active.resultCh <- captchaBrowserResult{err: err}: + default: + } +} + +func (cs *captchaServer) resolveActiveChallenge(id string, result captchaBrowserResult) error { + cs.mu.Lock() + active := cs.active + if active == nil { + cs.mu.Unlock() + cs.log.Warn(). + Str("challenge_id", id). + Msg("Attempted to resolve CAPTCHA challenge, but none is active") + return fmt.Errorf("no active captcha challenge") + } + if active.challenge.ID != id { + cs.mu.Unlock() + cs.log.Warn(). + Str("challenge_id", id). + Str("active_challenge_id", active.challenge.ID). + Msg("Attempted to resolve a stale CAPTCHA challenge") + return fmt.Errorf("captcha challenge is no longer current") + } + cs.active = nil + cs.mu.Unlock() + + select { + case active.resultCh <- result: + return nil + default: + cs.log.Warn(). + Str("challenge_id", id). + Msg("CAPTCHA challenge was already resolved") + return fmt.Errorf("captcha challenge already resolved") + } +} + +func (cs *captchaServer) currentChallenge() *browserCaptchaChallenge { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.active == nil { + return nil + } + + challenge := cs.active.challenge + return &challenge +} + +func (cs *captchaServer) handlePage(w http.ResponseWriter, r *http.Request) { + log := cs.requestLogger(r) + if r.Method != http.MethodGet { + log.Warn().Msg("Rejected CAPTCHA page request with unsupported method") + writeCaptchaMethodNotAllowed(w, http.MethodGet) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(captchaPageHTML)) + log.Info().Msg("Served CAPTCHA page") +} + +func (cs *captchaServer) handleChallenge(w http.ResponseWriter, r *http.Request) { + log := cs.requestLogger(r) + if r.Method != http.MethodGet { + log.Warn().Msg("Rejected CAPTCHA challenge request with unsupported method") + writeCaptchaMethodNotAllowed(w, http.MethodGet) + return + } + + challenge := cs.currentChallenge() + if challenge == nil { + log.Warn().Msg("Requested CAPTCHA challenge, but none is active") + writeCaptchaJSON(w, http.StatusNotFound, captchaErrorResponse{Error: "no active captcha challenge"}) + return + } + + log.Info(). + Str("challenge_id", challenge.ID). + Str("captcha_service", string(challenge.Service)). + Bool("captcha_invisible", challenge.Invisible). + Bool("captcha_has_rqdata", challenge.RqData != ""). + Msg("Served active CAPTCHA challenge") + writeCaptchaJSON(w, http.StatusOK, challenge) +} + +func (cs *captchaServer) handleSolve(w http.ResponseWriter, r *http.Request) { + log := cs.requestLogger(r) + if r.Method != http.MethodPost { + log.Warn().Msg("Rejected CAPTCHA solve request with unsupported method") + writeCaptchaMethodNotAllowed(w, http.MethodPost) + return + } + + var req captchaSolveRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Warn().Err(err).Msg("Rejected CAPTCHA solve request with invalid JSON body") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "invalid JSON body"}) + return + } + if strings.TrimSpace(req.ID) == "" { + log.Warn().Msg("Rejected CAPTCHA solve request without challenge id") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "missing challenge id"}) + return + } + req.Token = strings.TrimSpace(req.Token) + if req.Token == "" { + log.Warn(). + Str("challenge_id", req.ID). + Msg("Rejected CAPTCHA solve request with empty token") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "missing captcha token"}) + return + } + + if err := cs.resolveActiveChallenge(req.ID, captchaBrowserResult{token: req.Token}); err != nil { + log.Warn(). + Str("challenge_id", req.ID). + Err(err). + Msg("Rejected CAPTCHA solve request") + writeCaptchaJSON(w, http.StatusConflict, captchaErrorResponse{Error: err.Error()}) + return + } + + log.Info(). + Str("challenge_id", req.ID). + Int("token_length", len(req.Token)). + Msg("Accepted CAPTCHA token from browser page") + w.WriteHeader(http.StatusNoContent) +} + +func (cs *captchaServer) handleCancel(w http.ResponseWriter, r *http.Request) { + log := cs.requestLogger(r) + if r.Method != http.MethodPost { + log.Warn().Msg("Rejected CAPTCHA cancel request with unsupported method") + writeCaptchaMethodNotAllowed(w, http.MethodPost) + return + } + + var req captchaCancelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Warn().Err(err).Msg("Rejected CAPTCHA cancel request with invalid JSON body") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "invalid JSON body"}) + return + } + if strings.TrimSpace(req.ID) == "" { + log.Warn().Msg("Rejected CAPTCHA cancel request without challenge id") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "missing challenge id"}) + return + } + + if err := cs.resolveActiveChallenge(req.ID, captchaBrowserResult{err: errCaptchaBrowserCanceled}); err != nil { + log.Warn(). + Str("challenge_id", req.ID). + Err(err). + Msg("Rejected CAPTCHA cancel request") + writeCaptchaJSON(w, http.StatusConflict, captchaErrorResponse{Error: err.Error()}) + return + } + + log.Info(). + Str("challenge_id", req.ID). + Msg("Browser page canceled CAPTCHA flow") + w.WriteHeader(http.StatusNoContent) +} + +func (cs *captchaServer) handleError(w http.ResponseWriter, r *http.Request) { + log := cs.requestLogger(r) + if r.Method != http.MethodPost { + log.Warn().Msg("Rejected CAPTCHA error report with unsupported method") + writeCaptchaMethodNotAllowed(w, http.MethodPost) + return + } + + var req captchaErrorRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Warn().Err(err).Msg("Rejected CAPTCHA error report with invalid JSON body") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "invalid JSON body"}) + return + } + if strings.TrimSpace(req.ID) == "" { + log.Warn().Msg("Rejected CAPTCHA error report without challenge id") + writeCaptchaJSON(w, http.StatusBadRequest, captchaErrorResponse{Error: "missing challenge id"}) + return + } + + message := strings.TrimSpace(req.Error) + if message == "" { + message = "browser page reported an unknown error" + } + if err := cs.resolveActiveChallenge(req.ID, captchaBrowserResult{err: fmt.Errorf("%s", message)}); err != nil { + log.Warn(). + Str("challenge_id", req.ID). + Err(err). + Msg("Rejected CAPTCHA browser error report") + writeCaptchaJSON(w, http.StatusConflict, captchaErrorResponse{Error: err.Error()}) + return + } + + log.Warn(). + Str("challenge_id", req.ID). + Str("browser_error", message). + Msg("Browser page reported CAPTCHA error") + w.WriteHeader(http.StatusNoContent) +} + +func (cs *captchaServer) requestLogger(r *http.Request) zerolog.Logger { + return cs.log.With(). + Str("http_method", r.Method). + Str("http_path", r.URL.Path). + Str("remote_addr", r.RemoteAddr). + Logger() +} + +func writeCaptchaJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeCaptchaMethodNotAllowed(w http.ResponseWriter, allowed string) { + w.Header().Set("Allow", allowed) + writeCaptchaJSON(w, http.StatusMethodNotAllowed, captchaErrorResponse{Error: "method not allowed"}) +} diff --git a/cmd/authtester/main_captcha.html b/cmd/authtester/main_captcha.html new file mode 100644 index 0000000..98e6b8f --- /dev/null +++ b/cmd/authtester/main_captcha.html @@ -0,0 +1,225 @@ + + + + + + mautrix-discord hCaptcha + + + +
+

mautrix-discord hCaptcha

+
+
Loading challenge state...
+
+ +
+
+ + + diff --git a/cmd/mautrix-discord/legacymigrate.go b/cmd/mautrix-discord/legacymigrate.go new file mode 100644 index 0000000..31cb8a3 --- /dev/null +++ b/cmd/mautrix-discord/legacymigrate.go @@ -0,0 +1,37 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + _ "embed" +) + +const legacyMigrateRenameTables = ` +ALTER TABLE portal RENAME TO portal_old; +ALTER TABLE puppet RENAME TO puppet_old; +ALTER TABLE "user" RENAME TO user_old; +ALTER TABLE message RENAME TO message_old; +ALTER TABLE reaction RENAME TO reaction_old; +ALTER TABLE user_portal RENAME TO user_portal_old; +ALTER TABLE guild RENAME TO guild_old; +ALTER TABLE role RENAME TO role_old; +ALTER TABLE thread RENAME TO thread_old; +ALTER TABLE discord_file RENAME TO discord_file_old; +` + +//go:embed legacymigrate.sql +var legacyMigrateCopyData string diff --git a/cmd/mautrix-discord/legacymigrate.sql b/cmd/mautrix-discord/legacymigrate.sql new file mode 100644 index 0000000..884c71c --- /dev/null +++ b/cmd/mautrix-discord/legacymigrate.sql @@ -0,0 +1,382 @@ +INSERT INTO "user" (bridge_id, mxid, management_room, access_token) +SELECT + '', -- bridge_id + mxid, + management_room, + NULL -- access_token +FROM user_old; + +INSERT INTO user_login (bridge_id, user_mxid, id, remote_name, remote_profile, space_room, metadata) +SELECT + '', -- bridge_id + uo.mxid, -- user_mxid + uo.dcid, -- id + COALESCE(uo.dcid, ''), -- remote_name + NULL, -- remote_profile + uo.space_room, + -- only: postgres for next 13 lines + jsonb_build_object( + 'token', uo.discord_token, + 'heartbeat_session', COALESCE(uo.heartbeat_session, '{}'::jsonb), + 'bridged_guild_ids', COALESCE(( + SELECT jsonb_object_agg(bg.guild_id, true) + FROM ( + SELECT DISTINCT up.discord_id AS guild_id + FROM user_portal_old AS up + JOIN guild_old AS g ON g.dcid=up.discord_id + WHERE up.user_mxid=uo.mxid AND up.type='guild' AND g.bridging_mode > 0 + ) AS bg + ), '{}'::jsonb) + ) + -- only: sqlite for next 16 lines (lines commented) +-- json_object( +-- 'token', uo.discord_token, +-- 'heartbeat_session', CASE +-- WHEN uo.heartbeat_session IS NULL OR uo.heartbeat_session='' THEN json('{}') +-- ELSE json(uo.heartbeat_session) +-- END, +-- 'bridged_guild_ids', COALESCE(( +-- SELECT json_group_object(bg.guild_id, json('true')) +-- FROM ( +-- SELECT DISTINCT up.discord_id AS guild_id +-- FROM user_portal_old AS up +-- JOIN guild_old AS g ON g.dcid=up.discord_id +-- WHERE up.user_mxid=uo.mxid AND up.type='guild' AND g.bridging_mode > 0 +-- ) AS bg +-- ), json('{}')) +-- ) +FROM user_old AS uo +WHERE uo.dcid IS NOT NULL AND uo.dcid <> ''; + +INSERT INTO ghost ( + bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata +) +SELECT + '', -- bridge_id + id, + name, + avatar, -- avatar_id + '', -- avatar_hash + avatar_url, -- avatar_mxc + name_set, + avatar_set, + contact_info_set, + is_bot, + -- only: postgres + '[]'::jsonb, -- identifiers + -- only: sqlite (line commented) +-- '[]', -- identifiers + -- only: postgres + '{}'::jsonb -- metadata + -- only: sqlite (line commented) +-- '{}' -- metadata +FROM puppet_old; + +INSERT INTO ghost ( + bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata +) +SELECT + '', -- bridge_id + missing.sender_id, -- id + missing.sender_id, -- name + '', -- avatar_id + '', -- avatar_hash + '', -- avatar_mxc + false, -- name_set + false, -- avatar_set + false, -- contact_info_set + false, -- is_bot + -- only: postgres + '[]'::jsonb, -- identifiers + -- only: sqlite (line commented) +-- '[]', -- identifiers + -- only: postgres + '{}'::jsonb -- metadata + -- only: sqlite (line commented) +-- '{}' -- metadata +FROM ( + SELECT DISTINCT dc_sender AS sender_id FROM message_old + UNION + SELECT DISTINCT dc_sender AS sender_id FROM reaction_old +) AS missing +WHERE missing.sender_id <> '' AND NOT EXISTS( + SELECT 1 FROM ghost WHERE bridge_id='' AND id=missing.sender_id +); + +INSERT INTO portal ( + bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_bridge_id, relay_login_id, other_user_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, name_is_custom, in_space, room_type, + metadata +) +SELECT + '', -- bridge_id + '*' || dcid, -- id + '', -- receiver + mxid, + NULL, -- parent_id + '', -- parent_receiver + NULL, -- relay_bridge_id + NULL, -- relay_login_id + NULL, -- other_user_id + name, + '', -- topic + avatar, -- avatar_id + '', -- avatar_hash + avatar_url, -- avatar_mxc + name_set, + avatar_set, + true, -- topic_set + true, -- name_is_custom + false, -- in_space + 'space', -- room_type + -- only: postgres + '{}'::jsonb -- metadata + -- only: sqlite (line commented) +-- '{}' -- metadata +FROM guild_old; + +INSERT INTO portal ( + bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_bridge_id, relay_login_id, other_user_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, name_is_custom, in_space, room_type, + metadata +) +SELECT + '', -- bridge_id + p.dcid, -- id + p.receiver, -- receiver + p.mxid, + CASE + WHEN p.dc_parent_id <> '' THEN p.dc_parent_id + WHEN p.dc_guild_id <> '' THEN '*' || p.dc_guild_id + ELSE NULL + END, -- parent_id + CASE + WHEN p.dc_parent_id <> '' THEN p.dc_parent_receiver + WHEN p.dc_guild_id <> '' THEN '' + ELSE '' + END, -- parent_receiver + NULL, -- relay_bridge_id + NULL, -- relay_login_id + NULLIF(p.other_user_id, ''), -- other_user_id + p.name, + p.topic, + p.avatar, -- avatar_id + '', -- avatar_hash + p.avatar_url, -- avatar_mxc + p.name_set, + p.avatar_set, + p.topic_set, + NOT (p.type=1), -- name_is_custom + p.in_space <> '', -- in_space + CASE + WHEN p.type=1 THEN 'dm' + WHEN p.type=3 THEN 'group_dm' + WHEN p.type=4 THEN 'space' + ELSE '' + END, -- room_type + -- only: postgres for next 4 lines + jsonb_build_object( + 'guild_id', COALESCE(p.dc_guild_id, ''), + 'channel_type', p.type + ) + -- only: sqlite for next 4 lines (lines commented) +-- json_object( +-- 'guild_id', COALESCE(p.dc_guild_id, ''), +-- 'channel_type', p.type +-- ) +FROM portal_old AS p; + +INSERT INTO message ( + bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, timestamp, edit_count, double_puppeted, + thread_root_id, reply_to_id, reply_to_part_id, send_txn_id, metadata +) +SELECT + '', -- bridge_id + m.dcid, -- id + m.dc_attachment_id, -- part_id + m.mxid, + m.dc_chan_id, -- room_id + m.dc_chan_receiver, -- room_receiver + m.dc_sender, -- sender_id + m.sender_mxid, + m.timestamp * 1000000, -- timestamp (ms -> ns) + CASE WHEN m.dc_edit_timestamp > 0 THEN 1 ELSE 0 END, -- edit_count + false, -- double_puppeted + CASE WHEN m.dc_thread_id <> '' THEN COALESCE(NULLIF(t.root_msg_dcid, ''), m.dc_thread_id) END, -- thread_root_id + NULL, -- reply_to_id + NULL, -- reply_to_part_id + NULL, -- send_txn_id + -- only: postgres + '{}'::jsonb -- metadata + -- only: sqlite (line commented) +-- '{}' -- metadata +FROM message_old AS m +LEFT JOIN thread_old AS t ON t.dcid=m.dc_thread_id AND (t.receiver=m.dc_chan_receiver OR t.receiver='') +WHERE EXISTS ( + SELECT 1 + FROM portal + WHERE bridge_id='' AND id=m.dc_chan_id AND receiver=m.dc_chan_receiver +); + +INSERT INTO reaction ( + bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, room_id, room_receiver, mxid, timestamp, emoji, metadata +) +SELECT + '', -- bridge_id + r.dc_msg_id, -- message_id + r.dc_first_attachment_id, -- message_part_id + r.dc_sender, -- sender_id + '', -- sender_mxid + r.dc_emoji_name, -- emoji_id + m.room_id, + m.room_receiver, + r.mxid, + m.timestamp, + r.dc_emoji_name, -- emoji + -- only: postgres + '{}'::jsonb -- metadata + -- only: sqlite (line commented) +-- '{}' -- metadata +FROM reaction_old AS r +JOIN message AS m ON m.bridge_id='' AND m.id=r.dc_msg_id AND m.part_id=r.dc_first_attachment_id AND m.room_id=r.dc_chan_id AND m.room_receiver=r.dc_chan_receiver +WHERE r.dc_sender <> ''; + +INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read) +SELECT + '', -- bridge_id + up.user_mxid, + u.dcid, -- login_id + CASE WHEN up.type='guild' THEN '*' || up.discord_id ELSE up.discord_id END, -- portal_id + CASE WHEN up.type='guild' THEN '' ELSE COALESCE( + (SELECT p.receiver FROM portal_old AS p WHERE p.dcid=up.discord_id AND p.receiver=u.dcid LIMIT 1), + (SELECT p.receiver FROM portal_old AS p WHERE p.dcid=up.discord_id AND p.receiver='' LIMIT 1), + '' + ) END, -- portal_receiver + up.in_space, -- in_space + false, -- preferred + CASE WHEN up.timestamp > 0 THEN up.timestamp * 1000000 END -- last_read +FROM user_portal_old AS up +JOIN user_old AS u ON u.mxid=up.user_mxid +WHERE u.dcid IS NOT NULL AND u.dcid <> '' AND EXISTS( + SELECT 1 + FROM portal + WHERE bridge_id='' AND id=(CASE WHEN up.type='guild' THEN '*' || up.discord_id ELSE up.discord_id END) + AND receiver=(CASE WHEN up.type='guild' THEN '' ELSE COALESCE( + (SELECT p.receiver FROM portal_old AS p WHERE p.dcid=up.discord_id AND p.receiver=u.dcid LIMIT 1), + (SELECT p.receiver FROM portal_old AS p WHERE p.dcid=up.discord_id AND p.receiver='' LIMIT 1), + '' + ) END) +) +ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO NOTHING; + +-- migrate thread_old -> discord_thread (receiver already known) +INSERT INTO discord_thread (user_login_id, parent_channel_id, thread_channel_id, root_message_id) +SELECT + t.receiver AS user_login_id, + t.parent_chan_id, + t.dcid AS thread_channel_id, + t.root_msg_dcid AS root_message_id +FROM thread_old AS t +WHERE t.receiver <> '' AND t.root_msg_dcid <> '' +ON CONFLICT (user_login_id, thread_channel_id) DO UPDATE +SET parent_channel_id=excluded.parent_channel_id, root_message_id=excluded.root_message_id; + +-- migrate thread_old -> discord_thread (receiver missing; derive from guild +-- membership) +INSERT INTO discord_thread (user_login_id, parent_channel_id, thread_channel_id, root_message_id) +SELECT DISTINCT + u.dcid AS user_login_id, + t.parent_chan_id, + t.dcid AS thread_channel_id, + t.root_msg_dcid AS root_message_id +FROM thread_old AS t +JOIN portal_old AS parent ON parent.dcid=t.parent_chan_id AND parent.receiver='' +JOIN user_portal_old AS up ON up.type='guild' AND up.discord_id=parent.dc_guild_id +JOIN user_old AS u ON u.mxid=up.user_mxid +WHERE t.receiver='' AND t.root_msg_dcid <> '' AND u.dcid <> '' +ON CONFLICT (user_login_id, thread_channel_id) DO UPDATE +SET parent_channel_id=excluded.parent_channel_id, root_message_id=excluded.root_message_id; + +-- migrate message_old -> discord_thread (thread reference; receiver known) +INSERT INTO discord_thread (user_login_id, parent_channel_id, thread_channel_id, root_message_id) +SELECT DISTINCT + m.dc_chan_receiver AS user_login_id, + m.dc_chan_id AS parent_channel_id, + m.dc_thread_id AS thread_channel_id, + COALESCE(NULLIF(t.root_msg_dcid, ''), m.dc_thread_id) AS root_message_id +FROM message_old AS m +LEFT JOIN thread_old AS t ON t.dcid=m.dc_thread_id AND (t.receiver=m.dc_chan_receiver OR t.receiver='') +WHERE m.dc_chan_receiver <> '' AND m.dc_thread_id <> '' AND COALESCE(NULLIF(t.root_msg_dcid, ''), m.dc_thread_id) <> '' +ON CONFLICT (user_login_id, thread_channel_id) DO UPDATE +SET parent_channel_id=excluded.parent_channel_id, root_message_id=excluded.root_message_id; + +-- migrate message_old -> discord_thread (thread reference; eceiverr missing) +INSERT INTO discord_thread (user_login_id, parent_channel_id, thread_channel_id, root_message_id) +SELECT DISTINCT + u.dcid AS user_login_id, + m.dc_chan_id AS parent_channel_id, + m.dc_thread_id AS thread_channel_id, + COALESCE(NULLIF(t.root_msg_dcid, ''), m.dc_thread_id) AS root_message_id +FROM message_old AS m +JOIN portal_old AS parent ON parent.dcid=m.dc_chan_id AND parent.receiver='' +JOIN user_portal_old AS up ON up.type='guild' AND up.discord_id=parent.dc_guild_id +JOIN user_old AS u ON u.mxid=up.user_mxid +LEFT JOIN thread_old AS t ON t.dcid=m.dc_thread_id AND t.receiver='' +WHERE m.dc_chan_receiver='' AND m.dc_thread_id <> '' AND COALESCE(NULLIF(t.root_msg_dcid, ''), m.dc_thread_id) <> '' AND u.dcid <> '' +ON CONFLICT (user_login_id, thread_channel_id) DO UPDATE +SET parent_channel_id=excluded.parent_channel_id, root_message_id=excluded.root_message_id; + +INSERT INTO role (discord_guild_id, discord_id, name, icon, mentionable, managed, hoist, color, position, permissions) SELECT + r.dc_guild_id AS discord_guild_id, + r.dcid AS discord_id, + r.name, + r.icon, + r.mentionable, + r.managed, + r.hoist, + r.color, + r.position, + r.permissions +FROM role_old r; + +INSERT INTO custom_emoji (discord_id, name, animated, mxc) +SELECT + picked.id AS discord_id, + picked.emoji_name AS name, + CASE + WHEN picked.mime_type='image/gif' OR lower(picked.url) LIKE '%.gif%' THEN true + ELSE false + END AS animated, + picked.mxc +FROM ( + SELECT + df.id, + df.emoji_name, + df.mxc, + df.mime_type, + df.url, + ROW_NUMBER() OVER ( + PARTITION BY df.id + ORDER BY df.timestamp DESC, df.emoji_name DESC, df.mxc DESC + ) AS rn + FROM discord_file_old AS df + WHERE df.id IS NOT NULL AND df.id <> '' + AND df.emoji_name IS NOT NULL AND df.emoji_name <> '' + AND df.mxc IS NOT NULL AND df.mxc <> '' +) AS picked +WHERE picked.rn=1 +ON CONFLICT (discord_id) DO UPDATE +SET name=excluded.name, animated=excluded.animated, mxc=excluded.mxc; + +DROP TABLE thread_old; +DROP TABLE role_old; +DROP TABLE user_portal_old; +DROP TABLE reaction_old; +DROP TABLE message_old; +DROP TABLE user_old; +DROP TABLE puppet_old; +DROP TABLE portal_old; +DROP TABLE guild_old; +DROP TABLE discord_file_old; diff --git a/cmd/mautrix-discord/main.go b/cmd/mautrix-discord/main.go new file mode 100644 index 0000000..5391801 --- /dev/null +++ b/cmd/mautrix-discord/main.go @@ -0,0 +1,58 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "maunium.net/go/mautrix/bridgev2/matrix/mxmain" + + "go.mau.fi/mautrix-discord/pkg/connector" + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" +) + +var ( + Tag = "unknown" + Commit = "unknown" + BuildTime = "unknown" +) + +var c = &connector.DiscordConnector{} +var m = mxmain.BridgeMain{ + Name: "mautrix-discord", + Description: "A Matrix-Discord puppeting bridge", + URL: "https://github.com/mautrix/discord", + Version: "26.03", + SemCalVer: true, + Connector: c, +} + +func main() { + m.PostInit = func() { + m.CheckLegacyDB(24, "v0.7.6", "v26.03", + m.LegacyMigrateWithAnotherUpgrader( + legacyMigrateRenameTables, + legacyMigrateCopyData, + 26, + discorddb.UpgradeTable(), + "discord_version", + 2, + ), + true, + ) + } + m.InitVersion(Tag, Commit, BuildTime) + m.Run() +} diff --git a/commands.go b/commands.go deleted file mode 100644 index 88b737c..0000000 --- a/commands.go +++ /dev/null @@ -1,901 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "bytes" - "context" - "encoding/base64" - "errors" - "fmt" - "html" - "net/http" - "strconv" - "strings" - - "github.com/bwmarrin/discordgo" - "github.com/skip2/go-qrcode" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/commands" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" - "go.mau.fi/mautrix-discord/remoteauth" -) - -type WrappedCommandEvent struct { - *commands.Event - Bridge *DiscordBridge - User *User - Portal *Portal -} - -var HelpSectionPortalManagement = commands.HelpSection{Name: "Portal management", Order: 20} - -func (br *DiscordBridge) RegisterCommands() { - proc := br.CommandProcessor.(*commands.Processor) - proc.AddHandlers( - cmdLoginToken, - cmdLoginQR, - cmdLogout, - cmdPing, - cmdReconnect, - cmdDisconnect, - cmdBridge, - cmdUnbridge, - cmdDeletePortal, - cmdCreatePortal, - cmdSetRelay, - cmdUnsetRelay, - cmdGuilds, - cmdRejoinSpace, - cmdDeleteAllPortals, - cmdExec, - cmdCommands, - ) -} - -func wrapCommand(handler func(*WrappedCommandEvent)) func(*commands.Event) { - return func(ce *commands.Event) { - user := ce.User.(*User) - var portal *Portal - if ce.Portal != nil { - portal = ce.Portal.(*Portal) - } - br := ce.Bridge.Child.(*DiscordBridge) - handler(&WrappedCommandEvent{ce, br, user, portal}) - } -} - -var cmdLoginToken = &commands.FullHandler{ - Func: wrapCommand(fnLoginToken), - Name: "login-token", - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Link the bridge to your Discord account by extracting the access token manually.", - Args: " <_token_>", - }, -} - -const discordTokenEpoch = 1293840000 - -func decodeToken(token string) (userID int64, err error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - err = fmt.Errorf("invalid number of parts in token") - return - } - var userIDStr []byte - userIDStr, err = base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - err = fmt.Errorf("invalid base64 in user ID part: %w", err) - return - } - _, err = base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - err = fmt.Errorf("invalid base64 in random part: %w", err) - return - } - _, err = base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - err = fmt.Errorf("invalid base64 in checksum part: %w", err) - return - } - userID, err = strconv.ParseInt(string(userIDStr), 10, 64) - if err != nil { - err = fmt.Errorf("invalid number in decoded user ID part: %w", err) - return - } - return -} - -func fnLoginToken(ce *WrappedCommandEvent) { - if len(ce.Args) != 2 { - ce.Reply("**Usage**: `$cmdprefix login-token `") - return - } - ce.MarkRead() - defer ce.Redact() - if ce.User.IsLoggedIn() { - ce.Reply("You're already logged in") - return - } - token := ce.Args[1] - userID, err := decodeToken(token) - if err != nil { - ce.Reply("Invalid token") - return - } - switch strings.ToLower(ce.Args[0]) { - case "user": - // Token is used as-is - case "bot": - token = "Bot " + token - case "oauth": - token = "Bearer " + token - default: - ce.Reply("Token type must be `user`, `bot` or `oauth`") - return - } - ce.Reply("Connecting to Discord as user ID %d", userID) - if err = ce.User.Login(token); err != nil { - ce.Reply("Error connecting to Discord: %v", err) - return - } - ce.Reply("Successfully logged in as @%s", ce.User.Session.State.User.Username) -} - -var cmdLoginQR = &commands.FullHandler{ - Func: wrapCommand(fnLoginQR), - Name: "login-qr", - Aliases: []string{"login"}, - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Link the bridge to your Discord account by scanning a QR code.", - }, -} - -func fnLoginQR(ce *WrappedCommandEvent) { - if ce.User.IsLoggedIn() { - ce.Reply("You're already logged in") - return - } - - client, err := remoteauth.New() - if err != nil { - ce.Reply("Failed to prepare login: %v", err) - return - } - - qrChan := make(chan string) - doneChan := make(chan struct{}) - - var qrCodeEvent id.EventID - - go func() { - code := <-qrChan - resp := sendQRCode(ce, code) - qrCodeEvent = resp - }() - - ctx := context.Background() - - if err = client.Dial(ctx, qrChan, doneChan); err != nil { - close(qrChan) - close(doneChan) - ce.Reply("Error connecting to login websocket: %v", err) - return - } - - <-doneChan - - if qrCodeEvent != "" { - _, _ = ce.MainIntent().RedactEvent(ce.RoomID, qrCodeEvent) - } - - user, err := client.Result() - if err != nil || len(user.Token) == 0 { - if restErr := (&discordgo.RESTError{}); errors.As(err, &restErr) && - restErr.Response.StatusCode == http.StatusBadRequest && - bytes.Contains(restErr.ResponseBody, []byte("captcha-required")) { - ce.Reply("Error logging in: %v\n\nCAPTCHAs are currently not supported - use token login instead", err) - } else { - ce.Reply("Error logging in: %v", err) - } - return - } else if err = ce.User.Login(user.Token); err != nil { - ce.Reply("Error connecting after login: %v", err) - return - } - ce.User.Lock() - ce.User.DiscordID = user.UserID - ce.User.Update() - ce.User.Unlock() - ce.Reply("Successfully logged in as @%s", user.Username) -} - -func sendQRCode(ce *WrappedCommandEvent, code string) id.EventID { - url, ok := uploadQRCode(ce, code) - if !ok { - return "" - } - - content := event.MessageEventContent{ - MsgType: event.MsgImage, - Body: code, - FileName: "qr.png", - URL: url.CUString(), - } - - resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content) - if err != nil { - ce.Log.Errorfln("Failed to send QR code: %v", err) - return "" - } - - return resp.EventID -} - -func uploadQRCode(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) { - qrCode, err := qrcode.Encode(code, qrcode.Low, 256) - if err != nil { - ce.Log.Errorln("Failed to encode QR code:", err) - ce.Reply("Failed to encode QR code: %v", err) - return id.ContentURI{}, false - } - - resp, err := ce.Bot.UploadBytes(qrCode, "image/png") - if err != nil { - ce.Log.Errorln("Failed to upload QR code:", err) - ce.Reply("Failed to upload QR code: %v", err) - return id.ContentURI{}, false - } - - return resp.ContentURI, true -} - -var cmdLogout = &commands.FullHandler{ - Func: wrapCommand(fnLogout), - Name: "logout", - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Forget the stored Discord auth token.", - }, -} - -func fnLogout(ce *WrappedCommandEvent) { - wasLoggedIn := ce.User.DiscordID != "" - ce.User.Logout(false) - if wasLoggedIn { - ce.Reply("Logged out successfully.") - } else { - ce.Reply("You weren't logged in, but data was re-cleared just to be safe.") - } -} - -var cmdPing = &commands.FullHandler{ - Func: wrapCommand(fnPing), - Name: "ping", - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Check your connection to Discord", - }, -} - -func fnPing(ce *WrappedCommandEvent) { - if ce.User.Session == nil { - if ce.User.DiscordToken == "" { - ce.Reply("You're not logged in") - } else { - ce.Reply("You have a Discord token stored, but are not connected for some reason 🤔") - } - } else if ce.User.wasDisconnected { - ce.Reply("You're logged in, but the Discord connection seems to be dead 💥") - } else { - ce.Reply("You're logged in as @%s (`%s`)", ce.User.Session.State.User.Username, ce.User.DiscordID) - } -} - -var cmdDisconnect = &commands.FullHandler{ - Func: wrapCommand(fnDisconnect), - Name: "disconnect", - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Disconnect from Discord (without logging out)", - }, - RequiresLogin: true, -} - -func fnDisconnect(ce *WrappedCommandEvent) { - if !ce.User.Connected() { - ce.Reply("You're already not connected") - } else if err := ce.User.Disconnect(); err != nil { - ce.Reply("Error while disconnecting: %v", err) - } else { - ce.Reply("Successfully disconnected") - } -} - -var cmdReconnect = &commands.FullHandler{ - Func: wrapCommand(fnReconnect), - Name: "reconnect", - Aliases: []string{"connect"}, - Help: commands.HelpMeta{ - Section: commands.HelpSectionAuth, - Description: "Reconnect to Discord after disconnecting", - }, - RequiresLogin: true, -} - -func fnReconnect(ce *WrappedCommandEvent) { - if ce.User.Connected() { - ce.Reply("You're already connected") - } else if err := ce.User.Connect(); err != nil { - ce.Reply("Error while reconnecting: %v", err) - } else { - ce.Reply("Successfully reconnected") - } -} - -var cmdRejoinSpace = &commands.FullHandler{ - Func: wrapCommand(fnRejoinSpace), - Name: "rejoin-space", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Ask the bridge for an invite to a space you left", - Args: "<_guild ID_/main/dms>", - }, - RequiresLogin: true, -} - -func fnRejoinSpace(ce *WrappedCommandEvent) { - if len(ce.Args) == 0 { - ce.Reply("**Usage**: `$cmdprefix rejoin-space `") - return - } - user := ce.User - if ce.Args[0] == "main" { - user.ensureInvited(nil, user.GetSpaceRoom(), false, true) - ce.Reply("Invited you to your main space ([link](%s))", user.GetSpaceRoom().URI(ce.Bridge.AS.HomeserverDomain).MatrixToURL()) - } else if ce.Args[0] == "dms" { - user.ensureInvited(nil, user.GetDMSpaceRoom(), false, true) - ce.Reply("Invited you to your DM space ([link](%s))", user.GetDMSpaceRoom().URI(ce.Bridge.AS.HomeserverDomain).MatrixToURL()) - } else if _, err := strconv.Atoi(ce.Args[0]); err == nil { - ce.Reply("Rejoining guild spaces is not yet implemented") - } else { - ce.Reply("**Usage**: `$cmdprefix rejoin-space `") - return - } -} - -var roomModerator = event.Type{Type: "fi.mau.discord.admin", Class: event.StateEventType} - -var cmdSetRelay = &commands.FullHandler{ - Func: wrapCommand(fnSetRelay), - Name: "set-relay", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Create or set a relay webhook for a portal", - Args: "[room ID] <​--url URL> OR <​--create [name]>", - }, - RequiresLogin: true, - RequiresEventLevel: roomModerator, -} - -const webhookURLFormat = "https://discord.com/api/webhooks/%d/%s" - -const selectRelayHelp = "Usage: `$cmdprefix [room ID] <​--url URL> OR <​--create [name]>`" - -func fnSetRelay(ce *WrappedCommandEvent) { - portal := ce.Portal - if len(ce.Args) > 0 && strings.HasPrefix(ce.Args[0], "!") { - portal = ce.Bridge.GetPortalByMXID(id.RoomID(ce.Args[0])) - if portal == nil { - ce.Reply("Portal with room ID %s not found", ce.Args[0]) - return - } - if ce.User.PermissionLevel < bridgeconfig.PermissionLevelAdmin { - levels, err := portal.MainIntent().PowerLevels(ce.RoomID) - if err != nil { - ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - ce.Reply("Failed to get room power levels to see if you're allowed to use that command") - return - } else if levels.GetUserLevel(ce.User.GetMXID()) < levels.GetEventLevel(roomModerator) { - ce.Reply("You don't have admin rights in that room") - return - } - } - ce.Args = ce.Args[1:] - } else if portal == nil { - ce.Reply("You must either run the command in a portal, or specify an internal room ID as the first parameter") - return - } - log := ce.ZLog.With().Str("channel_id", portal.Key.ChannelID).Logger() - if portal.GuildID == "" { - ce.Reply("Only guild channels can have relays") - return - } else if portal.RelayWebhookID != "" { - webhookMeta, err := relayClient.WebhookWithToken(portal.RelayWebhookID, portal.RelayWebhookSecret) - if err != nil { - log.Warn().Err(err).Msg("Failed to get existing webhook info") - ce.Reply("This channel has a relay webhook set, but getting its info failed: %v", err) - return - } - ce.Reply("This channel already has a relay webhook %s (%s)", webhookMeta.Name, webhookMeta.ID) - return - } else if len(ce.Args) == 0 { - ce.Reply(selectRelayHelp) - return - } - createType := strings.ToLower(strings.TrimLeft(ce.Args[0], "-")) - var webhookMeta *discordgo.Webhook - switch createType { - case "url": - if len(ce.Args) < 2 { - ce.Reply("Usage: `$cmdprefix [room ID] --url ") - return - } - ce.Redact() - var webhookID int64 - var webhookSecret string - _, err := fmt.Sscanf(ce.Args[1], webhookURLFormat, &webhookID, &webhookSecret) - if err != nil { - log.Warn().Str("webhook_url", ce.Args[1]).Err(err).Msg("Failed to parse provided webhook URL") - ce.Reply("Invalid webhook URL") - return - } - webhookMeta, err = relayClient.WebhookWithToken(strconv.FormatInt(webhookID, 10), webhookSecret) - if err != nil { - log.Warn().Err(err).Msg("Failed to get webhook info") - ce.Reply("Failed to get webhook info: %v", err) - return - } - case "create": - perms, err := ce.User.Session.UserChannelPermissions(ce.User.DiscordID, portal.Key.ChannelID, portal.RefererOptIfUser(ce.User.Session, "")...) - if err != nil { - log.Warn().Err(err).Msg("Failed to check user permissions") - ce.Reply("Failed to check if you have permission to create webhooks") - return - } else if perms&discordgo.PermissionManageWebhooks == 0 { - log.Debug().Int64("perms", perms).Msg("User doesn't have permissions to manage webhooks in channel") - ce.Reply("You don't have permission to manage webhooks in that channel") - return - } - name := "mautrix" - if len(ce.Args) > 1 { - name = strings.Join(ce.Args[1:], " ") - } - log.Debug().Str("webhook_name", name).Msg("Creating webhook") - webhookMeta, err = ce.User.Session.WebhookCreate(portal.Key.ChannelID, name, "", portal.RefererOptIfUser(ce.User.Session, "")...) - if err != nil { - log.Warn().Err(err).Msg("Failed to create webhook") - ce.Reply("Failed to create webhook: %v", err) - return - } - default: - ce.Reply(selectRelayHelp) - return - } - if portal.Key.ChannelID != webhookMeta.ChannelID { - log.Debug(). - Str("portal_channel_id", portal.Key.ChannelID). - Str("webhook_channel_id", webhookMeta.ChannelID). - Msg("Provided webhook is for wrong channel") - ce.Reply("That webhook is not for the right channel (expected %s, webhook is for %s)", portal.Key.ChannelID, webhookMeta.ChannelID) - return - } - log.Debug().Str("webhook_id", webhookMeta.ID).Msg("Setting portal relay webhook") - portal.RelayWebhookID = webhookMeta.ID - portal.RelayWebhookSecret = webhookMeta.Token - portal.Update() - ce.Reply("Saved webhook %s (%s) as portal relay webhook", webhookMeta.Name, portal.RelayWebhookID) -} - -var cmdUnsetRelay = &commands.FullHandler{ - Func: wrapCommand(fnUnsetRelay), - Name: "unset-relay", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Disable the relay webhook and optionally delete it on Discord", - Args: "[--delete]", - }, - RequiresPortal: true, - RequiresEventLevel: roomModerator, -} - -func fnUnsetRelay(ce *WrappedCommandEvent) { - if ce.Portal.RelayWebhookID == "" { - ce.Reply("This portal doesn't have a relay webhook") - return - } - if len(ce.Args) > 0 && ce.Args[0] == "--delete" { - err := relayClient.WebhookDeleteWithToken(ce.Portal.RelayWebhookID, ce.Portal.RelayWebhookSecret) - if err != nil { - ce.Reply("Failed to delete webhook: %v", err) - return - } else { - ce.Reply("Successfully deleted webhook") - } - } else { - ce.Reply("Relay webhook disabled") - } - ce.Portal.RelayWebhookID = "" - ce.Portal.RelayWebhookSecret = "" - ce.Portal.Update() -} - -var cmdGuilds = &commands.FullHandler{ - Func: wrapCommand(fnGuilds), - Name: "guilds", - Aliases: []string{"servers", "guild", "server"}, - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Guild bridging management", - Args: " [_guild ID_] [...]", - }, - RequiresLogin: true, -} - -const smallGuildsHelp = "**Usage**: `$cmdprefix guilds [guild ID] [...]`" - -const fullGuildsHelp = smallGuildsHelp + ` - -* **help** - View this help message. -* **status** - View the list of guilds and their bridging status. -* **bridge <_guild ID_> [--entire]** - Enable bridging for a guild. The --entire flag auto-creates portals for all channels. -* **bridging-mode <_guild ID_> <_mode_>** - Set the mode for bridging messages and new channels in a guild. -* **unbridge <_guild ID_>** - Unbridge a guild and delete all channel portal rooms.` - -func fnGuilds(ce *WrappedCommandEvent) { - if len(ce.Args) == 0 { - ce.Reply(fullGuildsHelp) - return - } - subcommand := strings.ToLower(ce.Args[0]) - ce.Args = ce.Args[1:] - switch subcommand { - case "status", "list": - fnListGuilds(ce) - case "bridge": - fnBridgeGuild(ce) - case "unbridge", "delete": - fnUnbridgeGuild(ce) - case "bridging-mode", "mode": - fnGuildBridgingMode(ce) - case "help": - ce.Reply(fullGuildsHelp) - default: - ce.Reply("Unknown subcommand `%s`\n\n"+smallGuildsHelp, subcommand) - } -} - -func fnListGuilds(ce *WrappedCommandEvent) { - var items []string - for _, userGuild := range ce.User.GetPortals() { - guild := ce.Bridge.GetGuildByID(userGuild.DiscordID, false) - if guild == nil { - continue - } - var avatarHTML string - if !guild.AvatarURL.IsEmpty() { - avatarHTML = fmt.Sprintf(` `, guild.AvatarURL.String()) - } - items = append(items, fmt.Sprintf("
  • %s%s (%s) - %s
  • ", avatarHTML, html.EscapeString(guild.Name), guild.ID, guild.BridgingMode.Description())) - } - if len(items) == 0 { - ce.Reply("No guilds found") - } else { - ce.ReplyAdvanced(fmt.Sprintf("

    List of guilds:

      %s
    ", strings.Join(items, "")), false, true) - } -} - -func fnBridgeGuild(ce *WrappedCommandEvent) { - if len(ce.Args) == 0 || len(ce.Args) > 2 { - ce.Reply("**Usage**: `$cmdprefix guilds bridge [--entire]") - } else if err := ce.User.bridgeGuild(ce.Args[0], len(ce.Args) == 2 && strings.ToLower(ce.Args[1]) == "--entire"); err != nil { - ce.Reply("Error bridging guild: %v", err) - } else { - ce.Reply("Successfully bridged guild") - } -} - -func fnUnbridgeGuild(ce *WrappedCommandEvent) { - if len(ce.Args) != 1 { - ce.Reply("**Usage**: `$cmdprefix guilds unbridge ") - } else if err := ce.User.unbridgeGuild(ce.Args[0]); err != nil { - ce.Reply("Error unbridging guild: %v", err) - } else { - ce.Reply("Successfully unbridged guild") - } -} - -const availableModes = "Available modes:\n" + - "* `nothing` to never bridge any messages (default when unbridged)\n" + - "* `if-portal-exists` to bridge messages in existing portals, but drop messages in unbridged channels\n" + - "* `create-on-message` to bridge all messages and create portals if necessary on incoming messages (default after bridging)\n" + - "* `everything` to bridge all messages and create portals proactively on bridge startup (default if bridged with `--entire`)\n" - -func fnGuildBridgingMode(ce *WrappedCommandEvent) { - if len(ce.Args) == 0 || len(ce.Args) > 2 { - ce.Reply("**Usage**: `$cmdprefix guilds bridging-mode [mode]`\n\n" + availableModes) - return - } - guild := ce.Bridge.GetGuildByID(ce.Args[0], false) - if guild == nil { - ce.Reply("Guild not found") - return - } - if len(ce.Args) == 1 { - ce.Reply("%s (%s) is currently set to %s (`%s`)\n\n%s", guild.PlainName, guild.ID, guild.BridgingMode.Description(), guild.BridgingMode.String(), availableModes) - return - } - mode := database.ParseGuildBridgingMode(ce.Args[1]) - if mode == database.GuildBridgeInvalid { - ce.Reply("Invalid guild bridging mode `%s`", ce.Args[1]) - return - } - guild.BridgingMode = mode - guild.Update() - ce.Reply("Set guild bridging mode to %s", mode.Description()) -} - -var cmdBridge = &commands.FullHandler{ - Func: wrapCommand(fnBridge), - Name: "bridge", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Bridge this room to a specific Discord channel", - Args: "[--replace[=delete]] <_channel ID_>", - }, - RequiresEventLevel: roomModerator, -} - -func isNumber(str string) bool { - for _, chr := range str { - if chr < '0' || chr > '9' { - return false - } - } - return true -} - -func fnBridge(ce *WrappedCommandEvent) { - if ce.Portal != nil { - ce.Reply("This is already a portal room. Unbridge with `$cmdprefix unbridge` first if you want to link it to a different channel.") - return - } - var channelID string - var unbridgeOld, deleteOld bool - fail := true - for _, arg := range ce.Args { - arg = strings.ToLower(arg) - if arg == "--replace" { - unbridgeOld = true - } else if arg == "--replace=delete" { - unbridgeOld = true - deleteOld = true - } else if channelID == "" && isNumber(arg) { - channelID = arg - fail = false - } else { - fail = true - break - } - } - if fail { - ce.Reply("**Usage**: `$cmdprefix bridge [--replace[=delete]] `") - return - } - portal := ce.User.GetExistingPortalByID(channelID) - if portal == nil { - ce.Reply("Channel not found") - return - } - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - if portal.MXID != "" { - hasUnbridgePermission := ce.User.PermissionLevel >= bridgeconfig.PermissionLevelAdmin - if !hasUnbridgePermission { - levels, err := portal.MainIntent().PowerLevels(portal.MXID) - if errors.Is(err, mautrix.MNotFound) { - ce.ZLog.Debug().Err(err).Msg("Got M_NOT_FOUND trying to get power levels to check if user can unbridge it, assuming the room is gone") - hasUnbridgePermission = true - } else if err != nil { - ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - ce.Reply("Failed to get power levels in old room to see if you're allowed to unbridge it") - return - } else { - hasUnbridgePermission = levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(roomModerator) - } - } - if !unbridgeOld || !hasUnbridgePermission { - extraHelp := "Rerun the command with `--replace` or `--replace=delete` to unbridge the old room." - if !hasUnbridgePermission { - extraHelp = "Additionally, you do not have the permissions to unbridge the old room." - } - ce.Reply("That channel is already bridged to [%s](https://matrix.to/#/%s). %s", portal.Name, portal.MXID, extraHelp) - return - } - ce.ZLog.Debug(). - Str("old_room_id", portal.MXID.String()). - Bool("delete", deleteOld). - Msg("Unbridging old room") - portal.removeFromSpace() - portal.cleanup(!deleteOld) - portal.RemoveMXID() - ce.ZLog.Info(). - Str("old_room_id", portal.MXID.String()). - Bool("delete", deleteOld). - Msg("Unbridged old room to make space for new bridge") - } - if portal.Guild != nil && portal.Guild.BridgingMode < database.GuildBridgeIfPortalExists { - ce.ZLog.Debug().Str("guild_id", portal.Guild.ID).Msg("Bumping bridging mode of portal guild to if-portal-exists") - portal.Guild.BridgingMode = database.GuildBridgeIfPortalExists - portal.Guild.Update() - } - ce.ZLog.Debug().Str("channel_id", portal.Key.ChannelID).Msg("Bridging room") - portal.MXID = ce.RoomID - portal.bridge.portalsLock.Lock() - portal.bridge.portalsByMXID[portal.MXID] = portal - portal.bridge.portalsLock.Unlock() - portal.updateRoomName() - portal.updateRoomAvatar() - portal.updateRoomTopic() - portal.updateSpace(ce.User) - portal.UpdateBridgeInfo() - state, err := portal.MainIntent().State(portal.MXID) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to update state cache for room") - } else { - encryptionEvent, isEncrypted := state[event.StateEncryption][""] - portal.Encrypted = isEncrypted && encryptionEvent.Content.AsEncryption().Algorithm == id.AlgorithmMegolmV1 - } - portal.Update() - ce.Reply("Room successfully bridged") - ce.ZLog.Info(). - Str("channel_id", portal.Key.ChannelID). - Bool("encrypted", portal.Encrypted). - Msg("Manual bridging complete") -} - -var cmdUnbridge = &commands.FullHandler{ - Func: wrapCommand(fnUnbridge), - Name: "unbridge", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Unbridge this room from the linked Discord channel", - }, - RequiresPortal: true, - RequiresEventLevel: roomModerator, -} - -var cmdCreatePortal = &commands.FullHandler{ - Func: wrapCommand(fnCreatePortal), - Name: "create-portal", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Create a portal for a specific channel", - Args: "<_channel ID_>", - }, - RequiresLogin: true, -} - -func fnCreatePortal(ce *WrappedCommandEvent) { - meta, err := ce.User.Session.Channel(ce.Args[0]) - if err != nil { - ce.Reply("Failed to get channel info: %v", err) - return - } else if meta == nil { - ce.Reply("Channel not found") - return - } else if !ce.User.channelIsBridgeable(meta) { - ce.Reply("That channel can't be bridged") - return - } - portal := ce.User.GetPortalByMeta(meta) - if portal.Guild != nil && portal.Guild.BridgingMode == database.GuildBridgeNothing { - ce.Reply("That guild is set to not bridge any messages. Bridge the guild with `$cmdprefix guilds bridge %s` first", portal.Guild.ID) - return - } else if portal.MXID != "" { - ce.Reply("That channel is already bridged: [%s](%s)", portal.Name, portal.MXID.URI(portal.bridge.Config.Homeserver.Domain).MatrixToURL()) - return - } - err = portal.CreateMatrixRoom(ce.User, meta) - if err != nil { - ce.Reply("Failed to create portal: %v", err) - } else { - ce.Reply("Portal created: [%s](%s)", portal.Name, portal.MXID.URI(portal.bridge.Config.Homeserver.Domain).MatrixToURL()) - } -} - -var cmdDeletePortal = &commands.FullHandler{ - Func: wrapCommand(fnUnbridge), - Name: "delete-portal", - Help: commands.HelpMeta{ - Section: HelpSectionPortalManagement, - Description: "Unbridge this room and kick all Matrix users", - }, - RequiresPortal: true, - RequiresEventLevel: roomModerator, -} - -func fnUnbridge(ce *WrappedCommandEvent) { - ce.Portal.roomCreateLock.Lock() - defer ce.Portal.roomCreateLock.Unlock() - ce.Portal.removeFromSpace() - ce.Portal.cleanup(ce.Command == "unbridge") - ce.Portal.RemoveMXID() -} - -var cmdDeleteAllPortals = &commands.FullHandler{ - Func: wrapCommand(fnDeleteAllPortals), - Name: "delete-all-portals", - Help: commands.HelpMeta{ - Section: commands.HelpSectionAdmin, - Description: "Delete all portals.", - }, - RequiresAdmin: true, -} - -func fnDeleteAllPortals(ce *WrappedCommandEvent) { - portals := ce.Bridge.GetAllPortals() - guilds := ce.Bridge.GetAllGuilds() - if len(portals) == 0 && len(guilds) == 0 { - ce.Reply("Didn't find any portals") - return - } - - leave := func(mxid id.RoomID, intent *appservice.IntentAPI) { - if len(mxid) > 0 { - _, _ = intent.KickUser(mxid, &mautrix.ReqKickUser{ - Reason: "Deleting portal", - UserID: ce.User.MXID, - }) - } - } - customPuppet := ce.Bridge.GetPuppetByCustomMXID(ce.User.MXID) - if customPuppet != nil && customPuppet.CustomIntent() != nil { - intent := customPuppet.CustomIntent() - leave = func(mxid id.RoomID, _ *appservice.IntentAPI) { - if len(mxid) > 0 { - _, _ = intent.LeaveRoom(mxid) - _, _ = intent.ForgetRoom(mxid) - } - } - } - ce.Reply("Found %d channel portals and %d guild portals, deleting...", len(portals), len(guilds)) - for _, portal := range portals { - portal.Delete() - leave(portal.MXID, portal.MainIntent()) - } - for _, guild := range guilds { - guild.Delete() - leave(guild.MXID, ce.Bot) - } - ce.Reply("Finished deleting portal info. Now cleaning up rooms in background. You'll have to restart the bridge or relogin before rooms can be bridged again.") - - go func() { - for _, portal := range portals { - portal.cleanup(false) - } - ce.Reply("Finished background cleanup of deleted portal rooms.") - }() -} diff --git a/commands_botinteraction.go b/commands_botinteraction.go deleted file mode 100644 index 8dd585a..0000000 --- a/commands_botinteraction.go +++ /dev/null @@ -1,318 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2023 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "fmt" - "strconv" - "strings" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/google/shlex" - - "maunium.net/go/mautrix/bridge/commands" -) - -var HelpSectionDiscordBots = commands.HelpSection{Name: "Discord bot interaction", Order: 30} - -var cmdCommands = &commands.FullHandler{ - Func: wrapCommand(fnCommands), - Name: "commands", - Aliases: []string{"cmds", "cs"}, - Help: commands.HelpMeta{ - Section: HelpSectionDiscordBots, - Description: "View parameters of bot interaction commands on Discord", - Args: "search <_query_> OR help <_command_>", - }, - RequiresPortal: true, - RequiresLogin: true, -} - -var cmdExec = &commands.FullHandler{ - Func: wrapCommand(fnExec), - Name: "exec", - Aliases: []string{"command", "cmd", "c", "exec", "e"}, - Help: commands.HelpMeta{ - Section: HelpSectionDiscordBots, - Description: "Run bot interaction commands on Discord", - Args: "<_command_> [_arg=value ..._]", - }, - RequiresLogin: true, - RequiresPortal: true, -} - -func (portal *Portal) getCommand(user *User, command string) (*discordgo.ApplicationCommand, error) { - portal.commandsLock.Lock() - defer portal.commandsLock.Unlock() - cmd, ok := portal.commands[command] - if !ok { - results, err := user.Session.ApplicationCommandsSearch(portal.Key.ChannelID, command, portal.RefererOpt("")) - if err != nil { - return nil, err - } - for _, result := range results { - if result.Name == command { - portal.commands[result.Name] = result - cmd = result - break - } - } - if cmd == nil { - return nil, nil - } - } - return cmd, nil -} - -func getCommandOptionTypeName(optType discordgo.ApplicationCommandOptionType) string { - switch optType { - case discordgo.ApplicationCommandOptionSubCommand: - return "subcommand" - case discordgo.ApplicationCommandOptionSubCommandGroup: - return "subcommand group (unsupported)" - case discordgo.ApplicationCommandOptionString: - return "string" - case discordgo.ApplicationCommandOptionInteger: - return "integer" - case discordgo.ApplicationCommandOptionBoolean: - return "boolean" - case discordgo.ApplicationCommandOptionUser: - return "user (unsupported)" - case discordgo.ApplicationCommandOptionChannel: - return "channel (unsupported)" - case discordgo.ApplicationCommandOptionRole: - return "role (unsupported)" - case discordgo.ApplicationCommandOptionMentionable: - return "mentionable (unsupported)" - case discordgo.ApplicationCommandOptionNumber: - return "number" - case discordgo.ApplicationCommandOptionAttachment: - return "attachment (unsupported)" - default: - return fmt.Sprintf("unknown type %d", optType) - } -} - -func parseCommandOptionValue(optType discordgo.ApplicationCommandOptionType, value string) (any, error) { - switch optType { - case discordgo.ApplicationCommandOptionSubCommandGroup: - return nil, fmt.Errorf("subcommand groups aren't supported") - case discordgo.ApplicationCommandOptionString: - return value, nil - case discordgo.ApplicationCommandOptionInteger: - return strconv.ParseInt(value, 10, 64) - case discordgo.ApplicationCommandOptionBoolean: - return strconv.ParseBool(value) - case discordgo.ApplicationCommandOptionUser: - return nil, fmt.Errorf("user options aren't supported") - case discordgo.ApplicationCommandOptionChannel: - return nil, fmt.Errorf("channel options aren't supported") - case discordgo.ApplicationCommandOptionRole: - return nil, fmt.Errorf("role options aren't supported") - case discordgo.ApplicationCommandOptionMentionable: - return nil, fmt.Errorf("mentionable options aren't supported") - case discordgo.ApplicationCommandOptionNumber: - return strconv.ParseFloat(value, 64) - case discordgo.ApplicationCommandOptionAttachment: - return nil, fmt.Errorf("attachment options aren't supported") - default: - return nil, fmt.Errorf("unknown option type %d", optType) - } -} - -func indent(text, with string) string { - split := strings.Split(text, "\n") - for i, part := range split { - split[i] = with + part - } - return strings.Join(split, "\n") -} - -func formatOption(opt *discordgo.ApplicationCommandOption) string { - argText := fmt.Sprintf("* `%s`: %s", opt.Name, getCommandOptionTypeName(opt.Type)) - if strings.ToLower(opt.Description) != opt.Name { - argText += fmt.Sprintf(" - %s", opt.Description) - } - if opt.Required { - argText += " (required)" - } - if len(opt.Options) > 0 { - subopts := make([]string, len(opt.Options)) - for i, subopt := range opt.Options { - subopts[i] = indent(formatOption(subopt), " ") - } - argText += "\n" + strings.Join(subopts, "\n") - } - return argText -} - -func formatCommand(cmd *discordgo.ApplicationCommand) string { - baseText := fmt.Sprintf("$cmdprefix exec %s", cmd.Name) - if len(cmd.Options) > 0 { - args := make([]string, len(cmd.Options)) - argPlaceholder := "[arg=value ...]" - for i, opt := range cmd.Options { - args[i] = formatOption(opt) - if opt.Required { - argPlaceholder = "" - } - } - baseText = fmt.Sprintf("`%s %s` - %s\n%s", baseText, argPlaceholder, cmd.Description, strings.Join(args, "\n")) - } else { - baseText = fmt.Sprintf("`%s` - %s", baseText, cmd.Description) - } - return baseText -} - -func parseCommandOptions(opts []*discordgo.ApplicationCommandOption, subcommands []string, namedArgs map[string]string) (res []*discordgo.ApplicationCommandOptionInput, err error) { - subcommandDone := false - for _, opt := range opts { - optRes := &discordgo.ApplicationCommandOptionInput{ - Type: opt.Type, - Name: opt.Name, - } - if opt.Type == discordgo.ApplicationCommandOptionSubCommand { - if !subcommandDone && len(subcommands) > 0 && subcommands[0] == opt.Name { - subcommandDone = true - optRes.Options, err = parseCommandOptions(opt.Options, subcommands[1:], namedArgs) - if err != nil { - err = fmt.Errorf("error parsing subcommand %s: %v", opt.Name, err) - break - } - subcommands = subcommands[1:] - } else { - continue - } - } else if argVal, ok := namedArgs[opt.Name]; ok { - optRes.Value, err = parseCommandOptionValue(opt.Type, argVal) - if err != nil { - err = fmt.Errorf("error parsing parameter %s: %v", opt.Name, err) - break - } - } else if opt.Required { - switch opt.Type { - case discordgo.ApplicationCommandOptionSubCommandGroup, discordgo.ApplicationCommandOptionUser, - discordgo.ApplicationCommandOptionChannel, discordgo.ApplicationCommandOptionRole, - discordgo.ApplicationCommandOptionMentionable, discordgo.ApplicationCommandOptionAttachment: - err = fmt.Errorf("missing required parameter %s (which is not supported by the bridge)", opt.Name) - default: - err = fmt.Errorf("missing required parameter %s", opt.Name) - } - break - } else { - continue - } - res = append(res, optRes) - } - if len(subcommands) > 0 { - err = fmt.Errorf("unparsed subcommands left over (did you forget quoting for parameters with spaces?)") - } - return -} - -func executeCommand(cmd *discordgo.ApplicationCommand, args []string) (res []*discordgo.ApplicationCommandOptionInput, err error) { - namedArgs := map[string]string{} - n := 0 - for _, arg := range args { - name, value, isNamed := strings.Cut(arg, "=") - if isNamed { - namedArgs[name] = value - } else { - args[n] = arg - n++ - } - } - return parseCommandOptions(cmd.Options, args[:n], namedArgs) -} - -func fnCommands(ce *WrappedCommandEvent) { - if len(ce.Args) < 2 { - ce.Reply("**Usage**: `$cmdprefix commands search <_query_>` OR `$cmdprefix commands help <_command_>`") - return - } - subcmd := strings.ToLower(ce.Args[0]) - if subcmd == "search" { - results, err := ce.User.Session.ApplicationCommandsSearch(ce.Portal.Key.ChannelID, ce.Args[1], ce.Portal.RefererOpt("")) - if err != nil { - ce.Reply("Error searching for commands: %v", err) - return - } - formatted := make([]string, len(results)) - ce.Portal.commandsLock.Lock() - for i, result := range results { - ce.Portal.commands[result.Name] = result - formatted[i] = indent(formatCommand(result), " ") - formatted[i] = "*" + formatted[i][1:] - } - ce.Portal.commandsLock.Unlock() - ce.Reply("Found results:\n" + strings.Join(formatted, "\n")) - } else if subcmd == "help" { - command := strings.ToLower(ce.Args[1]) - cmd, err := ce.Portal.getCommand(ce.User, command) - if err != nil { - ce.Reply("Error searching for commands: %v", err) - } else if cmd == nil { - ce.Reply("Command %q not found", command) - } else { - ce.Reply(formatCommand(cmd)) - } - } -} - -func fnExec(ce *WrappedCommandEvent) { - if len(ce.Args) == 0 { - ce.Reply("**Usage**: `$cmdprefix exec [arg=value ...]`") - return - } - args, err := shlex.Split(ce.RawArgs) - if err != nil { - ce.Reply("Error parsing args with shlex: %v", err) - return - } - command := strings.ToLower(args[0]) - cmd, err := ce.Portal.getCommand(ce.User, command) - if err != nil { - ce.Reply("Error searching for commands: %v", err) - } else if cmd == nil { - ce.Reply("Command %q not found", command) - } else if options, err := executeCommand(cmd, args[1:]); err != nil { - ce.Reply("Error parsing arguments: %v\n\n**Usage:** "+formatCommand(cmd), err) - } else { - nonce := generateNonce() - ce.User.pendingInteractionsLock.Lock() - ce.User.pendingInteractions[nonce] = ce - ce.User.pendingInteractionsLock.Unlock() - err = ce.User.Session.SendInteractions(ce.Portal.GuildID, ce.Portal.Key.ChannelID, cmd, options, nonce, ce.Portal.RefererOpt("")) - if err != nil { - ce.Reply("Error sending interaction: %v", err) - ce.User.pendingInteractionsLock.Lock() - delete(ce.User.pendingInteractions, nonce) - ce.User.pendingInteractionsLock.Unlock() - } else { - go func() { - time.Sleep(10 * time.Second) - ce.User.pendingInteractionsLock.Lock() - if _, stillWaiting := ce.User.pendingInteractions[nonce]; stillWaiting { - delete(ce.User.pendingInteractions, nonce) - ce.Reply("Timed out waiting for interaction success") - } - ce.User.pendingInteractionsLock.Unlock() - }() - } - } -} diff --git a/config/bridge.go b/config/bridge.go deleted file mode 100644 index 2f78ed7..0000000 --- a/config/bridge.go +++ /dev/null @@ -1,239 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package config - -import ( - "errors" - "fmt" - "strings" - "text/template" - - "github.com/bwmarrin/discordgo" - - "maunium.net/go/mautrix/bridge/bridgeconfig" -) - -type BridgeConfig struct { - UsernameTemplate string `yaml:"username_template"` - DisplaynameTemplate string `yaml:"displayname_template"` - ChannelNameTemplate string `yaml:"channel_name_template"` - GuildNameTemplate string `yaml:"guild_name_template"` - PrivateChatPortalMeta string `yaml:"private_chat_portal_meta"` - PrivateChannelCreateLimit int `yaml:"startup_private_channel_create_limit"` - - PortalMessageBuffer int `yaml:"portal_message_buffer"` - - PublicAddress string `yaml:"public_address"` - AvatarProxyKey string `yaml:"avatar_proxy_key"` - - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageStatusEvents bool `yaml:"message_status_events"` - MessageErrorNotices bool `yaml:"message_error_notices"` - RestrictedRooms bool `yaml:"restricted_rooms"` - AutojoinThreadOnOpen bool `yaml:"autojoin_thread_on_open"` - EmbedFieldsAsTables bool `yaml:"embed_fields_as_tables"` - MuteChannelsOnCreate bool `yaml:"mute_channels_on_create"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - CustomEmojiReactions bool `yaml:"custom_emoji_reactions"` - DeletePortalOnChannelDelete bool `yaml:"delete_portal_on_channel_delete"` - DeleteGuildOnLeave bool `yaml:"delete_guild_on_leave"` - FederateRooms bool `yaml:"federate_rooms"` - PrefixWebhookMessages bool `yaml:"prefix_webhook_messages"` - EnableWebhookAvatars bool `yaml:"enable_webhook_avatars"` - UseDiscordCDNUpload bool `yaml:"use_discord_cdn_upload"` - - Proxy string `yaml:"proxy"` - - CacheMedia string `yaml:"cache_media"` - DirectMedia DirectMedia `yaml:"direct_media"` - - AnimatedSticker struct { - Target string `yaml:"target"` - Args struct { - Width int `yaml:"width"` - Height int `yaml:"height"` - FPS int `yaml:"fps"` - } `yaml:"args"` - } `yaml:"animated_sticker"` - - DoublePuppetConfig bridgeconfig.DoublePuppetConfig `yaml:",inline"` - - CommandPrefix string `yaml:"command_prefix"` - ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"` - - Backfill struct { - Limits struct { - Initial BackfillLimitPart `yaml:"initial"` - Missed BackfillLimitPart `yaml:"missed"` - } `yaml:"forward_limits"` - MaxGuildMembers int `yaml:"max_guild_members"` - } `yaml:"backfill"` - - Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"` - - Provisioning struct { - Prefix string `yaml:"prefix"` - SharedSecret string `yaml:"shared_secret"` - DebugEndpoints bool `yaml:"debug_endpoints"` - } `yaml:"provisioning"` - - Permissions bridgeconfig.PermissionConfig `yaml:"permissions"` - - usernameTemplate *template.Template `yaml:"-"` - displaynameTemplate *template.Template `yaml:"-"` - channelNameTemplate *template.Template `yaml:"-"` - guildNameTemplate *template.Template `yaml:"-"` -} - -type DirectMedia struct { - Enabled bool `yaml:"enabled"` - ServerName string `yaml:"server_name"` - WellKnownResponse string `yaml:"well_known_response"` - AllowProxy bool `yaml:"allow_proxy"` - ServerKey string `yaml:"server_key"` -} - -type BackfillLimitPart struct { - DM int `yaml:"dm"` - Channel int `yaml:"channel"` - Thread int `yaml:"thread"` -} - -func (bc *BridgeConfig) GetResendBridgeInfo() bool { - return bc.ResendBridgeInfo -} - -func (bc *BridgeConfig) EnableMessageStatusEvents() bool { - return bc.MessageStatusEvents -} - -func (bc *BridgeConfig) EnableMessageErrorNotices() bool { - return bc.MessageErrorNotices -} - -func boolToInt(val bool) int { - if val { - return 1 - } - return 0 -} - -func (bc *BridgeConfig) Validate() error { - _, hasWildcard := bc.Permissions["*"] - _, hasExampleDomain := bc.Permissions["example.com"] - _, hasExampleUser := bc.Permissions["@admin:example.com"] - exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - if len(bc.Permissions) <= exampleLen { - return errors.New("bridge.permissions not configured") - } - return nil -} - -type umBridgeConfig BridgeConfig - -func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal((*umBridgeConfig)(bc)) - if err != nil { - return err - } - - bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate) - if err != nil { - return err - } else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") { - return fmt.Errorf("username template is missing user ID placeholder") - } - bc.displaynameTemplate, err = template.New("displayname").Parse(bc.DisplaynameTemplate) - if err != nil { - return err - } - bc.channelNameTemplate, err = template.New("channel_name").Parse(bc.ChannelNameTemplate) - if err != nil { - return err - } - bc.guildNameTemplate, err = template.New("guild_name").Parse(bc.GuildNameTemplate) - if err != nil { - return err - } - - return nil -} - -var _ bridgeconfig.BridgeConfig = (*BridgeConfig)(nil) - -func (bc BridgeConfig) GetDoublePuppetConfig() bridgeconfig.DoublePuppetConfig { - return bc.DoublePuppetConfig -} - -func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig { - return bc.Encryption -} - -func (bc BridgeConfig) GetCommandPrefix() string { - return bc.CommandPrefix -} - -func (bc BridgeConfig) GetManagementRoomTexts() bridgeconfig.ManagementRoomTexts { - return bc.ManagementRoomText -} - -func (bc BridgeConfig) FormatUsername(userID string) string { - var buffer strings.Builder - _ = bc.usernameTemplate.Execute(&buffer, userID) - return buffer.String() -} - -type DisplaynameParams struct { - *discordgo.User - Webhook bool - Application bool -} - -func (bc BridgeConfig) FormatDisplayname(user *discordgo.User, webhook, application bool) string { - var buffer strings.Builder - _ = bc.displaynameTemplate.Execute(&buffer, &DisplaynameParams{ - User: user, - Webhook: webhook, - Application: application, - }) - return buffer.String() -} - -type ChannelNameParams struct { - Name string - ParentName string - GuildName string - NSFW bool - Type discordgo.ChannelType -} - -func (bc BridgeConfig) FormatChannelName(params ChannelNameParams) string { - var buffer strings.Builder - _ = bc.channelNameTemplate.Execute(&buffer, params) - return buffer.String() -} - -type GuildNameParams struct { - Name string -} - -func (bc BridgeConfig) FormatGuildName(params GuildNameParams) string { - var buffer strings.Builder - _ = bc.guildNameTemplate.Execute(&buffer, params) - return buffer.String() -} diff --git a/config/upgrade.go b/config/upgrade.go deleted file mode 100644 index 1c9fe56..0000000 --- a/config/upgrade.go +++ /dev/null @@ -1,151 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2023 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package config - -import ( - up "go.mau.fi/util/configupgrade" - "go.mau.fi/util/random" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/federation" -) - -func DoUpgrade(helper *up.Helper) { - bridgeconfig.Upgrader.DoUpgrade(helper) - - helper.Copy(up.Str, "bridge", "username_template") - helper.Copy(up.Str, "bridge", "displayname_template") - helper.Copy(up.Str, "bridge", "channel_name_template") - helper.Copy(up.Str, "bridge", "guild_name_template") - if legacyPrivateChatPortalMeta, ok := helper.Get(up.Bool, "bridge", "private_chat_portal_meta"); ok { - updatedPrivateChatPortalMeta := "default" - if legacyPrivateChatPortalMeta == "true" { - updatedPrivateChatPortalMeta = "always" - } - helper.Set(up.Str, updatedPrivateChatPortalMeta, "bridge", "private_chat_portal_meta") - } else { - helper.Copy(up.Str, "bridge", "private_chat_portal_meta") - } - helper.Copy(up.Int, "bridge", "startup_private_channel_create_limit") - helper.Copy(up.Str|up.Null, "bridge", "public_address") - if apkey, ok := helper.Get(up.Str, "bridge", "avatar_proxy_key"); !ok || apkey == "generate" { - helper.Set(up.Str, random.String(32), "bridge", "avatar_proxy_key") - } else { - helper.Copy(up.Str, "bridge", "avatar_proxy_key") - } - helper.Copy(up.Int, "bridge", "portal_message_buffer") - helper.Copy(up.Bool, "bridge", "delivery_receipts") - helper.Copy(up.Bool, "bridge", "message_status_events") - helper.Copy(up.Bool, "bridge", "message_error_notices") - helper.Copy(up.Bool, "bridge", "restricted_rooms") - helper.Copy(up.Bool, "bridge", "autojoin_thread_on_open") - helper.Copy(up.Bool, "bridge", "embed_fields_as_tables") - helper.Copy(up.Bool, "bridge", "mute_channels_on_create") - helper.Copy(up.Bool, "bridge", "sync_direct_chat_list") - helper.Copy(up.Bool, "bridge", "resend_bridge_info") - helper.Copy(up.Bool, "bridge", "custom_emoji_reactions") - helper.Copy(up.Bool, "bridge", "delete_portal_on_channel_delete") - helper.Copy(up.Bool, "bridge", "delete_guild_on_leave") - helper.Copy(up.Bool, "bridge", "federate_rooms") - helper.Copy(up.Bool, "bridge", "prefix_webhook_messages") - helper.Copy(up.Bool, "bridge", "enable_webhook_avatars") - helper.Copy(up.Bool, "bridge", "use_discord_cdn_upload") - helper.Copy(up.Str|up.Null, "bridge", "proxy") - helper.Copy(up.Str, "bridge", "cache_media") - helper.Copy(up.Bool, "bridge", "direct_media", "enabled") - helper.Copy(up.Str, "bridge", "direct_media", "server_name") - helper.Copy(up.Str|up.Null, "bridge", "direct_media", "well_known_response") - helper.Copy(up.Bool, "bridge", "direct_media", "allow_proxy") - if serverKey, ok := helper.Get(up.Str, "bridge", "direct_media", "server_key"); !ok || serverKey == "generate" { - serverKey = federation.GenerateSigningKey().SynapseString() - helper.Set(up.Str, serverKey, "bridge", "direct_media", "server_key") - } else { - helper.Copy(up.Str, "bridge", "direct_media", "server_key") - } - helper.Copy(up.Str, "bridge", "animated_sticker", "target") - helper.Copy(up.Int, "bridge", "animated_sticker", "args", "width") - helper.Copy(up.Int, "bridge", "animated_sticker", "args", "height") - helper.Copy(up.Int, "bridge", "animated_sticker", "args", "fps") - helper.Copy(up.Map, "bridge", "double_puppet_server_map") - helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery") - helper.Copy(up.Map, "bridge", "login_shared_secret_map") - helper.Copy(up.Str, "bridge", "command_prefix") - helper.Copy(up.Str, "bridge", "management_room_text", "welcome") - helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected") - helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected") - helper.Copy(up.Str|up.Null, "bridge", "management_room_text", "additional_help") - helper.Copy(up.Bool, "bridge", "backfill", "enabled") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "dm") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "channel") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "thread") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "dm") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "channel") - helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "thread") - helper.Copy(up.Int, "bridge", "backfill", "max_guild_members") - helper.Copy(up.Bool, "bridge", "encryption", "allow") - helper.Copy(up.Bool, "bridge", "encryption", "default") - helper.Copy(up.Bool, "bridge", "encryption", "require") - helper.Copy(up.Bool, "bridge", "encryption", "appservice") - helper.Copy(up.Bool, "bridge", "encryption", "msc4190") - helper.Copy(up.Bool, "bridge", "encryption", "allow_key_sharing") - helper.Copy(up.Bool, "bridge", "encryption", "plaintext_mentions") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_outbound_on_ack") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "dont_store_outbound") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "ratchet_on_decrypt") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_prev_on_new_session") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_on_device_delete") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "periodically_delete_expired") - helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_outdated_inbound") - helper.Copy(up.Str, "bridge", "encryption", "verification_levels", "receive") - helper.Copy(up.Str, "bridge", "encryption", "verification_levels", "send") - helper.Copy(up.Str, "bridge", "encryption", "verification_levels", "share") - helper.Copy(up.Bool, "bridge", "encryption", "rotation", "enable_custom") - helper.Copy(up.Int, "bridge", "encryption", "rotation", "milliseconds") - helper.Copy(up.Int, "bridge", "encryption", "rotation", "messages") - helper.Copy(up.Bool, "bridge", "encryption", "rotation", "disable_device_change_key_rotation") - - helper.Copy(up.Str, "bridge", "provisioning", "prefix") - if secret, ok := helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" { - sharedSecret := random.String(64) - helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret") - } else { - helper.Copy(up.Str, "bridge", "provisioning", "shared_secret") - } - helper.Copy(up.Bool, "bridge", "provisioning", "debug_endpoints") - - helper.Copy(up.Map, "bridge", "permissions") - //helper.Copy(up.Bool, "bridge", "relay", "enabled") - //helper.Copy(up.Bool, "bridge", "relay", "admin_only") - //helper.Copy(up.Map, "bridge", "relay", "message_formats") -} - -var SpacedBlocks = [][]string{ - {"homeserver", "software"}, - {"appservice"}, - {"appservice", "hostname"}, - {"appservice", "database"}, - {"appservice", "id"}, - {"appservice", "as_token"}, - {"bridge"}, - {"bridge", "command_prefix"}, - {"bridge", "management_room_text"}, - {"bridge", "encryption"}, - {"bridge", "provisioning"}, - {"bridge", "permissions"}, - //{"bridge", "relay"}, - {"logging"}, -} diff --git a/custompuppet.go b/custompuppet.go deleted file mode 100644 index f1c1f05..0000000 --- a/custompuppet.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "maunium.net/go/mautrix/id" -) - -func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error { - puppet.CustomMXID = mxid - puppet.AccessToken = accessToken - puppet.Update() - err := puppet.StartCustomMXID(false) - if err != nil { - return err - } - // TODO leave rooms with default puppet - return nil -} - -func (puppet *Puppet) ClearCustomMXID() { - save := puppet.CustomMXID != "" || puppet.AccessToken != "" - puppet.bridge.puppetsLock.Lock() - if puppet.CustomMXID != "" && puppet.bridge.puppetsByCustomMXID[puppet.CustomMXID] == puppet { - delete(puppet.bridge.puppetsByCustomMXID, puppet.CustomMXID) - } - puppet.bridge.puppetsLock.Unlock() - puppet.CustomMXID = "" - puppet.AccessToken = "" - puppet.customIntent = nil - puppet.customUser = nil - if save { - puppet.Update() - } -} - -func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { - newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail) - if err != nil { - puppet.ClearCustomMXID() - return err - } - puppet.bridge.puppetsLock.Lock() - puppet.bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet - puppet.bridge.puppetsLock.Unlock() - if puppet.AccessToken != newAccessToken { - puppet.AccessToken = newAccessToken - puppet.Update() - } - puppet.customIntent = newIntent - puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID) - return nil -} - -func (user *User) tryAutomaticDoublePuppeting() { - if !user.bridge.Config.CanAutoDoublePuppet(user.MXID) { - return - } - user.log.Debug().Msg("Checking if double puppeting needs to be enabled") - puppet := user.bridge.GetPuppetByID(user.DiscordID) - if len(puppet.CustomMXID) > 0 { - user.log.Debug().Msg("User already has double-puppeting enabled") - // Custom puppet already enabled - return - } - puppet.CustomMXID = user.MXID - err := puppet.StartCustomMXID(true) - if err != nil { - user.log.Warn().Err(err).Msg("Failed to login with shared secret") - } else { - // TODO leave rooms with default puppet - user.log.Debug().Msg("Successfully automatically enabled custom puppet") - } -} diff --git a/database/database.go b/database/database.go deleted file mode 100644 index a12bab6..0000000 --- a/database/database.go +++ /dev/null @@ -1,76 +0,0 @@ -package database - -import ( - _ "embed" - - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "go.mau.fi/util/dbutil" - "maunium.net/go/maulogger/v2" - - "go.mau.fi/mautrix-discord/database/upgrades" -) - -type Database struct { - *dbutil.Database - - User *UserQuery - Portal *PortalQuery - Puppet *PuppetQuery - Message *MessageQuery - Thread *ThreadQuery - Reaction *ReactionQuery - Guild *GuildQuery - Role *RoleQuery - File *FileQuery -} - -func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { - db := &Database{Database: baseDB} - db.UpgradeTable = upgrades.Table - db.User = &UserQuery{ - db: db, - log: log.Sub("User"), - } - db.Portal = &PortalQuery{ - db: db, - log: log.Sub("Portal"), - } - db.Puppet = &PuppetQuery{ - db: db, - log: log.Sub("Puppet"), - } - db.Message = &MessageQuery{ - db: db, - log: log.Sub("Message"), - } - db.Thread = &ThreadQuery{ - db: db, - log: log.Sub("Thread"), - } - db.Reaction = &ReactionQuery{ - db: db, - log: log.Sub("Reaction"), - } - db.Guild = &GuildQuery{ - db: db, - log: log.Sub("Guild"), - } - db.Role = &RoleQuery{ - db: db, - log: log.Sub("Role"), - } - db.File = &FileQuery{ - db: db, - log: log.Sub("File"), - } - return db -} - -func strPtr[T ~string](val T) *string { - if val == "" { - return nil - } - valStr := string(val) - return &valStr -} diff --git a/database/file.go b/database/file.go deleted file mode 100644 index 2ee926f..0000000 --- a/database/file.go +++ /dev/null @@ -1,138 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "errors" - "time" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/id" -) - -type FileQuery struct { - db *Database - log log.Logger -} - -// language=postgresql -const ( - fileSelect = "SELECT url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp FROM discord_file" - fileInsert = ` - INSERT INTO discord_file (url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - ` -) - -func (fq *FileQuery) New() *File { - return &File{ - db: fq.db, - log: fq.log, - } -} - -func (fq *FileQuery) Get(url string, encrypted bool) *File { - query := fileSelect + " WHERE url=$1 AND encrypted=$2" - return fq.New().Scan(fq.db.QueryRow(query, url, encrypted)) -} - -func (fq *FileQuery) GetEmojiByMXC(mxc id.ContentURI) *File { - query := fileSelect + " WHERE mxc=$1 AND emoji_name<>'' LIMIT 1" - return fq.New().Scan(fq.db.QueryRow(query, mxc.String())) -} - -type File struct { - db *Database - log log.Logger - - URL string - Encrypted bool - MXC id.ContentURI - - ID string - EmojiName string - - Size int - Width int - Height int - MimeType string - - DecryptionInfo *attachment.EncryptedFile - Timestamp time.Time -} - -func (f *File) Scan(row dbutil.Scannable) *File { - var fileID, emojiName, decryptionInfo sql.NullString - var width, height sql.NullInt32 - var timestamp int64 - var mxc string - err := row.Scan(&f.URL, &f.Encrypted, &mxc, &fileID, &emojiName, &f.Size, &width, &height, &f.MimeType, &decryptionInfo, ×tamp) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - f.log.Errorln("Database scan failed:", err) - panic(err) - } - return nil - } - f.ID = fileID.String - f.EmojiName = emojiName.String - f.Timestamp = time.UnixMilli(timestamp).UTC() - f.Width = int(width.Int32) - f.Height = int(height.Int32) - f.MXC, err = id.ParseContentURI(mxc) - if err != nil { - f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err) - panic(err) - } - if decryptionInfo.Valid { - err = json.Unmarshal([]byte(decryptionInfo.String), &f.DecryptionInfo) - if err != nil { - f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err) - panic(err) - } - } - return f -} - -func positiveIntToNullInt32(val int) (ptr sql.NullInt32) { - if val > 0 { - ptr.Valid = true - ptr.Int32 = int32(val) - } - return -} - -func (f *File) Insert(txn dbutil.Execable) { - if txn == nil { - txn = f.db - } - var decryptionInfoStr sql.NullString - if f.DecryptionInfo != nil { - decryptionInfo, err := json.Marshal(f.DecryptionInfo) - if err != nil { - f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err) - panic(err) - } - decryptionInfoStr.Valid = true - decryptionInfoStr.String = string(decryptionInfo) - } - _, err := txn.Exec(fileInsert, - f.URL, f.Encrypted, f.MXC.String(), strPtr(f.ID), strPtr(f.EmojiName), f.Size, - positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height), f.MimeType, - decryptionInfoStr, f.Timestamp.UnixMilli(), - ) - if err != nil { - f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err) - panic(err) - } -} - -func (f *File) Delete() { - _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted) - if err != nil { - f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err) - panic(err) - } -} diff --git a/database/guild.go b/database/guild.go deleted file mode 100644 index 70976a5..0000000 --- a/database/guild.go +++ /dev/null @@ -1,194 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - "fmt" - "strings" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type GuildBridgingMode int - -const ( - // GuildBridgeNothing tells the bridge to never bridge messages, not even checking if a portal exists. - GuildBridgeNothing GuildBridgingMode = iota - // GuildBridgeIfPortalExists tells the bridge to bridge messages in channels that already have portals. - GuildBridgeIfPortalExists - // GuildBridgeCreateOnMessage tells the bridge to create portals as soon as a message is received. - GuildBridgeCreateOnMessage - // GuildBridgeEverything tells the bridge to proactively create portals on startup and when receiving channel create notifications. - GuildBridgeEverything - - GuildBridgeInvalid GuildBridgingMode = -1 -) - -func ParseGuildBridgingMode(str string) GuildBridgingMode { - str = strings.ToLower(str) - str = strings.ReplaceAll(str, "-", "") - str = strings.ReplaceAll(str, "_", "") - switch str { - case "nothing", "0": - return GuildBridgeNothing - case "ifportalexists", "1": - return GuildBridgeIfPortalExists - case "createonmessage", "2": - return GuildBridgeCreateOnMessage - case "everything", "3": - return GuildBridgeEverything - default: - return GuildBridgeInvalid - } -} - -func (gbm GuildBridgingMode) String() string { - switch gbm { - case GuildBridgeNothing: - return "nothing" - case GuildBridgeIfPortalExists: - return "if-portal-exists" - case GuildBridgeCreateOnMessage: - return "create-on-message" - case GuildBridgeEverything: - return "everything" - default: - return "" - } -} - -func (gbm GuildBridgingMode) Description() string { - switch gbm { - case GuildBridgeNothing: - return "never bridge messages" - case GuildBridgeIfPortalExists: - return "bridge messages in existing portals" - case GuildBridgeCreateOnMessage: - return "bridge all messages and create portals on first message" - case GuildBridgeEverything: - return "bridge all messages and create portals proactively" - default: - return "" - } -} - -type GuildQuery struct { - db *Database - log log.Logger -} - -const ( - guildSelect = "SELECT dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, bridging_mode FROM guild" -) - -func (gq *GuildQuery) New() *Guild { - return &Guild{ - db: gq.db, - log: gq.log, - } -} - -func (gq *GuildQuery) GetByID(dcid string) *Guild { - query := guildSelect + " WHERE dcid=$1" - return gq.New().Scan(gq.db.QueryRow(query, dcid)) -} - -func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild { - query := guildSelect + " WHERE mxid=$1" - return gq.New().Scan(gq.db.QueryRow(query, mxid)) -} - -func (gq *GuildQuery) GetAll() []*Guild { - rows, err := gq.db.Query(guildSelect) - if err != nil { - gq.log.Errorln("Failed to query guilds:", err) - return nil - } - - var guilds []*Guild - for rows.Next() { - guild := gq.New().Scan(rows) - if guild != nil { - guilds = append(guilds, guild) - } - } - - return guilds -} - -type Guild struct { - db *Database - log log.Logger - - ID string - MXID id.RoomID - PlainName string - Name string - NameSet bool - Avatar string - AvatarURL id.ContentURI - AvatarSet bool - - BridgingMode GuildBridgingMode -} - -func (g *Guild) Scan(row dbutil.Scannable) *Guild { - var mxid sql.NullString - var avatarURL string - err := row.Scan(&g.ID, &mxid, &g.PlainName, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.BridgingMode) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - g.log.Errorln("Database scan failed:", err) - panic(err) - } - - return nil - } - if g.BridgingMode < GuildBridgeNothing || g.BridgingMode > GuildBridgeEverything { - panic(fmt.Errorf("invalid guild bridging mode %d in guild %s", g.BridgingMode, g.ID)) - } - g.MXID = id.RoomID(mxid.String) - g.AvatarURL, _ = id.ParseContentURI(avatarURL) - return g -} - -func (g *Guild) mxidPtr() *id.RoomID { - if g.MXID != "" { - return &g.MXID - } - return nil -} - -func (g *Guild) Insert() { - query := ` - INSERT INTO guild (dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, bridging_mode) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - ` - _, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.BridgingMode) - if err != nil { - g.log.Warnfln("Failed to insert %s: %v", g.ID, err) - panic(err) - } -} - -func (g *Guild) Update() { - query := ` - UPDATE guild SET mxid=$1, plain_name=$2, name=$3, name_set=$4, avatar=$5, avatar_url=$6, avatar_set=$7, bridging_mode=$8 - WHERE dcid=$9 - ` - _, err := g.db.Exec(query, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.BridgingMode, g.ID) - if err != nil { - g.log.Warnfln("Failed to update %s: %v", g.ID, err) - panic(err) - } -} - -func (g *Guild) Delete() { - _, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID) - if err != nil { - g.log.Warnfln("Failed to delete %s: %v", g.ID, err) - panic(err) - } -} diff --git a/database/message.go b/database/message.go deleted file mode 100644 index c38483c..0000000 --- a/database/message.go +++ /dev/null @@ -1,258 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - "fmt" - "strings" - "time" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type MessageQuery struct { - db *Database - log log.Logger -} - -const ( - messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message" -) - -func (mq *MessageQuery) New() *Message { - return &Message{ - db: mq.db, - log: mq.log, - } -} - -func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message { - if err != nil { - mq.log.Warnfln("Failed to query many messages: %v", err) - panic(err) - } else if rows == nil { - return nil - } - - var messages []*Message - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - - return messages -} - -func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC" - return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID)) -} - -func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) -} - -func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) -} - -func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli())) -} - -func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) -} - -func (mq *MessageQuery) GetLast(key PortalKey) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver)) -} - -func (mq *MessageQuery) DeleteAll(key PortalKey) { - query := "DELETE FROM message WHERE dc_chan_id=$1 AND dc_chan_receiver=$2" - _, err := mq.db.Exec(query, key.ChannelID, key.Receiver) - if err != nil { - mq.log.Warnfln("Failed to delete messages of %s: %v", key, err) - panic(err) - } -} - -func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3" - - row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid) - if row == nil { - return nil - } - - return mq.New().Scan(row) -} - -func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { - if len(msgs) == 0 { - return - } - valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)" - if mq.db.Dialect == dbutil.SQLite { - valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") - } - params := make([]interface{}, 2+len(msgs)*8) - placeholders := make([]string, len(msgs)) - params[0] = key.ChannelID - params[1] = key.Receiver - for i, msg := range msgs { - baseIndex := 2 + i*8 - params[baseIndex] = msg.DiscordID - params[baseIndex+1] = msg.AttachmentID - params[baseIndex+2] = msg.SenderID - params[baseIndex+3] = msg.Timestamp.UnixMilli() - params[baseIndex+4] = msg.editTimestampVal() - params[baseIndex+5] = msg.ThreadID - params[baseIndex+6] = msg.MXID - params[baseIndex+7] = msg.SenderMXID.String() - placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8) - } - _, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) - if err != nil { - mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err) - panic(err) - } -} - -type Message struct { - db *Database - log log.Logger - - DiscordID string - AttachmentID string - Channel PortalKey - SenderID string - Timestamp time.Time - EditTimestamp time.Time - ThreadID string - - MXID id.EventID - SenderMXID id.UserID -} - -func (m *Message) DiscordProtoChannelID() string { - if m.ThreadID != "" { - return m.ThreadID - } else { - return m.Channel.ChannelID - } -} - -func (m *Message) Scan(row dbutil.Scannable) *Message { - var ts, editTS int64 - - err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - m.log.Errorln("Database scan failed:", err) - panic(err) - } - - return nil - } - - if ts != 0 { - m.Timestamp = time.UnixMilli(ts).UTC() - } - if editTS != 0 { - m.EditTimestamp = time.Unix(0, editTS).UTC() - } - - return m -} - -const messageInsertQuery = ` - INSERT INTO message ( - dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid - ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) -` - -var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1) - -type MessagePart struct { - AttachmentID string - MXID id.EventID -} - -func (m *Message) editTimestampVal() int64 { - if m.EditTimestamp.IsZero() { - return 0 - } - return m.EditTimestamp.UnixNano() -} - -func (m *Message) MassInsertParts(msgs []MessagePart) { - if len(msgs) == 0 { - return - } - valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)" - if m.db.Dialect == dbutil.SQLite { - valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") - } - params := make([]interface{}, 8+len(msgs)*2) - placeholders := make([]string, len(msgs)) - params[0] = m.DiscordID - params[1] = m.Channel.ChannelID - params[2] = m.Channel.Receiver - params[3] = m.SenderID - params[4] = m.Timestamp.UnixMilli() - params[5] = m.editTimestampVal() - params[6] = m.ThreadID - params[7] = m.SenderMXID.String() - for i, msg := range msgs { - params[8+i*2] = msg.AttachmentID - params[8+i*2+1] = msg.MXID - placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2) - } - _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) - if err != nil { - m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err) - panic(err) - } -} - -func (m *Message) Insert() { - _, err := m.db.Exec(messageInsertQuery, - m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, - m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String()) - - if err != nil { - m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) - panic(err) - } -} - -const editUpdateQuery = ` - UPDATE message - SET dc_edit_timestamp=$1 - WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1 -` - -func (m *Message) UpdateEditTimestamp(ts time.Time) { - _, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver) - if err != nil { - m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err) - panic(err) - } -} - -func (m *Message) Delete() { - query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4" - _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID) - if err != nil { - m.log.Warnfln("Failed to delete %q of %s@%s: %v", m.AttachmentID, m.DiscordID, m.Channel, err) - panic(err) - } -} diff --git a/database/portal.go b/database/portal.go deleted file mode 100644 index 3c6a8da..0000000 --- a/database/portal.go +++ /dev/null @@ -1,210 +0,0 @@ -package database - -import ( - "database/sql" - - "github.com/bwmarrin/discordgo" - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -// language=postgresql -const ( - portalSelect = ` - SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid, - plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set, - encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret - FROM portal - ` -) - -type PortalKey struct { - ChannelID string - Receiver string -} - -func NewPortalKey(channelID, receiver string) PortalKey { - return PortalKey{ - ChannelID: channelID, - Receiver: receiver, - } -} - -func (key PortalKey) String() string { - if key.Receiver == "" { - return key.ChannelID - } - return key.ChannelID + "-" + key.Receiver -} - -type PortalQuery struct { - db *Database - log log.Logger -} - -func (pq *PortalQuery) New() *Portal { - return &Portal{ - db: pq.db, - log: pq.log, - } -} - -func (pq *PortalQuery) GetAll() []*Portal { - return pq.getAll(portalSelect) -} - -func (pq *PortalQuery) GetAllInGuild(guildID string) []*Portal { - return pq.getAll(portalSelect+" WHERE dc_guild_id=$1", guildID) -} - -func (pq *PortalQuery) GetByID(key PortalKey) *Portal { - return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver) -} - -func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - return pq.get(portalSelect+" WHERE mxid=$1", mxid) -} - -func (pq *PortalQuery) FindPrivateChatBetween(id, receiver string) *Portal { - return pq.get(portalSelect+" WHERE other_user_id=$1 AND receiver=$2 AND type=$3", id, receiver, discordgo.ChannelTypeDM) -} - -func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal { - return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM) -} - -func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal { - query := portalSelect + " portal WHERE receiver=$1 AND type=$2;" - - return pq.getAll(query, receiver, discordgo.ChannelTypeDM) -} - -func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal { - rows, err := pq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - - var portals []*Portal - for rows.Next() { - portals = append(portals, pq.New().Scan(rows)) - } - - return portals -} - -func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { - return pq.New().Scan(pq.db.QueryRow(query, args...)) -} - -type Portal struct { - db *Database - log log.Logger - - Key PortalKey - Type discordgo.ChannelType - OtherUserID string - ParentID string - GuildID string - - MXID id.RoomID - - PlainName string - Name string - NameSet bool - FriendNick bool - Topic string - TopicSet bool - Avatar string - AvatarURL id.ContentURI - AvatarSet bool - Encrypted bool - InSpace id.RoomID - - FirstEventID id.EventID - - RelayWebhookID string - RelayWebhookSecret string -} - -func (p *Portal) Scan(row dbutil.Scannable) *Portal { - var otherUserID, guildID, parentID, mxid, firstEventID, relayWebhookID, relayWebhookSecret sql.NullString - var chanType int32 - var avatarURL string - - err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID, - &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.FriendNick, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet, - &p.Encrypted, &p.InSpace, &firstEventID, &relayWebhookID, &relayWebhookSecret) - - if err != nil { - if err != sql.ErrNoRows { - p.log.Errorln("Database scan failed:", err) - panic(err) - } - - return nil - } - - p.MXID = id.RoomID(mxid.String) - p.OtherUserID = otherUserID.String - p.GuildID = guildID.String - p.ParentID = parentID.String - p.Type = discordgo.ChannelType(chanType) - p.FirstEventID = id.EventID(firstEventID.String) - p.AvatarURL, _ = id.ParseContentURI(avatarURL) - p.RelayWebhookID = relayWebhookID.String - p.RelayWebhookSecret = relayWebhookSecret.String - - return p -} - -func (p *Portal) Insert() { - query := ` - INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid, - plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set, - encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) - ` - _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type, - strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), - p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, - p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret)) - - if err != nil { - p.log.Warnfln("Failed to insert %s: %v", p.Key, err) - panic(err) - } -} - -func (p *Portal) Update() { - query := ` - UPDATE portal - SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5, - plain_name=$6, name=$7, name_set=$8, friend_nick=$9, topic=$10, topic_set=$11, - avatar=$12, avatar_url=$13, avatar_set=$14, encrypted=$15, in_space=$16, first_event_id=$17, - relay_webhook_id=$18, relay_webhook_secret=$19 - WHERE dcid=$20 AND receiver=$21 - ` - _, err := p.db.Exec(query, - p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), - p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet, - p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.Encrypted, p.InSpace, p.FirstEventID.String(), - strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret), - p.Key.ChannelID, p.Key.Receiver) - - if err != nil { - p.log.Warnfln("Failed to update %s: %v", p.Key, err) - panic(err) - } -} - -func (p *Portal) Delete() { - query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2" - _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver) - if err != nil { - p.log.Warnfln("Failed to delete %s: %v", p.Key, err) - panic(err) - } -} diff --git a/database/puppet.go b/database/puppet.go deleted file mode 100644 index d6080c7..0000000 --- a/database/puppet.go +++ /dev/null @@ -1,151 +0,0 @@ -package database - -import ( - "database/sql" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -const ( - puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," + - " contact_info_set, global_name, username, discriminator, is_bot, is_webhook, is_application, custom_mxid, access_token, next_batch" + - " FROM puppet " -) - -type PuppetQuery struct { - db *Database - log log.Logger -} - -func (pq *PuppetQuery) New() *Puppet { - return &Puppet{ - db: pq.db, - log: pq.log, - } -} - -func (pq *PuppetQuery) Get(id string) *Puppet { - return pq.get(puppetSelect+" WHERE id=$1", id) -} - -func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { - return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid) -} - -func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet { - return pq.New().Scan(pq.db.QueryRow(query, args...)) -} - -func (pq *PuppetQuery) GetAll() []*Puppet { - return pq.getAll(puppetSelect) -} - -func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet { - return pq.getAll(puppetSelect + " WHERE custom_mxid<>''") -} - -func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet { - rows, err := pq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - - var puppets []*Puppet - for rows.Next() { - puppets = append(puppets, pq.New().Scan(rows)) - } - - return puppets -} - -type Puppet struct { - db *Database - log log.Logger - - ID string - Name string - NameSet bool - Avatar string - AvatarURL id.ContentURI - AvatarSet bool - - ContactInfoSet bool - - GlobalName string - Username string - Discriminator string - IsBot bool - IsWebhook bool - IsApplication bool - - CustomMXID id.UserID - AccessToken string - NextBatch string -} - -func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { - var avatarURL string - var customMXID, accessToken, nextBatch sql.NullString - - err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet, &p.ContactInfoSet, - &p.GlobalName, &p.Username, &p.Discriminator, &p.IsBot, &p.IsWebhook, &p.IsApplication, &customMXID, &accessToken, &nextBatch) - - if err != nil { - if err != sql.ErrNoRows { - p.log.Errorln("Database scan failed:", err) - panic(err) - } - - return nil - } - - p.AvatarURL, _ = id.ParseContentURI(avatarURL) - p.CustomMXID = id.UserID(customMXID.String) - p.AccessToken = accessToken.String - p.NextBatch = nextBatch.String - - return p -} - -func (p *Puppet) Insert() { - query := ` - INSERT INTO puppet ( - id, name, name_set, avatar, avatar_url, avatar_set, contact_info_set, - global_name, username, discriminator, is_bot, is_webhook, is_application, - custom_mxid, access_token, next_batch - ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - ` - _, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.ContactInfoSet, - p.GlobalName, p.Username, p.Discriminator, p.IsBot, p.IsWebhook, p.IsApplication, - strPtr(p.CustomMXID), strPtr(p.AccessToken), strPtr(p.NextBatch)) - - if err != nil { - p.log.Warnfln("Failed to insert %s: %v", p.ID, err) - panic(err) - } -} - -func (p *Puppet) Update() { - query := ` - UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5, contact_info_set=$6, - global_name=$7, username=$8, discriminator=$9, is_bot=$10, is_webhook=$11, is_application=$12, - custom_mxid=$13, access_token=$14, next_batch=$15 - WHERE id=$16 - ` - _, err := p.db.Exec( - query, - p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.ContactInfoSet, - p.GlobalName, p.Username, p.Discriminator, p.IsBot, p.IsWebhook, p.IsApplication, - strPtr(p.CustomMXID), strPtr(p.AccessToken), strPtr(p.NextBatch), - p.ID, - ) - - if err != nil { - p.log.Warnfln("Failed to update %s: %v", p.ID, err) - panic(err) - } -} diff --git a/database/reaction.go b/database/reaction.go deleted file mode 100644 index 8727bb5..0000000 --- a/database/reaction.go +++ /dev/null @@ -1,124 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type ReactionQuery struct { - db *Database - log log.Logger -} - -const ( - reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction" -) - -func (rq *ReactionQuery) New() *Reaction { - return &Reaction{ - db: rq.db, - log: rq.log, - } -} - -func (rq *ReactionQuery) GetAllForMessage(key PortalKey, discordMessageID string) []*Reaction { - query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3" - - return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) -} - -func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction { - rows, err := rq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - - var reactions []*Reaction - for rows.Next() { - reactions = append(reactions, rq.New().Scan(rows)) - } - - return reactions -} - -func (rq *ReactionQuery) GetByDiscordID(key PortalKey, msgID, sender, emojiName string) *Reaction { - query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dc_sender=$4 AND dc_emoji_name=$5" - - return rq.get(query, key.ChannelID, key.Receiver, msgID, sender, emojiName) -} - -func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction { - query := reactionSelect + " WHERE mxid=$1" - - return rq.get(query, mxid) -} - -func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction { - row := rq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return rq.New().Scan(row) -} - -type Reaction struct { - db *Database - log log.Logger - - Channel PortalKey - MessageID string - Sender string - EmojiName string - ThreadID string - - MXID id.EventID - - FirstAttachmentID string -} - -func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { - err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - r.log.Errorln("Database scan failed:", err) - panic(err) - } - return nil - } - - return r -} - -func (r *Reaction) DiscordProtoChannelID() string { - if r.ThreadID != "" { - return r.ThreadID - } else { - return r.Channel.ChannelID - } -} - -func (r *Reaction) Insert() { - query := ` - INSERT INTO reaction (dc_msg_id, dc_first_attachment_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) - VALUES($1, $2, $3, $4, $5, $6, $7, $8) - ` - _, err := r.db.Exec(query, r.MessageID, r.FirstAttachmentID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.ThreadID, r.MXID) - if err != nil { - r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err) - panic(err) - } -} - -func (r *Reaction) Delete() { - query := "DELETE FROM reaction WHERE dc_msg_id=$1 AND dc_sender=$2 AND dc_emoji_name=$3" - _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName) - if err != nil { - r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err) - panic(err) - } -} diff --git a/database/role.go b/database/role.go deleted file mode 100644 index 3696b51..0000000 --- a/database/role.go +++ /dev/null @@ -1,112 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - - "github.com/bwmarrin/discordgo" - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" -) - -type RoleQuery struct { - db *Database - log log.Logger -} - -// language=postgresql -const ( - roleSelect = "SELECT dc_guild_id, dcid, name, icon, mentionable, managed, hoist, color, position, permissions FROM role" - roleUpsert = ` - INSERT INTO role (dc_guild_id, dcid, name, icon, mentionable, managed, hoist, color, position, permissions) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - ON CONFLICT (dc_guild_id, dcid) DO UPDATE - SET name=excluded.name, icon=excluded.icon, mentionable=excluded.mentionable, managed=excluded.managed, - hoist=excluded.hoist, color=excluded.color, position=excluded.position, permissions=excluded.permissions - ` - roleDelete = "DELETE FROM role WHERE dc_guild_id=$1 AND dcid=$2" -) - -func (rq *RoleQuery) New() *Role { - return &Role{ - db: rq.db, - log: rq.log, - } -} - -func (rq *RoleQuery) GetByID(guildID, dcid string) *Role { - query := roleSelect + " WHERE dc_guild_id=$1 AND dcid=$2" - return rq.New().Scan(rq.db.QueryRow(query, guildID, dcid)) -} - -func (rq *RoleQuery) DeleteByID(guildID, dcid string) { - _, err := rq.db.Exec("DELETE FROM role WHERE dc_guild_id=$1 AND dcid=$2", guildID, dcid) - if err != nil { - rq.log.Warnfln("Failed to delete %s/%s: %v", guildID, dcid, err) - panic(err) - } -} - -func (rq *RoleQuery) GetAll(guildID string) []*Role { - rows, err := rq.db.Query(roleSelect+" WHERE dc_guild_id=$1", guildID) - if err != nil { - rq.log.Errorfln("Failed to query roles of %s: %v", guildID, err) - return nil - } - - var roles []*Role - for rows.Next() { - role := rq.New().Scan(rows) - if role != nil { - roles = append(roles, role) - } - } - - return roles -} - -type Role struct { - db *Database - log log.Logger - - GuildID string - - discordgo.Role -} - -func (r *Role) Scan(row dbutil.Scannable) *Role { - var icon sql.NullString - err := row.Scan(&r.GuildID, &r.ID, &r.Name, &icon, &r.Mentionable, &r.Managed, &r.Hoist, &r.Color, &r.Position, &r.Permissions) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - r.log.Errorln("Database scan failed:", err) - panic(err) - } - - return nil - } - r.Icon = icon.String - return r -} - -func (r *Role) Upsert(txn dbutil.Execable) { - if txn == nil { - txn = r.db - } - _, err := txn.Exec(roleUpsert, r.GuildID, r.ID, r.Name, strPtr(r.Icon), r.Mentionable, r.Managed, r.Hoist, r.Color, r.Position, r.Permissions) - if err != nil { - r.log.Warnfln("Failed to insert %s/%s: %v", r.GuildID, r.ID, err) - panic(err) - } -} - -func (r *Role) Delete(txn dbutil.Execable) { - if txn == nil { - txn = r.db - } - _, err := txn.Exec(roleDelete, r.GuildID, r.Icon) - if err != nil { - r.log.Warnfln("Failed to delete %s/%s: %v", r.GuildID, r.ID, err) - panic(err) - } -} diff --git a/database/thread.go b/database/thread.go deleted file mode 100644 index 87f4127..0000000 --- a/database/thread.go +++ /dev/null @@ -1,111 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type ThreadQuery struct { - db *Database - log log.Logger -} - -const ( - threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread" -) - -func (tq *ThreadQuery) New() *Thread { - return &Thread{ - db: tq.db, - log: tq.log, - } -} - -func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread { - query := threadSelect + " WHERE dcid=$1" - - row := tq.db.QueryRow(query, discordID) - if row == nil { - return nil - } - - return tq.New().Scan(row) -} - -func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread { - query := threadSelect + " WHERE root_msg_mxid=$1" - - row := tq.db.QueryRow(query, mxid) - if row == nil { - return nil - } - - return tq.New().Scan(row) -} - -func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread { - query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1" - - row := tq.db.QueryRow(query, mxid) - if row == nil { - return nil - } - - return tq.New().Scan(row) -} - -type Thread struct { - db *Database - log log.Logger - - ID string - ParentID string - - RootDiscordID string - RootMXID id.EventID - - CreationNoticeMXID id.EventID -} - -func (t *Thread) Scan(row dbutil.Scannable) *Thread { - err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - t.log.Errorln("Database scan failed:", err) - panic(err) - } - return nil - } - return t -} - -func (t *Thread) Insert() { - query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)" - _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID) - if err != nil { - t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err) - panic(err) - } -} - -func (t *Thread) Update() { - query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1" - _, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID) - if err != nil { - t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err) - panic(err) - } -} - -func (t *Thread) Delete() { - query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2" - _, err := t.db.Exec(query, t.ID, t.ParentID) - if err != nil { - t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err) - panic(err) - } -} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql deleted file mode 100644 index 46fbb73..0000000 --- a/database/upgrades/00-latest-revision.sql +++ /dev/null @@ -1,179 +0,0 @@ --- v0 -> v23 (compatible with v19+): Latest revision - -CREATE TABLE guild ( - dcid TEXT PRIMARY KEY, - mxid TEXT UNIQUE, - plain_name TEXT NOT NULL, - name TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar TEXT NOT NULL, - avatar_url TEXT NOT NULL, - avatar_set BOOLEAN NOT NULL, - - bridging_mode INTEGER NOT NULL -); - -CREATE TABLE portal ( - dcid TEXT, - receiver TEXT, - other_user_id TEXT, - type INTEGER NOT NULL, - - dc_guild_id TEXT, - dc_parent_id TEXT, - -- This is not accessed by the bridge, it's only used for the portal parent foreign key. - -- Only guild channels have parents, but only DMs have a receiver field. - dc_parent_receiver TEXT NOT NULL DEFAULT '', - - mxid TEXT UNIQUE, - plain_name TEXT NOT NULL, - name TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - friend_nick BOOLEAN NOT NULL, - topic TEXT NOT NULL, - topic_set BOOLEAN NOT NULL, - avatar TEXT NOT NULL, - avatar_url TEXT NOT NULL, - avatar_set BOOLEAN NOT NULL, - encrypted BOOLEAN NOT NULL, - in_space TEXT NOT NULL, - - first_event_id TEXT NOT NULL, - - relay_webhook_id TEXT, - relay_webhook_secret TEXT, - - PRIMARY KEY (dcid, receiver), - CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE, - CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE -); - -CREATE TABLE thread ( - dcid TEXT PRIMARY KEY, - parent_chan_id TEXT NOT NULL, - root_msg_dcid TEXT NOT NULL, - root_msg_mxid TEXT NOT NULL, - creation_notice_mxid TEXT NOT NULL, - -- This is also not accessed by the bridge. - receiver TEXT NOT NULL DEFAULT '', - - CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE -); - -CREATE TABLE puppet ( - id TEXT PRIMARY KEY, - - name TEXT NOT NULL, - name_set BOOLEAN NOT NULL DEFAULT false, - avatar TEXT NOT NULL, - avatar_url TEXT NOT NULL, - avatar_set BOOLEAN NOT NULL DEFAULT false, - - contact_info_set BOOLEAN NOT NULL DEFAULT false, - - global_name TEXT NOT NULL DEFAULT '', - username TEXT NOT NULL DEFAULT '', - discriminator TEXT NOT NULL DEFAULT '', - is_bot BOOLEAN NOT NULL DEFAULT false, - is_webhook BOOLEAN NOT NULL DEFAULT false, - is_application BOOLEAN NOT NULL DEFAULT false, - - custom_mxid TEXT, - access_token TEXT, - next_batch TEXT -); - -CREATE TABLE "user" ( - mxid TEXT PRIMARY KEY, - dcid TEXT UNIQUE, - - discord_token TEXT, - management_room TEXT, - space_room TEXT, - dm_space_room TEXT, - - read_state_version INTEGER NOT NULL DEFAULT 0 -); - -CREATE TABLE user_portal ( - discord_id TEXT, - user_mxid TEXT, - type TEXT NOT NULL, - in_space BOOLEAN NOT NULL, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (discord_id, user_mxid), - CONSTRAINT up_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE -); - -CREATE TABLE message ( - dcid TEXT, - dc_attachment_id TEXT, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - dc_edit_timestamp BIGINT NOT NULL, - dc_thread_id TEXT NOT NULL, - - mxid TEXT NOT NULL UNIQUE, - sender_mxid TEXT NOT NULL DEFAULT '', - - PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver), - CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE -); - -CREATE TABLE reaction ( - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_msg_id TEXT, - dc_sender TEXT, - dc_emoji_name TEXT, - dc_thread_id TEXT NOT NULL, - - dc_first_attachment_id TEXT NOT NULL, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE -); - -CREATE TABLE role ( - dc_guild_id TEXT, - dcid TEXT, - - name TEXT NOT NULL, - icon TEXT, - - mentionable BOOLEAN NOT NULL, - managed BOOLEAN NOT NULL, - hoist BOOLEAN NOT NULL, - - color INTEGER NOT NULL, - position INTEGER NOT NULL, - permissions BIGINT NOT NULL, - - PRIMARY KEY (dc_guild_id, dcid), - CONSTRAINT role_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild (dcid) ON DELETE CASCADE -); - -CREATE TABLE discord_file ( - url TEXT, - encrypted BOOLEAN, - mxc TEXT NOT NULL, - - id TEXT, - emoji_name TEXT, - - size BIGINT NOT NULL, - width INTEGER, - height INTEGER, - mime_type TEXT NOT NULL, - decryption_info jsonb, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (url, encrypted) -); - -CREATE INDEX discord_file_mxc_idx ON discord_file (mxc); diff --git a/database/upgrades/02-column-renames.sql b/database/upgrades/02-column-renames.sql deleted file mode 100644 index 86b0cb0..0000000 --- a/database/upgrades/02-column-renames.sql +++ /dev/null @@ -1,53 +0,0 @@ --- v2: Rename columns in message-related tables - -ALTER TABLE portal RENAME COLUMN dmuser TO other_user_id; -ALTER TABLE portal RENAME COLUMN channel_id TO dcid; - -ALTER TABLE "user" RENAME COLUMN id TO dcid; - -ALTER TABLE puppet DROP COLUMN enable_presence; -ALTER TABLE puppet DROP COLUMN enable_receipts; - -DROP TABLE message; -DROP TABLE reaction; -DROP TABLE attachment; - -CREATE TABLE message ( - dcid TEXT, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver), - CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE -); - -CREATE TABLE reaction ( - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_msg_id TEXT, - dc_sender TEXT, - dc_emoji_name TEXT, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE -); - -CREATE TABLE attachment ( - dcid TEXT, - dc_msg_id TEXT, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver), - CONSTRAINT attachment_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE -); - -UPDATE portal SET receiver='' WHERE type<>1; diff --git a/database/upgrades/03-spaces.sql b/database/upgrades/03-spaces.sql deleted file mode 100644 index 79bc3c5..0000000 --- a/database/upgrades/03-spaces.sql +++ /dev/null @@ -1,73 +0,0 @@ --- v3: Store portal parent metadata for spaces -DROP TABLE guild; - -CREATE TABLE guild ( - dcid TEXT PRIMARY KEY, - mxid TEXT UNIQUE, - name TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar TEXT NOT NULL, - avatar_url TEXT NOT NULL, - avatar_set BOOLEAN NOT NULL, - - auto_bridge_channels BOOLEAN NOT NULL -); - -CREATE TABLE user_portal ( - discord_id TEXT, - user_mxid TEXT, - type TEXT NOT NULL, - in_space BOOLEAN NOT NULL, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (discord_id, user_mxid), - CONSTRAINT up_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE -); - -ALTER TABLE portal ADD COLUMN dc_guild_id TEXT; -ALTER TABLE portal ADD COLUMN dc_parent_id TEXT; -ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT ''; -ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE; -ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE; -DELETE FROM portal WHERE type IS NULL; --- only: postgres -ALTER TABLE portal ALTER COLUMN type SET NOT NULL; - -ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT ''; -ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false; -ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false; -ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false; --- only: postgres for next 5 lines -ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT; -ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT; -ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT; -ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT; -ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT; - -ALTER TABLE puppet RENAME COLUMN display_name TO name; -ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false; -ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false; --- only: postgres for next 2 lines -ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT; -ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT; - -ALTER TABLE "user" ADD COLUMN space_room TEXT; -ALTER TABLE "user" ADD COLUMN dm_space_room TEXT; -ALTER TABLE "user" RENAME COLUMN token TO discord_token; - -UPDATE message SET timestamp=timestamp*1000; - -CREATE TABLE thread ( - dcid TEXT PRIMARY KEY, - parent_chan_id TEXT NOT NULL, - root_msg_dcid TEXT NOT NULL, - root_msg_mxid TEXT NOT NULL, - -- This is also not accessed by the bridge. - receiver TEXT NOT NULL DEFAULT '', - - CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE -); - -ALTER TABLE message ADD COLUMN dc_thread_id TEXT; -ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT; -ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT; diff --git a/database/upgrades/04-attachment-fix.postgres.sql b/database/upgrades/04-attachment-fix.postgres.sql deleted file mode 100644 index c476afd..0000000 --- a/database/upgrades/04-attachment-fix.postgres.sql +++ /dev/null @@ -1,20 +0,0 @@ --- v4: Fix storing attachments -ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; -ALTER TABLE attachment DROP CONSTRAINT attachment_message_fkey; -ALTER TABLE message DROP CONSTRAINT message_pkey; -ALTER TABLE message ADD COLUMN dc_attachment_id TEXT NOT NULL DEFAULT ''; -ALTER TABLE message ADD COLUMN dc_edit_index INTEGER NOT NULL DEFAULT 0; -ALTER TABLE message ALTER COLUMN dc_attachment_id DROP DEFAULT; -ALTER TABLE message ALTER COLUMN dc_edit_index DROP DEFAULT; -ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver); -INSERT INTO message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) - SELECT message.dcid, attachment.dcid, 0, attachment.dc_chan_id, attachment.dc_chan_receiver, message.dc_sender, message.timestamp, attachment.dc_thread_id, attachment.mxid - FROM attachment LEFT JOIN message ON attachment.dc_msg_id = message.dcid; -DROP TABLE attachment; - -ALTER TABLE reaction ADD COLUMN dc_first_attachment_id TEXT NOT NULL DEFAULT ''; -ALTER TABLE reaction ALTER COLUMN dc_first_attachment_id DROP DEFAULT; -ALTER TABLE reaction ADD COLUMN _dc_first_edit_index INTEGER DEFAULT 0; -ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey - FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) - REFERENCES message(dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver); diff --git a/database/upgrades/04-attachment-fix.sqlite.sql b/database/upgrades/04-attachment-fix.sqlite.sql deleted file mode 100644 index 88c4386..0000000 --- a/database/upgrades/04-attachment-fix.sqlite.sql +++ /dev/null @@ -1,45 +0,0 @@ --- v4: Fix storing attachments -CREATE TABLE new_message ( - dcid TEXT, - dc_attachment_id TEXT, - dc_edit_index INTEGER, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - dc_thread_id TEXT, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver), - CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE -); -INSERT INTO new_message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) - SELECT dcid, '', 0, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message; -INSERT INTO new_message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) - SELECT message.dcid, attachment.dcid, 0, attachment.dc_chan_id, attachment.dc_chan_receiver, message.dc_sender, message.timestamp, attachment.dc_thread_id, attachment.mxid - FROM attachment LEFT JOIN message ON attachment.dc_msg_id = message.dcid; -DROP TABLE attachment; -DROP TABLE message; -ALTER TABLE new_message RENAME TO message; - -CREATE TABLE new_reaction ( - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_msg_id TEXT, - dc_sender TEXT, - dc_emoji_name TEXT, - dc_thread_id TEXT, - - dc_first_attachment_id TEXT NOT NULL, - _dc_first_edit_index INTEGER NOT NULL DEFAULT 0, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE -); -INSERT INTO new_reaction (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid) -SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, '', mxid FROM reaction; -DROP TABLE reaction; -ALTER TABLE new_reaction RENAME TO reaction; diff --git a/database/upgrades/05-reaction-fkey-fix.sql b/database/upgrades/05-reaction-fkey-fix.sql deleted file mode 100644 index 1a02a5e..0000000 --- a/database/upgrades/05-reaction-fkey-fix.sql +++ /dev/null @@ -1,8 +0,0 @@ --- v5: Fix foreign key broken in v4 --- only: postgres - -ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; -ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey - FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) - REFERENCES message(dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) - ON DELETE CASCADE; diff --git a/database/upgrades/06-user-read-state-version.sql b/database/upgrades/06-user-read-state-version.sql deleted file mode 100644 index 612a777..0000000 --- a/database/upgrades/06-user-read-state-version.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v6: Store user read state version -ALTER TABLE "user" ADD COLUMN read_state_version INTEGER NOT NULL DEFAULT 0; diff --git a/database/upgrades/07-store-role-info.sql b/database/upgrades/07-store-role-info.sql deleted file mode 100644 index 21f6a57..0000000 --- a/database/upgrades/07-store-role-info.sql +++ /dev/null @@ -1,19 +0,0 @@ --- v7: Store role info -CREATE TABLE role ( - dc_guild_id TEXT, - dcid TEXT, - - name TEXT NOT NULL, - icon TEXT, - - mentionable BOOLEAN NOT NULL, - managed BOOLEAN NOT NULL, - hoist BOOLEAN NOT NULL, - - color INTEGER NOT NULL, - position INTEGER NOT NULL, - permissions BIGINT NOT NULL, - - PRIMARY KEY (dc_guild_id, dcid), - CONSTRAINT role_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild (dcid) ON DELETE CASCADE -); diff --git a/database/upgrades/08-channel-plain-name.sql b/database/upgrades/08-channel-plain-name.sql deleted file mode 100644 index 22237b6..0000000 --- a/database/upgrades/08-channel-plain-name.sql +++ /dev/null @@ -1,9 +0,0 @@ --- v8: Store plain name of channels and guilds -ALTER TABLE guild ADD COLUMN plain_name TEXT; -ALTER TABLE portal ADD COLUMN plain_name TEXT; -UPDATE guild SET plain_name=name; -UPDATE portal SET plain_name=name; -UPDATE portal SET plain_name='' WHERE type=1; --- only: postgres for next 2 lines -ALTER TABLE guild ALTER COLUMN plain_name SET NOT NULL; -ALTER TABLE portal ALTER COLUMN plain_name SET NOT NULL; diff --git a/database/upgrades/09-more-thread-data.sql b/database/upgrades/09-more-thread-data.sql deleted file mode 100644 index 461a1d4..0000000 --- a/database/upgrades/09-more-thread-data.sql +++ /dev/null @@ -1,9 +0,0 @@ --- v9: Store more info for proper thread support -ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT ''; -UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL; -UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL; - --- only: postgres for next 3 lines -ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT; -ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL; -ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL; diff --git a/database/upgrades/10-remove-broken-double-puppets.sql b/database/upgrades/10-remove-broken-double-puppets.sql deleted file mode 100644 index 862c917..0000000 --- a/database/upgrades/10-remove-broken-double-puppets.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v10: Remove double puppet ghosts added while there was a bug in the bridge -DELETE FROM puppet WHERE id=''; diff --git a/database/upgrades/11-cache-reuploaded-files.sql b/database/upgrades/11-cache-reuploaded-files.sql deleted file mode 100644 index c32c2bc..0000000 --- a/database/upgrades/11-cache-reuploaded-files.sql +++ /dev/null @@ -1,18 +0,0 @@ --- v11: Cache files copied from Discord to Matrix -CREATE TABLE discord_file ( - url TEXT, - encrypted BOOLEAN, - - id TEXT, - mxc TEXT NOT NULL, - - size BIGINT NOT NULL, - width INTEGER, - height INTEGER, - - decryption_info jsonb, - - timestamp BIGINT NOT NULL, - - PRIMARY KEY (url, encrypted) -); diff --git a/database/upgrades/12-file-cache-mime-type.sql b/database/upgrades/12-file-cache-mime-type.sql deleted file mode 100644 index 1bdb960..0000000 --- a/database/upgrades/12-file-cache-mime-type.sql +++ /dev/null @@ -1,4 +0,0 @@ --- v12: Cache mime type for reuploaded files -ALTER TABLE discord_file ADD COLUMN mime_type TEXT NOT NULL DEFAULT ''; --- only: postgres -ALTER TABLE discord_file ALTER COLUMN mime_type DROP DEFAULT; diff --git a/database/upgrades/13-merge-emoji-and-file.postgres.sql b/database/upgrades/13-merge-emoji-and-file.postgres.sql deleted file mode 100644 index 18ef607..0000000 --- a/database/upgrades/13-merge-emoji-and-file.postgres.sql +++ /dev/null @@ -1,4 +0,0 @@ --- v13: Merge tables used for cached custom emojis and attachments -ALTER TABLE discord_file ADD CONSTRAINT mxc_unique UNIQUE (mxc); -ALTER TABLE discord_file ADD COLUMN emoji_name TEXT; -DROP TABLE emoji; diff --git a/database/upgrades/13-merge-emoji-and-file.sqlite.sql b/database/upgrades/13-merge-emoji-and-file.sqlite.sql deleted file mode 100644 index ffe1b25..0000000 --- a/database/upgrades/13-merge-emoji-and-file.sqlite.sql +++ /dev/null @@ -1,24 +0,0 @@ --- v13: Merge tables used for cached custom emojis and attachments -CREATE TABLE new_discord_file ( - url TEXT, - encrypted BOOLEAN, - mxc TEXT NOT NULL UNIQUE, - - id TEXT, - emoji_name TEXT, - - size BIGINT NOT NULL, - width INTEGER, - height INTEGER, - mime_type TEXT NOT NULL, - decryption_info jsonb, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (url, encrypted) -); - -INSERT INTO new_discord_file (url, encrypted, id, mxc, size, width, height, mime_type, decryption_info, timestamp) -SELECT url, encrypted, id, mxc, size, width, height, mime_type, decryption_info, timestamp FROM discord_file; - -DROP TABLE discord_file; -ALTER TABLE new_discord_file RENAME TO discord_file; diff --git a/database/upgrades/14-guild-bridging-mode.sql b/database/upgrades/14-guild-bridging-mode.sql deleted file mode 100644 index 854d1c0..0000000 --- a/database/upgrades/14-guild-bridging-mode.sql +++ /dev/null @@ -1,7 +0,0 @@ --- v14: Add more modes of bridging guilds -ALTER TABLE guild ADD COLUMN bridging_mode INTEGER NOT NULL DEFAULT 0; -UPDATE guild SET bridging_mode=2 WHERE mxid<>''; -UPDATE guild SET bridging_mode=3 WHERE auto_bridge_channels=true; -ALTER TABLE guild DROP COLUMN auto_bridge_channels; --- only: postgres -ALTER TABLE guild ALTER COLUMN bridging_mode DROP DEFAULT; diff --git a/database/upgrades/15-portal-relay-webhook.sql b/database/upgrades/15-portal-relay-webhook.sql deleted file mode 100644 index 0035d00..0000000 --- a/database/upgrades/15-portal-relay-webhook.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v15: Store relay webhook URL for portals -ALTER TABLE portal ADD COLUMN relay_webhook_id TEXT; -ALTER TABLE portal ADD COLUMN relay_webhook_secret TEXT; diff --git a/database/upgrades/16-add-contact-info.sql b/database/upgrades/16-add-contact-info.sql deleted file mode 100644 index 8595ae3..0000000 --- a/database/upgrades/16-add-contact-info.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v16: Store whether custom contact info has been set for the puppet - -ALTER TABLE puppet ADD COLUMN contact_info_set BOOLEAN NOT NULL DEFAULT false; diff --git a/database/upgrades/17-dm-portal-friend-nick.sql b/database/upgrades/17-dm-portal-friend-nick.sql deleted file mode 100644 index 2c2b43c..0000000 --- a/database/upgrades/17-dm-portal-friend-nick.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v17: Store whether DM portal name is a friend nickname -ALTER TABLE portal ADD COLUMN friend_nick BOOLEAN NOT NULL DEFAULT false; diff --git a/database/upgrades/18-extra-ghost-metadata.sql b/database/upgrades/18-extra-ghost-metadata.sql deleted file mode 100644 index 92677dc..0000000 --- a/database/upgrades/18-extra-ghost-metadata.sql +++ /dev/null @@ -1,4 +0,0 @@ --- v18 (compatible with v15+): Store additional metadata for ghosts -ALTER TABLE puppet ADD COLUMN username TEXT NOT NULL DEFAULT ''; -ALTER TABLE puppet ADD COLUMN discriminator TEXT NOT NULL DEFAULT ''; -ALTER TABLE puppet ADD COLUMN is_bot BOOLEAN NOT NULL DEFAULT false; diff --git a/database/upgrades/19-message-edit-ts.postgres.sql b/database/upgrades/19-message-edit-ts.postgres.sql deleted file mode 100644 index 231afa1..0000000 --- a/database/upgrades/19-message-edit-ts.postgres.sql +++ /dev/null @@ -1,15 +0,0 @@ --- v19: Replace dc_edit_index with dc_edit_timestamp --- transaction: off -BEGIN; - -ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; -ALTER TABLE message DROP CONSTRAINT message_pkey; -ALTER TABLE message DROP COLUMN dc_edit_index; -ALTER TABLE reaction DROP COLUMN _dc_first_edit_index; -ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver); -ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE; - -ALTER TABLE message ADD COLUMN dc_edit_timestamp BIGINT NOT NULL DEFAULT 0; -ALTER TABLE message ALTER COLUMN dc_edit_timestamp DROP DEFAULT; - -COMMIT; diff --git a/database/upgrades/19-message-edit-ts.sqlite.sql b/database/upgrades/19-message-edit-ts.sqlite.sql deleted file mode 100644 index a25f317..0000000 --- a/database/upgrades/19-message-edit-ts.sqlite.sql +++ /dev/null @@ -1,48 +0,0 @@ --- v19: Replace dc_edit_index with dc_edit_timestamp --- transaction: off -PRAGMA foreign_keys = OFF; -BEGIN; - -CREATE TABLE message_new ( - dcid TEXT, - dc_attachment_id TEXT, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - dc_edit_timestamp BIGINT NOT NULL, - dc_thread_id TEXT NOT NULL, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver), - CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE -); -INSERT INTO message_new (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid) - SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, 0, dc_thread_id, mxid FROM message; -DROP TABLE message; -ALTER TABLE message_new RENAME TO message; - -CREATE TABLE reaction_new ( - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_msg_id TEXT, - dc_sender TEXT, - dc_emoji_name TEXT, - dc_thread_id TEXT NOT NULL, - - dc_first_attachment_id TEXT NOT NULL, - - mxid TEXT NOT NULL UNIQUE, - - PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE -); -INSERT INTO reaction_new (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid) - SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, COALESCE(dc_thread_id, ''), dc_first_attachment_id, mxid FROM reaction; -DROP TABLE reaction; -ALTER TABLE reaction_new RENAME TO reaction; - -PRAGMA foreign_key_check; -COMMIT; -PRAGMA foreign_keys = ON; diff --git a/database/upgrades/20-message-sender-mxid.sql b/database/upgrades/20-message-sender-mxid.sql deleted file mode 100644 index aa2bd65..0000000 --- a/database/upgrades/20-message-sender-mxid.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v20 (compatible with v19+): Store message sender Matrix user ID -ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; diff --git a/database/upgrades/21-more-puppet-info.sql b/database/upgrades/21-more-puppet-info.sql deleted file mode 100644 index 3bc374a..0000000 --- a/database/upgrades/21-more-puppet-info.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v21 (compatible with v19+): Store global displayname and is webhook status for puppets -ALTER TABLE puppet ADD COLUMN global_name TEXT NOT NULL DEFAULT ''; -ALTER TABLE puppet ADD COLUMN is_webhook BOOLEAN NOT NULL DEFAULT false; diff --git a/database/upgrades/22-file-cache-duplicate-mxc.sql b/database/upgrades/22-file-cache-duplicate-mxc.sql deleted file mode 100644 index b0bac3b..0000000 --- a/database/upgrades/22-file-cache-duplicate-mxc.sql +++ /dev/null @@ -1,26 +0,0 @@ --- v22 (compatible with v19+): Allow non-unique mxc URIs in file cache -CREATE TABLE new_discord_file ( - url TEXT, - encrypted BOOLEAN, - mxc TEXT NOT NULL, - - id TEXT, - emoji_name TEXT, - - size BIGINT NOT NULL, - width INTEGER, - height INTEGER, - mime_type TEXT NOT NULL, - decryption_info jsonb, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (url, encrypted) -); - -INSERT INTO new_discord_file (url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp) -SELECT url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp FROM discord_file; - -DROP TABLE discord_file; -ALTER TABLE new_discord_file RENAME TO discord_file; - -CREATE INDEX discord_file_mxc_idx ON discord_file (mxc); diff --git a/database/upgrades/23-puppet-is-application.sql b/database/upgrades/23-puppet-is-application.sql deleted file mode 100644 index 6279c88..0000000 --- a/database/upgrades/23-puppet-is-application.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v23 (compatible with v19+): Store is application status for puppets -ALTER TABLE puppet ADD COLUMN is_application BOOLEAN NOT NULL DEFAULT false; diff --git a/database/user.go b/database/user.go deleted file mode 100644 index 763625d..0000000 --- a/database/user.go +++ /dev/null @@ -1,101 +0,0 @@ -package database - -import ( - "database/sql" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type UserQuery struct { - db *Database - log log.Logger -} - -func (uq *UserQuery) New() *User { - return &User{ - db: uq.db, - log: uq.log, - } -} - -func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - query := `SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room, read_state_version FROM "user" WHERE mxid=$1` - return uq.New().Scan(uq.db.QueryRow(query, userID)) -} - -func (uq *UserQuery) GetByID(id string) *User { - query := `SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room, read_state_version FROM "user" WHERE dcid=$1` - return uq.New().Scan(uq.db.QueryRow(query, id)) -} - -func (uq *UserQuery) GetAllWithToken() []*User { - query := ` - SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room, read_state_version - FROM "user" WHERE discord_token IS NOT NULL - ` - rows, err := uq.db.Query(query) - if err != nil || rows == nil { - return nil - } - - var users []*User - for rows.Next() { - user := uq.New().Scan(rows) - if user != nil { - users = append(users, user) - } - } - return users -} - -type User struct { - db *Database - log log.Logger - - MXID id.UserID - DiscordID string - DiscordToken string - ManagementRoom id.RoomID - SpaceRoom id.RoomID - DMSpaceRoom id.RoomID - - ReadStateVersion int -} - -func (u *User) Scan(row dbutil.Scannable) *User { - var discordID, managementRoom, spaceRoom, dmSpaceRoom, discordToken sql.NullString - err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom, &dmSpaceRoom, &u.ReadStateVersion) - if err != nil { - if err != sql.ErrNoRows { - u.log.Errorln("Database scan failed:", err) - panic(err) - } - return nil - } - u.DiscordID = discordID.String - u.DiscordToken = discordToken.String - u.ManagementRoom = id.RoomID(managementRoom.String) - u.SpaceRoom = id.RoomID(spaceRoom.String) - u.DMSpaceRoom = id.RoomID(dmSpaceRoom.String) - return u -} - -func (u *User) Insert() { - query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room, dm_space_room, read_state_version) VALUES ($1, $2, $3, $4, $5, $6, $7)` - _, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), strPtr(string(u.DMSpaceRoom)), u.ReadStateVersion) - if err != nil { - u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) - panic(err) - } -} - -func (u *User) Update() { - query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4, dm_space_room=$5, read_state_version=$6 WHERE mxid=$7` - _, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), strPtr(string(u.DMSpaceRoom)), u.ReadStateVersion, u.MXID) - if err != nil { - u.log.Warnfln("Failed to update %q: %v", u.MXID, err) - panic(err) - } -} diff --git a/database/userportal.go b/database/userportal.go deleted file mode 100644 index 783b83d..0000000 --- a/database/userportal.go +++ /dev/null @@ -1,140 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - "time" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -const ( - UserPortalTypeDM = "dm" - UserPortalTypeGuild = "guild" - UserPortalTypeThread = "thread" -) - -type UserPortal struct { - DiscordID string - Type string - Timestamp time.Time - InSpace bool -} - -func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal { - var ts int64 - err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace) - if err != nil { - l.Errorln("Error scanning user portal:", err) - panic(err) - } - up.Timestamp = time.UnixMilli(ts).UTC() - return &up -} - -func (u *User) scanUserPortals(rows dbutil.Rows) []UserPortal { - var ups []UserPortal - for rows.Next() { - up := UserPortal{}.Scan(u.log, rows) - if up != nil { - ups = append(ups, *up) - } - } - return ups -} - -func (db *Database) GetUsersInPortal(channelID string) []id.UserID { - rows, err := db.Query("SELECT user_mxid FROM user_portal WHERE discord_id=$1", channelID) - if err != nil { - db.Portal.log.Errorln("Failed to get users in portal:", err) - } - var users []id.UserID - for rows.Next() { - var mxid id.UserID - err = rows.Scan(&mxid) - if err != nil { - db.Portal.log.Errorln("Failed to scan user in portal:", err) - } else { - users = append(users, mxid) - } - } - return users -} - -func (u *User) GetPortals() []UserPortal { - rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID) - if err != nil { - u.log.Errorln("Failed to get portals:", err) - panic(err) - } - return u.scanUserPortals(rows) -} - -func (u *User) IsInSpace(discordID string) (isIn bool) { - query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2` - err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err) - panic(err) - } - return -} - -func (u *User) IsInPortal(discordID string) (isIn bool) { - query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)` - err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err) - panic(err) - } - return -} - -func (u *User) MarkInPortal(portal UserPortal) { - query := ` - INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (discord_id, user_mxid) DO UPDATE - SET timestamp=excluded.timestamp, in_space=excluded.in_space - ` - _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace) - if err != nil { - u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err) - panic(err) - } -} - -func (u *User) MarkNotInPortal(discordID string) { - query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2` - _, err := u.db.Exec(query, u.MXID, discordID) - if err != nil { - u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err) - panic(err) - } -} - -func (u *User) PortalHasOtherUsers(discordID string) (hasOtherUsers bool) { - query := `SELECT COUNT(*) > 0 FROM user_portal WHERE user_mxid<>$1 AND discord_id=$2` - err := u.db.QueryRow(query, u.MXID, discordID).Scan(&hasOtherUsers) - if err != nil { - u.log.Errorfln("Failed to check if %s has users other than %s: %v", discordID, u.MXID, err) - panic(err) - } - return -} - -func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal { - query := ` - DELETE FROM user_portal - WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild') - RETURNING discord_id, type, timestamp, in_space - ` - rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli()) - if err != nil { - u.log.Errorln("Failed to prune user guild list:", err) - panic(err) - } - return u.scanUserPortals(rows) -} diff --git a/directmedia.go b/directmedia.go deleted file mode 100644 index 4499c1a..0000000 --- a/directmedia.go +++ /dev/null @@ -1,662 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2024 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "context" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "mime" - "mime/multipart" - "net" - "net/http" - "net/textproto" - "net/url" - "os" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/gorilla/mux" - "github.com/rs/zerolog" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/federation" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/config" - "go.mau.fi/mautrix-discord/database" -) - -type DirectMediaAPI struct { - bridge *DiscordBridge - ks *federation.KeyServer - cfg config.DirectMedia - log zerolog.Logger - proxy http.Client - - signatureKey [32]byte - - attachmentCache map[AttachmentCacheKey]AttachmentCacheValue - attachmentCacheLock sync.Mutex -} - -type AttachmentCacheKey struct { - ChannelID uint64 - AttachmentID uint64 -} - -type AttachmentCacheValue struct { - URL string - Expiry time.Time -} - -func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { - if !br.Config.Bridge.DirectMedia.Enabled { - return nil - } - dma := &DirectMediaAPI{ - bridge: br, - cfg: br.Config.Bridge.DirectMedia, - log: br.ZLog.With().Str("component", "direct media").Logger(), - proxy: http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ForceAttemptHTTP2: false, - }, - Timeout: 60 * time.Second, - }, - attachmentCache: make(map[AttachmentCacheKey]AttachmentCacheValue), - } - r := br.AS.Router - - parsed, err := federation.ParseSynapseKey(dma.cfg.ServerKey) - if err != nil { - dma.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to parse server key") - os.Exit(11) - return nil - } - dma.signatureKey = sha256.Sum256(parsed.Priv.Seed()) - dma.ks = &federation.KeyServer{ - KeyProvider: &federation.StaticServerKey{ - ServerName: dma.cfg.ServerName, - Key: parsed, - }, - WellKnownTarget: dma.cfg.WellKnownResponse, - Version: federation.ServerVersion{ - Name: br.Name, - Version: br.Version, - }, - } - if dma.ks.WellKnownTarget == "" { - dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName) - } - federationRouter := r.PathPrefix("/_matrix/federation").Subrouter() - mediaRouter := r.PathPrefix("/_matrix/media").Subrouter() - clientMediaRouter := r.PathPrefix("/_matrix/client/v1/media").Subrouter() - var reqIDCounter atomic.Uint64 - middleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - log := dma.log.With(). - Str("remote_addr", r.RemoteAddr). - Str("request_path", r.URL.Path). - Uint64("req_id", reqIDCounter.Add(1)). - Logger() - next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context()))) - }) - } - mediaRouter.Use(middleware) - federationRouter.Use(middleware) - clientMediaRouter.Use(middleware) - addRoutes := func(version string) { - mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) - mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) - mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) - mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) - mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost) - mediaRouter.HandleFunc("/"+version+"/create", dma.UploadNotSupported).Methods(http.MethodPost) - mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet) - mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) - } - clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) - clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) - clientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) - clientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) - clientMediaRouter.HandleFunc("/upload", dma.UploadNotSupported).Methods(http.MethodPost) - clientMediaRouter.HandleFunc("/create", dma.UploadNotSupported).Methods(http.MethodPost) - clientMediaRouter.HandleFunc("/config", dma.UploadNotSupported).Methods(http.MethodGet) - clientMediaRouter.HandleFunc("/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) - addRoutes("v3") - addRoutes("r0") - addRoutes("v1") - federationRouter.HandleFunc("/v1/media/download/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) - federationRouter.HandleFunc("/v1/version", dma.ks.GetServerVersion).Methods(http.MethodGet) - mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) - mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) - federationRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) - federationRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) - dma.ks.Register(r) - - return dma -} - -func (dma *DirectMediaAPI) makeMXC(data MediaIDData) id.ContentURI { - return id.ContentURI{ - Homeserver: dma.cfg.ServerName, - FileID: data.Wrap().SignedString(dma.signatureKey), - } -} - -func parseExpiryTS(addr string) time.Time { - parsedURL, err := url.Parse(addr) - if err != nil { - return time.Time{} - } - tsBytes, err := hex.DecodeString(parsedURL.Query().Get("ex")) - if err != nil || len(tsBytes) != 4 { - return time.Time{} - } - parsedTS := int64(binary.BigEndian.Uint32(tsBytes)) - if parsedTS > time.Now().Unix() && parsedTS < time.Now().Add(365*24*time.Hour).Unix() { - return time.Unix(parsedTS, 0) - } - return time.Time{} -} - -func (dma *DirectMediaAPI) addAttachmentToCache(channelID uint64, att *discordgo.MessageAttachment) time.Time { - attachmentID, err := strconv.ParseUint(att.ID, 10, 64) - if err != nil { - return time.Time{} - } - expiry := parseExpiryTS(att.URL) - if expiry.IsZero() { - expiry = time.Now().Add(24 * time.Hour) - } - dma.attachmentCache[AttachmentCacheKey{ - ChannelID: channelID, - AttachmentID: attachmentID, - }] = AttachmentCacheValue{ - URL: att.URL, - Expiry: expiry, - } - return expiry -} - -func (dma *DirectMediaAPI) AttachmentMXC(channelID, messageID string, att *discordgo.MessageAttachment) (mxc id.ContentURI) { - if dma == nil { - return - } - channelIDInt, err := strconv.ParseUint(channelID, 10, 64) - if err != nil { - dma.log.Warn().Str("channel_id", channelID).Msg("Got non-integer channel ID") - return - } - messageIDInt, err := strconv.ParseUint(messageID, 10, 64) - if err != nil { - dma.log.Warn().Str("message_id", messageID).Msg("Got non-integer message ID") - return - } - attachmentIDInt, err := strconv.ParseUint(att.ID, 10, 64) - if err != nil { - dma.log.Warn().Str("attachment_id", att.ID).Msg("Got non-integer attachment ID") - return - } - dma.attachmentCacheLock.Lock() - dma.addAttachmentToCache(channelIDInt, att) - dma.attachmentCacheLock.Unlock() - return dma.makeMXC(&AttachmentMediaData{ - ChannelID: channelIDInt, - MessageID: messageIDInt, - AttachmentID: attachmentIDInt, - }) -} - -func (dma *DirectMediaAPI) EmojiMXC(emojiID, name string, animated bool) (mxc id.ContentURI) { - if dma == nil { - return - } - emojiIDInt, err := strconv.ParseUint(emojiID, 10, 64) - if err != nil { - dma.log.Warn().Str("emoji_id", emojiID).Msg("Got non-integer emoji ID") - return - } - return dma.makeMXC(&EmojiMediaData{ - EmojiMediaDataInner: EmojiMediaDataInner{ - EmojiID: emojiIDInt, - Animated: animated, - }, - Name: name, - }) -} - -func (dma *DirectMediaAPI) StickerMXC(stickerID string, format discordgo.StickerFormat) (mxc id.ContentURI) { - if dma == nil { - return - } - stickerIDInt, err := strconv.ParseUint(stickerID, 10, 64) - if err != nil { - dma.log.Warn().Str("sticker_id", stickerID).Msg("Got non-integer sticker ID") - return - } else if format > 255 || format < 0 { - dma.log.Warn().Int("format", int(format)).Msg("Got invalid sticker format") - return - } - return dma.makeMXC(&StickerMediaData{ - StickerID: stickerIDInt, - Format: byte(format), - }) -} - -func (dma *DirectMediaAPI) AvatarMXC(guildID, userID, avatarID string) (mxc id.ContentURI) { - if dma == nil { - return - } - animated := strings.HasPrefix(avatarID, "a_") - avatarIDBytes, err := hex.DecodeString(strings.TrimPrefix(avatarID, "a_")) - if err != nil { - dma.log.Warn().Str("avatar_id", avatarID).Msg("Got non-hex avatar ID") - return - } else if len(avatarIDBytes) != 16 { - dma.log.Warn().Str("avatar_id", avatarID).Msg("Got invalid avatar ID length") - return - } - avatarIDArray := [16]byte(avatarIDBytes) - userIDInt, err := strconv.ParseUint(userID, 10, 64) - if err != nil { - dma.log.Warn().Str("user_id", userID).Msg("Got non-integer user ID") - return - } - if guildID != "" { - guildIDInt, err := strconv.ParseUint(guildID, 10, 64) - if err != nil { - dma.log.Warn().Str("guild_id", guildID).Msg("Got non-integer guild ID") - return - } - return dma.makeMXC(&GuildMemberAvatarMediaData{ - GuildID: guildIDInt, - UserID: userIDInt, - AvatarID: avatarIDArray, - Animated: animated, - }) - } else { - return dma.makeMXC(&UserAvatarMediaData{ - UserID: userIDInt, - AvatarID: avatarIDArray, - Animated: animated, - }) - } -} - -type RespError struct { - Code string - Message string - Status int -} - -func (re *RespError) Error() string { - return re.Message -} - -var ErrNoUsersWithAccessFound = errors.New("no users found to fetch message") -var ErrAttachmentNotFound = errors.New("attachment not found") - -func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *AttachmentMediaData) (string, time.Time, error) { - var client *discordgo.Session - channelIDStr := strconv.FormatUint(meta.ChannelID, 10) - portal := dma.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: channelIDStr}) - var users []id.UserID - if portal != nil && portal.GuildID != "" { - users = dma.bridge.DB.GetUsersInPortal(portal.GuildID) - } else { - users = dma.bridge.DB.GetUsersInPortal(channelIDStr) - } - for _, userID := range users { - user := dma.bridge.GetCachedUserByMXID(userID) - if user == nil || user.Session == nil { - continue - } - perms, err := user.Session.State.UserChannelPermissions(user.DiscordID, channelIDStr) - if err == nil && perms&discordgo.PermissionViewChannel == 0 { - continue - } - if client == nil || err == nil { - client = user.Session - if !client.IsUser { - break - } - } - } - if client == nil { - return "", time.Time{}, ErrNoUsersWithAccessFound - } - var msgs []*discordgo.Message - var err error - messageIDStr := strconv.FormatUint(meta.MessageID, 10) - if client.IsUser { - var refs []discordgo.RequestOption - if portal != nil { - refs = append(refs, discordgo.WithChannelReferer(portal.GuildID, channelIDStr)) - } - msgs, err = client.ChannelMessages(channelIDStr, 5, "", "", messageIDStr, refs...) - } else { - var msg *discordgo.Message - msg, err = client.ChannelMessage(channelIDStr, messageIDStr) - msgs = []*discordgo.Message{msg} - } - if err != nil { - return "", time.Time{}, fmt.Errorf("failed to fetch message: %w", err) - } - attachmentIDStr := strconv.FormatUint(meta.AttachmentID, 10) - var url string - var expiry time.Time - for _, item := range msgs { - for _, att := range item.Attachments { - thisExpiry := dma.addAttachmentToCache(meta.ChannelID, att) - if att.ID == attachmentIDStr { - url = att.URL - expiry = thisExpiry - } - } - } - if url == "" { - return "", time.Time{}, ErrAttachmentNotFound - } - return url, expiry, nil -} - -func (dma *DirectMediaAPI) GetEmojiInfo(contentURI id.ContentURI) *EmojiMediaData { - if dma == nil || contentURI.IsEmpty() || contentURI.Homeserver != dma.cfg.ServerName { - return nil - } - mediaID, err := ParseMediaID(contentURI.FileID, dma.signatureKey) - if err != nil { - return nil - } - emojiData, ok := mediaID.Data.(*EmojiMediaData) - if !ok { - return nil - } - return emojiData - -} - -func (dma *DirectMediaAPI) getMediaURL(ctx context.Context, encodedMediaID string) (url string, expiry time.Time, err error) { - var mediaID *MediaID - mediaID, err = ParseMediaID(encodedMediaID, dma.signatureKey) - if err != nil { - err = &RespError{ - Code: mautrix.MNotFound.ErrCode, - Message: err.Error(), - Status: http.StatusNotFound, - } - return - } - switch mediaData := mediaID.Data.(type) { - case *AttachmentMediaData: - dma.attachmentCacheLock.Lock() - defer dma.attachmentCacheLock.Unlock() - cached, ok := dma.attachmentCache[mediaData.CacheKey()] - if ok && time.Until(cached.Expiry) > 5*time.Minute { - return cached.URL, cached.Expiry, nil - } - zerolog.Ctx(ctx).Debug(). - Uint64("channel_id", mediaData.ChannelID). - Uint64("message_id", mediaData.MessageID). - Uint64("attachment_id", mediaData.AttachmentID). - Msg("Refreshing attachment URL") - url, expiry, err = dma.fetchNewAttachmentURL(ctx, mediaData) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to refresh attachment URL") - msg := "Failed to refresh attachment URL" - if errors.Is(err, ErrNoUsersWithAccessFound) { - msg = "No users found with access to the channel" - } else if errors.Is(err, ErrAttachmentNotFound) { - msg = "Attachment not found in message. Perhaps it was deleted?" - } - err = &RespError{ - Code: mautrix.MNotFound.ErrCode, - Message: msg, - Status: http.StatusNotFound, - } - } else { - zerolog.Ctx(ctx).Debug().Time("expiry", expiry).Msg("Successfully refreshed attachment URL") - } - case *EmojiMediaData: - if mediaData.Animated { - url = discordgo.EndpointEmojiAnimated(strconv.FormatUint(mediaData.EmojiID, 10)) - } else { - url = discordgo.EndpointEmoji(strconv.FormatUint(mediaData.EmojiID, 10)) - } - case *StickerMediaData: - url = discordgo.EndpointStickerImage( - strconv.FormatUint(mediaData.StickerID, 10), - discordgo.StickerFormat(mediaData.Format), - ) - case *UserAvatarMediaData: - if mediaData.Animated { - url = discordgo.EndpointUserAvatarAnimated( - strconv.FormatUint(mediaData.UserID, 10), - fmt.Sprintf("a_%x", mediaData.AvatarID), - ) - } else { - url = discordgo.EndpointUserAvatar( - strconv.FormatUint(mediaData.UserID, 10), - fmt.Sprintf("%x", mediaData.AvatarID), - ) - } - case *GuildMemberAvatarMediaData: - if mediaData.Animated { - url = discordgo.EndpointGuildMemberAvatarAnimated( - strconv.FormatUint(mediaData.GuildID, 10), - strconv.FormatUint(mediaData.UserID, 10), - fmt.Sprintf("a_%x", mediaData.AvatarID), - ) - } else { - url = discordgo.EndpointGuildMemberAvatar( - strconv.FormatUint(mediaData.GuildID, 10), - strconv.FormatUint(mediaData.UserID, 10), - fmt.Sprintf("%x", mediaData.AvatarID), - ) - } - default: - zerolog.Ctx(ctx).Error().Type("media_data_type", mediaData).Msg("Unrecognized media data struct") - err = &RespError{ - Code: "M_UNKNOWN", - Message: "Unrecognized media data struct", - Status: http.StatusInternalServerError, - } - } - return -} - -func (dma *DirectMediaAPI) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) { - log := zerolog.Ctx(ctx) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - log.Err(err).Str("url", url).Msg("Failed to create proxy request") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Failed to create proxy request", - }) - return - } - for key, val := range discordgo.DroidDownloadHeaders { - req.Header.Set(key, val) - } - resp, err := dma.proxy.Do(req) - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - if err != nil { - log.Err(err).Str("url", url).Msg("Failed to proxy download") - jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Failed to proxy download", - }) - return - } else if resp.StatusCode != http.StatusOK { - log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") - jsonResponse(w, resp.StatusCode, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Unexpected status code proxying download", - }) - return - } - w.Header()["Content-Type"] = resp.Header["Content-Type"] - w.Header()["Content-Length"] = resp.Header["Content-Length"] - w.Header()["Last-Modified"] = resp.Header["Last-Modified"] - w.Header()["Cache-Control"] = resp.Header["Cache-Control"] - contentDisposition := "attachment" - switch resp.Header.Get("Content-Type") { - case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", - "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", - "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", - "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": - contentDisposition = "inline" - } - if fileName != "" { - contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ - "filename": fileName, - }) - } - w.Header().Set("Content-Disposition", contentDisposition) - w.WriteHeader(http.StatusOK) - _, err = io.Copy(w, resp.Body) - if err != nil { - log.Debug().Err(err).Msg("Failed to write proxy response") - } -} - -func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - log := zerolog.Ctx(ctx) - isNewFederation := strings.HasPrefix(r.URL.Path, "/_matrix/federation/v1/media/download/") - vars := mux.Vars(r) - if !isNewFederation && vars["serverName"] != dma.cfg.ServerName { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName), - }) - return - } - // TODO check destination header in X-Matrix auth when isNewFederation - - url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"]) - if err != nil { - var respError *RespError - if errors.As(err, &respError) { - jsonResponse(w, respError.Status, &mautrix.RespError{ - ErrCode: respError.Code, - Err: respError.Message, - }) - } else { - log.Err(err).Str("media_id", vars["mediaID"]).Msg("Failed to get media URL") - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Media not found", - }) - } - return - } - if isNewFederation { - mp := multipart.NewWriter(w) - w.Header().Set("Content-Type", strings.Replace(mp.FormDataContentType(), "form-data", "mixed", 1)) - var metaPart io.Writer - metaPart, err = mp.CreatePart(textproto.MIMEHeader{ - "Content-Type": {"application/json"}, - }) - if err != nil { - log.Err(err).Msg("Failed to create multipart metadata field") - return - } - _, err = metaPart.Write([]byte(`{}`)) - if err != nil { - log.Err(err).Msg("Failed to write multipart metadata field") - return - } - _, err = mp.CreatePart(textproto.MIMEHeader{ - "Location": {url}, - }) - if err != nil { - log.Err(err).Msg("Failed to create multipart redirect field") - return - } - err = mp.Close() - if err != nil { - log.Err(err).Msg("Failed to close multipart writer") - return - } - return - } - // Proxy if the config allows proxying and the request doesn't allow redirects. - // In any other case, redirect to the Discord CDN. - if dma.cfg.AllowProxy && r.URL.Query().Get("allow_redirect") != "true" { - dma.proxyDownload(ctx, w, url, vars["fileName"]) - return - } - w.Header().Set("Location", url) - expirySeconds := (time.Until(expiresAt) - 5*time.Minute).Seconds() - if expiresAt.IsZero() { - w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") - } else if expirySeconds > 0 { - cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds)) - w.Header().Set("Cache-Control", cacheControl) - } else { - w.Header().Set("Cache-Control", "no-store") - } - w.WriteHeader(http.StatusTemporaryRedirect) -} - -func (dma *DirectMediaAPI) UploadNotSupported(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "This bridge only supports proxying Discord media downloads and does not support media uploads.", - }) -} - -func (dma *DirectMediaAPI) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "This bridge only supports proxying Discord media downloads and does not support URL previews.", - }) -} - -func (dma *DirectMediaAPI) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) -} - -func (dma *DirectMediaAPI) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) -} diff --git a/directmedia_id.go b/directmedia_id.go deleted file mode 100644 index 92b935a..0000000 --- a/directmedia_id.go +++ /dev/null @@ -1,287 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2024 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "errors" - "fmt" - "io" -) - -const MediaIDPrefix = "\U0001F408DISCORD" -const MediaIDVersion = 1 - -type MediaIDClass uint8 - -const ( - MediaIDClassAttachment MediaIDClass = 1 - MediaIDClassEmoji MediaIDClass = 2 - MediaIDClassSticker MediaIDClass = 3 - MediaIDClassUserAvatar MediaIDClass = 4 - MediaIDClassGuildMemberAvatar MediaIDClass = 5 -) - -type MediaIDData interface { - Write(to io.Writer) - Read(from io.Reader) error - Size() int - Wrap() *MediaID -} - -type MediaID struct { - Version uint8 - TypeClass MediaIDClass - Data MediaIDData -} - -func ParseMediaID(id string, key [32]byte) (*MediaID, error) { - data, err := base64.RawURLEncoding.DecodeString(id) - if err != nil { - return nil, fmt.Errorf("failed to decode base64: %w", err) - } - hasher := hmac.New(sha256.New, key[:]) - checksum := data[len(data)-TruncatedHashLength:] - data = data[:len(data)-TruncatedHashLength] - hasher.Write(data) - if !hmac.Equal(checksum, hasher.Sum(nil)[:TruncatedHashLength]) { - return nil, ErrMediaIDChecksumMismatch - } - mid := &MediaID{} - err = mid.Read(bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("failed to parse media ID: %w", err) - } - return mid, nil -} - -const TruncatedHashLength = 16 - -func (mid *MediaID) SignedString(key [32]byte) string { - buf := bytes.NewBuffer(make([]byte, 0, mid.Size())) - mid.Write(buf) - hasher := hmac.New(sha256.New, key[:]) - hasher.Write(buf.Bytes()) - buf.Write(hasher.Sum(nil)[:TruncatedHashLength]) - return base64.RawURLEncoding.EncodeToString(buf.Bytes()) -} - -func (mid *MediaID) Write(to io.Writer) { - _, _ = to.Write([]byte(MediaIDPrefix)) - _ = binary.Write(to, binary.BigEndian, mid.Version) - _ = binary.Write(to, binary.BigEndian, mid.TypeClass) - mid.Data.Write(to) -} - -func (mid *MediaID) Size() int { - return len(MediaIDPrefix) + 2 + mid.Data.Size() + TruncatedHashLength -} - -var ( - ErrInvalidMediaID = errors.New("invalid media ID") - ErrMediaIDChecksumMismatch = errors.New("invalid checksum in media ID") - ErrUnsupportedMediaID = errors.New("unsupported media ID") -) - -func (mid *MediaID) Read(from io.Reader) error { - prefix := make([]byte, len(MediaIDPrefix)) - _, err := io.ReadFull(from, prefix) - if err != nil || !bytes.Equal(prefix, []byte(MediaIDPrefix)) { - return fmt.Errorf("%w: prefix not found", ErrInvalidMediaID) - } - versionAndClass := make([]byte, 2) - _, err = io.ReadFull(from, versionAndClass) - if err != nil { - return fmt.Errorf("%w: version and class not found", ErrInvalidMediaID) - } else if versionAndClass[0] != MediaIDVersion { - return fmt.Errorf("%w: unknown version %d", ErrUnsupportedMediaID, versionAndClass[0]) - } - switch MediaIDClass(versionAndClass[1]) { - case MediaIDClassAttachment: - mid.Data = &AttachmentMediaData{} - case MediaIDClassEmoji: - mid.Data = &EmojiMediaData{} - case MediaIDClassSticker: - mid.Data = &StickerMediaData{} - case MediaIDClassUserAvatar: - mid.Data = &UserAvatarMediaData{} - case MediaIDClassGuildMemberAvatar: - mid.Data = &GuildMemberAvatarMediaData{} - default: - return fmt.Errorf("%w: unrecognized type class %d", ErrUnsupportedMediaID, versionAndClass[1]) - } - err = mid.Data.Read(from) - if err != nil { - return fmt.Errorf("failed to parse media ID data: %w", err) - } - return nil -} - -type AttachmentMediaData struct { - ChannelID uint64 - MessageID uint64 - AttachmentID uint64 -} - -func (amd *AttachmentMediaData) Write(to io.Writer) { - _ = binary.Write(to, binary.BigEndian, amd) -} - -func (amd *AttachmentMediaData) Read(from io.Reader) (err error) { - return binary.Read(from, binary.BigEndian, amd) -} - -func (amd *AttachmentMediaData) Size() int { - return binary.Size(amd) -} - -func (amd *AttachmentMediaData) Wrap() *MediaID { - return &MediaID{ - Version: MediaIDVersion, - TypeClass: MediaIDClassAttachment, - Data: amd, - } -} - -func (amd *AttachmentMediaData) CacheKey() AttachmentCacheKey { - return AttachmentCacheKey{ - ChannelID: amd.ChannelID, - AttachmentID: amd.AttachmentID, - } -} - -type StickerMediaData struct { - StickerID uint64 - Format uint8 -} - -func (smd *StickerMediaData) Write(to io.Writer) { - _ = binary.Write(to, binary.BigEndian, smd) -} - -func (smd *StickerMediaData) Read(from io.Reader) error { - return binary.Read(from, binary.BigEndian, smd) -} - -func (smd *StickerMediaData) Size() int { - return binary.Size(smd) -} - -func (smd *StickerMediaData) Wrap() *MediaID { - return &MediaID{ - Version: MediaIDVersion, - TypeClass: MediaIDClassSticker, - Data: smd, - } -} - -type EmojiMediaDataInner struct { - EmojiID uint64 - Animated bool -} - -type EmojiMediaData struct { - EmojiMediaDataInner - Name string -} - -func (emd *EmojiMediaData) Write(to io.Writer) { - _ = binary.Write(to, binary.BigEndian, &emd.EmojiMediaDataInner) - _, _ = to.Write([]byte(emd.Name)) -} - -func (emd *EmojiMediaData) Read(from io.Reader) (err error) { - err = binary.Read(from, binary.BigEndian, &emd.EmojiMediaDataInner) - if err != nil { - return - } - name, err := io.ReadAll(from) - if err != nil { - return - } - emd.Name = string(name) - return -} - -func (emd *EmojiMediaData) Size() int { - return binary.Size(&emd.EmojiMediaDataInner) + len(emd.Name) -} - -func (emd *EmojiMediaData) Wrap() *MediaID { - return &MediaID{ - Version: MediaIDVersion, - TypeClass: MediaIDClassEmoji, - Data: emd, - } -} - -type UserAvatarMediaData struct { - UserID uint64 - Animated bool - AvatarID [16]byte -} - -func (uamd *UserAvatarMediaData) Write(to io.Writer) { - _ = binary.Write(to, binary.BigEndian, uamd) -} - -func (uamd *UserAvatarMediaData) Read(from io.Reader) error { - return binary.Read(from, binary.BigEndian, uamd) -} - -func (uamd *UserAvatarMediaData) Size() int { - return binary.Size(uamd) -} - -func (uamd *UserAvatarMediaData) Wrap() *MediaID { - return &MediaID{ - Version: MediaIDVersion, - TypeClass: MediaIDClassUserAvatar, - Data: uamd, - } -} - -type GuildMemberAvatarMediaData struct { - GuildID uint64 - UserID uint64 - Animated bool - AvatarID [16]byte -} - -func (guamd *GuildMemberAvatarMediaData) Write(to io.Writer) { - _ = binary.Write(to, binary.BigEndian, guamd) -} - -func (guamd *GuildMemberAvatarMediaData) Read(from io.Reader) error { - return binary.Read(from, binary.BigEndian, guamd) -} - -func (guamd *GuildMemberAvatarMediaData) Size() int { - return binary.Size(guamd) -} - -func (guamd *GuildMemberAvatarMediaData) Wrap() *MediaID { - return &MediaID{ - Version: MediaIDVersion, - TypeClass: MediaIDClassGuildMemberAvatar, - Data: guamd, - } -} diff --git a/discord.go b/discord.go deleted file mode 100644 index 37cddbc..0000000 --- a/discord.go +++ /dev/null @@ -1,52 +0,0 @@ -package main - -import ( - "errors" - - "github.com/bwmarrin/discordgo" -) - -func (user *User) channelIsBridgeable(channel *discordgo.Channel) bool { - switch channel.Type { - case discordgo.ChannelTypeGuildText, discordgo.ChannelTypeGuildNews: - // allowed - case discordgo.ChannelTypeDM, discordgo.ChannelTypeGroupDM: - // DMs are always bridgeable, no need for permission checks - return true - default: - // everything else is not allowed - return false - } - - log := user.log.With().Str("guild_id", channel.GuildID).Str("channel_id", channel.ID).Logger() - - member, err := user.Session.State.Member(channel.GuildID, user.DiscordID) - if errors.Is(err, discordgo.ErrStateNotFound) { - log.Debug().Msg("Fetching own membership in guild to check roles") - member, err = user.Session.GuildMember(channel.GuildID, user.DiscordID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get own membership in guild from server") - } else { - err = user.Session.State.MemberAdd(member) - if err != nil { - log.Warn().Err(err).Msg("Failed to add own membership in guild to cache") - } - } - } else if err != nil { - log.Warn().Err(err).Msg("Failed to get own membership in guild from cache") - } - err = user.Session.State.ChannelAdd(channel) - if err != nil { - log.Warn().Err(err).Msg("Failed to add channel to cache") - } - perms, err := user.Session.State.UserChannelPermissions(user.DiscordID, channel.ID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get permissions in channel to determine if it's bridgeable") - return true - } - log.Debug(). - Int64("permissions", perms). - Bool("view_channel", perms&discordgo.PermissionViewChannel > 0). - Msg("Computed permissions in channel") - return perms&discordgo.PermissionViewChannel > 0 -} diff --git a/docker-run.sh b/docker-run.sh index 054a636..f4a5630 100755 --- a/docker-run.sh +++ b/docker-run.sh @@ -15,7 +15,7 @@ function fixperms { } if [[ ! -f /data/config.yaml ]]; then - cp /opt/mautrix-discord/example-config.yaml /data/config.yaml + /usr/bin/mautrix-discord -c /data/config.yaml -e echo "Didn't find a config file." echo "Copied default config file to /data/config.yaml" echo "Modify that config file to your liking." diff --git a/example-config.yaml b/example-config.yaml deleted file mode 100644 index ea392bb..0000000 --- a/example-config.yaml +++ /dev/null @@ -1,381 +0,0 @@ -# Homeserver details. -homeserver: - # The address that this appservice can use to connect to the homeserver. - address: https://matrix.example.com - # The domain of the homeserver (also known as server_name, used for MXIDs, etc). - domain: example.com - - # What software is the homeserver running? - # Standard Matrix homeservers like Synapse, Dendrite and Conduit should just use "standard" here. - software: standard - # The URL to push real-time bridge status to. - # If set, the bridge will make POST requests to this URL whenever a user's discord connection state changes. - # The bridge will use the appservice as_token to authorize requests. - status_endpoint: null - # Endpoint for reporting per-message status. - message_send_checkpoint_endpoint: null - # Does the homeserver support https://github.com/matrix-org/matrix-spec-proposals/pull/2246? - async_media: false - - # Should the bridge use a websocket for connecting to the homeserver? - # The server side is currently not documented anywhere and is only implemented by mautrix-wsproxy, - # mautrix-asmux (deprecated), and hungryserv (proprietary). - websocket: false - # How often should the websocket be pinged? Pinging will be disabled if this is zero. - ping_interval_seconds: 0 - -# Application service host/registration related details. -# Changing these values requires regeneration of the registration. -appservice: - # The address that the homeserver can use to connect to this appservice. - address: http://localhost:29334 - - # The hostname and port where this appservice should listen. - hostname: 0.0.0.0 - port: 29334 - - # Database config. - database: - # The database type. "sqlite3-fk-wal" and "postgres" are supported. - type: postgres - # The database URI. - # SQLite: A raw file path is supported, but `file:?_txlock=immediate` is recommended. - # https://github.com/mattn/go-sqlite3#connection-string - # Postgres: Connection string. For example, postgres://user:password@host/database?sslmode=disable - # To connect via Unix socket, use something like postgres:///dbname?host=/var/run/postgresql - uri: postgres://user:password@host/database?sslmode=disable - # Maximum number of connections. Mostly relevant for Postgres. - max_open_conns: 20 - max_idle_conns: 2 - # Maximum connection idle time and lifetime before they're closed. Disabled if null. - # Parsed with https://pkg.go.dev/time#ParseDuration - max_conn_idle_time: null - max_conn_lifetime: null - - # The unique ID of this appservice. - id: discord - # Appservice bot details. - bot: - # Username of the appservice bot. - username: discordbot - # Display name and avatar for bot. Set to "remove" to remove display name/avatar, leave empty - # to leave display name/avatar as-is. - displayname: Discord bridge bot - avatar: mxc://maunium.net/nIdEykemnwdisvHbpxflpDlC - - # Whether or not to receive ephemeral events via appservice transactions. - # Requires MSC2409 support (i.e. Synapse 1.22+). - ephemeral_events: true - - # Should incoming events be handled asynchronously? - # This may be necessary for large public instances with lots of messages going through. - # However, messages will not be guaranteed to be bridged in the same order they were sent in. - async_transactions: false - - # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. - as_token: "This value is generated when generating the registration" - hs_token: "This value is generated when generating the registration" - -# Bridge config -bridge: - # Localpart template of MXIDs for Discord users. - # {{.}} is replaced with the internal ID of the Discord user. - username_template: discord_{{.}} - # Displayname template for Discord users. This is also used as the room name in DMs if private_chat_portal_meta is enabled. - # Available variables: - # .ID - Internal user ID - # .Username - Legacy display/username on Discord - # .GlobalName - New displayname on Discord - # .Discriminator - The 4 numbers after the name on Discord - # .Bot - Whether the user is a bot - # .System - Whether the user is an official system user - # .Webhook - Whether the user is a webhook and is not an application - # .Application - Whether the user is an application - displayname_template: '{{if .Webhook}}Webhook{{else}}{{or .GlobalName .Username}}{{if .Bot}} (bot){{end}}{{end}}' - # Displayname template for Discord channels (bridged as rooms, or spaces when type=4). - # Available variables: - # .Name - Channel name, or user displayname (pre-formatted with displayname_template) in DMs. - # .ParentName - Parent channel name (used for categories). - # .GuildName - Guild name. - # .NSFW - Whether the channel is marked as NSFW. - # .Type - Channel type (see values at https://github.com/bwmarrin/discordgo/blob/v0.25.0/structs.go#L251-L267) - channel_name_template: '{{if or (eq .Type 3) (eq .Type 4)}}{{.Name}}{{else}}#{{.Name}}{{end}}' - # Displayname template for Discord guilds (bridged as spaces). - # Available variables: - # .Name - Guild name - guild_name_template: '{{.Name}}' - # Whether to explicitly set the avatar and room name for private chat portal rooms. - # If set to `default`, this will be enabled in encrypted rooms and disabled in unencrypted rooms. - # If set to `always`, all DM rooms will have explicit names and avatars set. - # If set to `never`, DM rooms will never have names and avatars set. - private_chat_portal_meta: default - - # Publicly accessible base URL that Discord can use to reach the bridge, used for avatars in relay mode. - # If not set, avatars will not be bridged. Only the /mautrix-discord/avatar/{server}/{id}/{hash} endpoint is used on this address. - # This should not have a trailing slash, the endpoint above will be appended to the provided address. - public_address: null - # A random key used to sign the avatar URLs. The bridge will only accept requests with a valid signature. - avatar_proxy_key: generate - - portal_message_buffer: 128 - - # Number of private channel portals to create on bridge startup. - # Other portals will be created when receiving messages. - startup_private_channel_create_limit: 5 - # Should the bridge send a read receipt from the bridge bot when a message has been sent to Discord? - delivery_receipts: false - # Whether the bridge should send the message status as a custom com.beeper.message_send_status event. - message_status_events: false - # Whether the bridge should send error notices via m.notice events when a message fails to bridge. - message_error_notices: true - # Should the bridge use space-restricted join rules instead of invite-only for guild rooms? - # This can avoid unnecessary invite events in guild rooms when members are synced in. - restricted_rooms: true - # Should the bridge automatically join the user to threads on Discord when the thread is opened on Matrix? - # This only works with clients that support thread read receipts (MSC3771 added in Matrix v1.4). - autojoin_thread_on_open: true - # Should inline fields in Discord embeds be bridged as HTML tables to Matrix? - # Tables aren't supported in all clients, but are the only way to emulate the Discord inline field UI. - embed_fields_as_tables: true - # Should guild channels be muted when the portal is created? This only meant for single-user instances, - # it won't mute it for all users if there are multiple Matrix users in the same Discord guild. - mute_channels_on_create: false - # Should the bridge update the m.direct account data event when double puppeting is enabled. - # Note that updating the m.direct event is not atomic (except with mautrix-asmux) - # and is therefore prone to race conditions. - sync_direct_chat_list: false - # Set this to true to tell the bridge to re-send m.bridge events to all rooms on the next run. - # This field will automatically be changed back to false after it, except if the config file is not writable. - resend_bridge_info: false - # Should incoming custom emoji reactions be bridged as mxc:// URIs? - # If set to false, custom emoji reactions will be bridged as the shortcode instead, and the image won't be available. - custom_emoji_reactions: true - # Should the bridge attempt to completely delete portal rooms when a channel is deleted on Discord? - # If true, the bridge will try to kick Matrix users from the room. Otherwise, the bridge only makes ghosts leave. - delete_portal_on_channel_delete: false - # Should the bridge delete all portal rooms when you leave a guild on Discord? - # This only applies if the guild has no other Matrix users on this bridge instance. - delete_guild_on_leave: true - # Whether or not created rooms should have federation enabled. - # If false, created portal rooms will never be federated. - federate_rooms: true - # Prefix messages from webhooks with the profile info? This can be used along with a custom displayname_template - # to better handle webhooks that change their name all the time (like ones used by bridges). - # - # This will use the fallback mode in MSC4144, which means clients that support MSC4144 will not show the prefix - # (and will instead show the name and avatar as the message sender). - prefix_webhook_messages: true - # Bridge webhook avatars? - enable_webhook_avatars: false - # Should the bridge upload media to the Discord CDN directly before sending the message when using a user token, - # like the official client does? The other option is sending the media in the message send request as a form part - # (which is always used by bots and webhooks). - use_discord_cdn_upload: true - # Proxy for Discord connections - proxy: - # Should mxc uris copied from Discord be cached? - # This can be `never` to never cache, `unencrypted` to only cache unencrypted mxc uris, or `always` to cache everything. - # If you have a media repo that generates non-unique mxc uris, you should set this to never. - cache_media: unencrypted - # Settings for converting Discord media to custom mxc:// URIs instead of reuploading. - # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html - direct_media: - # Should custom mxc:// URIs be used instead of reuploading media? - enabled: false - # The server name to use for the custom mxc:// URIs. - # This server name will effectively be a real Matrix server, it just won't implement anything other than media. - # You must either set up .well-known delegation from this domain to the bridge, or proxy the domain directly to the bridge. - server_name: discord-media.example.com - # Optionally a custom .well-known response. This defaults to `server_name:443` - well_known_response: - # The bridge supports MSC3860 media download redirects and will use them if the requester supports it. - # Optionally, you can force redirects and not allow proxying at all by setting this to false. - allow_proxy: true - # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file. - # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. - server_key: generate - # Settings for converting animated stickers. - animated_sticker: - # Format to which animated stickers should be converted. - # disable - No conversion, send as-is (lottie JSON) - # png - converts to non-animated png (fastest) - # gif - converts to animated gif - # webm - converts to webm video, requires ffmpeg executable with vp9 codec and webm container support - # webp - converts to animated webp, requires ffmpeg executable with webp codec/container support - target: webp - # Arguments for converter. All converters take width and height. - args: - width: 320 - height: 320 - fps: 25 # only for webm, webp and gif (2, 5, 10, 20 or 25 recommended) - # Servers to always allow double puppeting from - double_puppet_server_map: - example.com: https://example.com - # Allow using double puppeting from any server with a valid client .well-known file. - double_puppet_allow_discovery: false - # Shared secrets for https://github.com/devture/matrix-synapse-shared-secret-auth - # - # If set, double puppeting will be enabled automatically for local users - # instead of users having to find an access token and run `login-matrix` - # manually. - login_shared_secret_map: - example.com: foobar - - # The prefix for commands. Only required in non-management rooms. - command_prefix: '!discord' - # Messages sent upon joining a management room. - # Markdown is supported. The defaults are listed below. - management_room_text: - # Sent when joining a room. - welcome: "Hello, I'm a Discord bridge bot." - # Sent when joining a management room and the user is already logged in. - welcome_connected: "Use `help` for help." - # Sent when joining a management room and the user is not logged in. - welcome_unconnected: "Use `help` for help or `login` to log in." - # Optional extra text sent when joining a management room. - additional_help: "" - - # Settings for backfilling messages. - backfill: - # Limits for forward backfilling. - forward_limits: - # Initial backfill (when creating portal). 0 means backfill is disabled. - # A special unlimited value is not supported, you must set a limit. Initial backfill will - # fetch all messages first before backfilling anything, so high limits can take a lot of time. - initial: - dm: 0 - channel: 0 - thread: 0 - # Missed message backfill (on startup). - # 0 means backfill is disabled, -1 means fetch all messages since last bridged message. - # When using unlimited backfill (-1), messages are backfilled as they are fetched. - # With limits, all messages up to the limit are fetched first and backfilled afterwards. - missed: - dm: 0 - channel: 0 - thread: 0 - # Maximum members in a guild to enable backfilling. Set to -1 to disable limit. - # This can be used as a rough heuristic to disable backfilling in channels that are too active. - # Currently only applies to missed message backfill. - max_guild_members: -1 - - # End-to-bridge encryption support options. - # - # See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info. - encryption: - # Allow encryption, work in group chat rooms with e2ee enabled - allow: false - # Default to encryption, force-enable encryption in all portals the bridge creates - # This will cause the bridge bot to be in private chats for the encryption to work properly. - default: false - # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. - # Changing this option requires updating the appservice registration file. - appservice: false - # Whether to use MSC4190 instead of appservice login to create the bridge bot device. - # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. - # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). - # Changing this option requires updating the appservice registration file. - msc4190: false - # Require encryption, drop any unencrypted messages. - require: false - # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. - # You must use a client that supports requesting keys from other users to use this feature. - allow_key_sharing: false - # Should users mentions be in the event wire content to enable the server to send push notifications? - plaintext_mentions: false - # Options for deleting megolm sessions from the bridge. - delete_keys: - # Beeper-specific: delete outbound sessions when hungryserv confirms - # that the user has uploaded the key to key backup. - delete_outbound_on_ack: false - # Don't store outbound sessions in the inbound table. - dont_store_outbound: false - # Ratchet megolm sessions forward after decrypting messages. - ratchet_on_decrypt: false - # Delete fully used keys (index >= max_messages) after decrypting messages. - delete_fully_used_on_decrypt: false - # Delete previous megolm sessions from same device when receiving a new one. - delete_prev_on_new_session: false - # Delete megolm sessions received from a device when the device is deleted. - delete_on_device_delete: false - # Periodically delete megolm sessions when 2x max_age has passed since receiving the session. - periodically_delete_expired: false - # Delete inbound megolm sessions that don't have the received_at field used for - # automatic ratcheting and expired session deletion. This is meant as a migration - # to delete old keys prior to the bridge update. - delete_outdated_inbound: false - # What level of device verification should be required from users? - # - # Valid levels: - # unverified - Send keys to all device in the room. - # cross-signed-untrusted - Require valid cross-signing, but trust all cross-signing keys. - # cross-signed-tofu - Require valid cross-signing, trust cross-signing keys on first use (and reject changes). - # cross-signed-verified - Require valid cross-signing, plus a valid user signature from the bridge bot. - # Note that creating user signatures from the bridge bot is not currently possible. - # verified - Require manual per-device verification - # (currently only possible by modifying the `trust` column in the `crypto_device` database table). - verification_levels: - # Minimum level for which the bridge should send keys to when bridging messages from WhatsApp to Matrix. - receive: unverified - # Minimum level that the bridge should accept for incoming Matrix messages. - send: unverified - # Minimum level that the bridge should require for accepting key requests. - share: cross-signed-tofu - # Options for Megolm room key rotation. These options allow you to - # configure the m.room.encryption event content. See: - # https://spec.matrix.org/v1.3/client-server-api/#mroomencryption for - # more information about that event. - rotation: - # Enable custom Megolm room key rotation settings. Note that these - # settings will only apply to rooms created after this option is - # set. - enable_custom: false - # The maximum number of milliseconds a session should be used - # before changing it. The Matrix spec recommends 604800000 (a week) - # as the default. - milliseconds: 604800000 - # The maximum number of messages that should be sent with a given a - # session before changing it. The Matrix spec recommends 100 as the - # default. - messages: 100 - - # Disable rotating keys when a user's devices change? - # You should not enable this option unless you understand all the implications. - disable_device_change_key_rotation: false - - # Settings for provisioning API - provisioning: - # Prefix for the provisioning API paths. - prefix: /_matrix/provision - # Shared secret for authentication. If set to "generate", a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. - shared_secret: generate - # Enable debug API at /debug with provisioning authentication. - debug_endpoints: false - - # Permissions for using the bridge. - # Permitted values: - # relay - Talk through the relaybot (if enabled), no access otherwise - # user - Access to use the bridge to chat with a Discord account. - # admin - User level and some additional administration tools - # Permitted keys: - # * - All Matrix users - # domain - All users on that homeserver - # mxid - Specific user - permissions: - "*": relay - "example.com": user - "@admin:example.com": admin - -# Logging config. See https://github.com/tulir/zeroconfig for details. -logging: - min_level: debug - writers: - - type: stdout - format: pretty-colored - - type: file - format: json - filename: ./logs/mautrix-discord.log - max_size: 100 - max_backups: 10 - compress: true diff --git a/formatter.go b/formatter.go deleted file mode 100644 index 2112b04..0000000 --- a/formatter.go +++ /dev/null @@ -1,260 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2023 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "fmt" - "regexp" - "strings" - - "github.com/bwmarrin/discordgo" - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/extension" - "github.com/yuin/goldmark/parser" - "github.com/yuin/goldmark/util" - "go.mau.fi/util/variationselector" - "golang.org/x/exp/slices" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/format/mdext" - "maunium.net/go/mautrix/id" -) - -// escapeFixer is a hacky partial fix for the difference in escaping markdown, used with escapeReplacement -// -// Discord allows escaping with just one backslash, e.g. \__a__, -// but standard markdown requires both to be escaped (\_\_a__) -var escapeFixer = regexp.MustCompile(`\\(__[^_]|\*\*[^*])`) - -func escapeReplacement(s string) string { - return s[:2] + `\` + s[2:] -} - -// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine. -// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear). -type indentableParagraphParser struct { - parser.BlockParser -} - -var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()} - -func (b *indentableParagraphParser) CanAcceptIndentedLine() bool { - return true -} - -var removeFeaturesExceptLinks = []any{ - parser.NewListParser(), parser.NewListItemParser(), parser.NewHTMLBlockParser(), parser.NewRawHTMLParser(), - parser.NewSetextHeadingParser(), parser.NewThematicBreakParser(), - parser.NewCodeBlockParser(), -} -var removeFeaturesAndLinks = append(removeFeaturesExceptLinks, parser.NewLinkParser()) -var fixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500))) -var discordExtensions = goldmark.WithExtensions(extension.Strikethrough, mdext.SimpleSpoiler, mdext.DiscordUnderline, ExtDiscordEveryone, ExtDiscordTag) - -var discordRenderer = goldmark.New( - goldmark.WithParser(mdext.ParserWithoutFeatures(removeFeaturesAndLinks...)), - fixIndentedParagraphs, format.HTMLOptions, discordExtensions, -) -var discordRendererWithInlineLinks = goldmark.New( - goldmark.WithParser(mdext.ParserWithoutFeatures(removeFeaturesExceptLinks...)), - fixIndentedParagraphs, format.HTMLOptions, discordExtensions, -) - -func (portal *Portal) renderDiscordMarkdownOnlyHTMLNoUnwrap(text string, allowInlineLinks bool) string { - text = escapeFixer.ReplaceAllStringFunc(text, escapeReplacement) - - var buf strings.Builder - ctx := parser.NewContext() - ctx.Set(parserContextPortal, portal) - renderer := discordRenderer - if allowInlineLinks { - renderer = discordRendererWithInlineLinks - } - err := renderer.Convert([]byte(text), &buf, parser.WithContext(ctx)) - if err != nil { - panic(fmt.Errorf("markdown parser errored: %w", err)) - } - return buf.String() -} - -func (portal *Portal) renderDiscordMarkdownOnlyHTML(text string, allowInlineLinks bool) string { - return format.UnwrapSingleParagraph(portal.renderDiscordMarkdownOnlyHTMLNoUnwrap(text, allowInlineLinks)) -} - -const formatterContextPortalKey = "fi.mau.discord.portal" -const formatterContextAllowedMentionsKey = "fi.mau.discord.allowed_mentions" -const formatterContextInputAllowedMentionsKey = "fi.mau.discord.input_allowed_mentions" -const formatterContextInputAllowedLinkPreviewsKey = "fi.mau.discord.input_allowed_link_previews" - -func appendIfNotContains(arr []string, newItem string) []string { - for _, item := range arr { - if item == newItem { - return arr - } - } - return append(arr, newItem) -} - -func (br *DiscordBridge) pillConverter(displayname, mxid, eventID string, ctx format.Context) string { - if len(mxid) == 0 { - return displayname - } - if mxid[0] == '#' { - alias, err := br.Bot.ResolveAlias(id.RoomAlias(mxid)) - if err != nil { - return displayname - } - mxid = alias.RoomID.String() - } - if mxid[0] == '!' { - portal := br.GetPortalByMXID(id.RoomID(mxid)) - if portal != nil { - if eventID == "" { - //currentPortal := ctx[formatterContextPortalKey].(*Portal) - return fmt.Sprintf("<#%s>", portal.Key.ChannelID) - //if currentPortal.GuildID == portal.GuildID { - //} else if portal.GuildID != "" { - // return fmt.Sprintf("<#%s:%s:%s>", portal.Key.ChannelID, portal.GuildID, portal.Name) - //} else { - // // TODO is mentioning private channels possible at all? - //} - } else if msg := br.DB.Message.GetByMXID(portal.Key, id.EventID(eventID)); msg != nil { - guildID := portal.GuildID - if guildID == "" { - guildID = "@me" - } - return fmt.Sprintf("https://discord.com/channels/%s/%s/%s", guildID, msg.DiscordProtoChannelID(), msg.DiscordID) - } - } - } else if mxid[0] == '@' { - allowedMentions, _ := ctx.ReturnData[formatterContextInputAllowedMentionsKey].([]id.UserID) - if allowedMentions != nil && !slices.Contains(allowedMentions, id.UserID(mxid)) { - return displayname - } - mentions := ctx.ReturnData[formatterContextAllowedMentionsKey].(*discordgo.MessageAllowedMentions) - parsedID, ok := br.ParsePuppetMXID(id.UserID(mxid)) - if ok { - mentions.Users = appendIfNotContains(mentions.Users, parsedID) - return fmt.Sprintf("<@%s>", parsedID) - } - mentionedUser := br.GetUserByMXID(id.UserID(mxid)) - if mentionedUser != nil && mentionedUser.DiscordID != "" { - mentions.Users = appendIfNotContains(mentions.Users, mentionedUser.DiscordID) - return fmt.Sprintf("<@%s>", mentionedUser.DiscordID) - } - } - return displayname -} - -const discordLinkPattern = `https?://[^<\p{Zs}\x{feff}]*[^"'),.:;\]\p{Zs}\x{feff}]` - -// Discord links start with http:// or https://, contain at least two characters afterwards, -// don't contain < or whitespace anywhere, and don't end with "'),.:;] -// -// Zero-width whitespace is mostly in the Format category and is allowed, except \uFEFF isn't for some reason -var discordLinkRegex = regexp.MustCompile(discordLinkPattern) -var discordLinkRegexFull = regexp.MustCompile("^" + discordLinkPattern + "$") - -var discordMarkdownEscaper = strings.NewReplacer( - `\`, `\\`, - `_`, `\_`, - `*`, `\*`, - `~`, `\~`, - "`", "\\`", - `|`, `\|`, - `<`, `\<`, - `#`, `\#`, -) - -func escapeDiscordMarkdown(s string) string { - submatches := discordLinkRegex.FindAllStringIndex(s, -1) - if submatches == nil { - return discordMarkdownEscaper.Replace(s) - } - var builder strings.Builder - offset := 0 - for _, match := range submatches { - start := match[0] - end := match[1] - builder.WriteString(discordMarkdownEscaper.Replace(s[offset:start])) - builder.WriteString(s[start:end]) - offset = end - } - builder.WriteString(discordMarkdownEscaper.Replace(s[offset:])) - return builder.String() -} - -var matrixHTMLParser = &format.HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - HorizontalLine: "\n---\n", - ItalicConverter: func(s string, ctx format.Context) string { - return fmt.Sprintf("*%s*", s) - }, - UnderlineConverter: func(s string, ctx format.Context) string { - return fmt.Sprintf("__%s__", s) - }, - TextConverter: func(s string, ctx format.Context) string { - if ctx.TagStack.Has("pre") || ctx.TagStack.Has("code") { - // If we're in a code block, don't escape markdown - return s - } - return escapeDiscordMarkdown(s) - }, - SpoilerConverter: func(text, reason string, ctx format.Context) string { - if reason != "" { - return fmt.Sprintf("(%s) ||%s||", reason, text) - } - return fmt.Sprintf("||%s||", text) - }, - LinkConverter: func(text, href string, ctx format.Context) string { - linkPreviews := ctx.ReturnData[formatterContextInputAllowedLinkPreviewsKey].([]string) - allowPreview := linkPreviews == nil || slices.Contains(linkPreviews, href) - if text == href { - if !allowPreview { - return fmt.Sprintf("<%s>", text) - } - return text - } else if !discordLinkRegexFull.MatchString(href) { - return fmt.Sprintf("%s (%s)", escapeDiscordMarkdown(text), escapeDiscordMarkdown(href)) - } else if !allowPreview { - return fmt.Sprintf("[%s](<%s>)", escapeDiscordMarkdown(text), href) - } else { - return fmt.Sprintf("[%s](%s)", escapeDiscordMarkdown(text), href) - } - }, -} - -func (portal *Portal) parseMatrixHTML(content *event.MessageEventContent, allowedLinkPreviews []string) (string, *discordgo.MessageAllowedMentions) { - allowedMentions := &discordgo.MessageAllowedMentions{ - Parse: []discordgo.AllowedMentionType{}, - Users: []string{}, - RepliedUser: true, - } - if content.Format == event.FormatHTML && len(content.FormattedBody) > 0 { - ctx := format.NewContext() - ctx.ReturnData[formatterContextInputAllowedLinkPreviewsKey] = allowedLinkPreviews - ctx.ReturnData[formatterContextPortalKey] = portal - ctx.ReturnData[formatterContextAllowedMentionsKey] = allowedMentions - if content.Mentions != nil { - ctx.ReturnData[formatterContextInputAllowedMentionsKey] = content.Mentions.UserIDs - } - return variationselector.FullyQualify(matrixHTMLParser.Parse(content.FormattedBody, ctx)), allowedMentions - } else { - return variationselector.FullyQualify(escapeDiscordMarkdown(content.Body)), allowedMentions - } -} diff --git a/formatter_test.go b/formatter_test.go deleted file mode 100644 index c05f95b..0000000 --- a/formatter_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEscapeDiscordMarkdown(t *testing.T) { - type escapeTest struct { - name string - input string - expected string - } - - tests := []escapeTest{ - {"Simple text", "Lorem ipsum dolor sit amet, consectetuer adipiscing elit.", "Lorem ipsum dolor sit amet, consectetuer adipiscing elit."}, - {"Backslash", `foo\bar`, `foo\\bar`}, - {"Underscore", `foo_bar`, `foo\_bar`}, - {"Asterisk", `foo*bar`, `foo\*bar`}, - {"Tilde", `foo~bar`, `foo\~bar`}, - {"Backtick", "foo`bar", "foo\\`bar"}, - {"Forward tick", `foo´bar`, `foo´bar`}, - {"Pipe", `foo|bar`, `foo\|bar`}, - {"Less than", `foobar`, `foo>bar`}, - {"Multiple things", `\_*~|`, `\\\_\*\~\|`}, - {"URL", `https://example.com/foo_bar`, `https://example.com/foo_bar`}, - {"Multiple URLs", `hello_world https://example.com/foo_bar *testing* https://a_b_c/*def*`, `hello\_world https://example.com/foo_bar \*testing\* https://a_b_c/*def*`}, - {"URL ends with no-break zero-width space", "https://example.com\ufefffoo_bar", "https://example.com\ufefffoo\\_bar"}, - {"URL ends with less than", `https://example.com github.com/beeper/discordgo v0.0.0-20250607214857-f23a8518ece2 +replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20260429033617-4dea361a9bb6 diff --git a/go.sum b/go.sum index 51356d2..91d02b6 100644 --- a/go.sum +++ b/go.sum @@ -1,68 +1,70 @@ -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/beeper/discordgo v0.0.0-20250607214857-f23a8518ece2 h1:8lgTjYGSIlS90f0jiFfEC4UwxCq9FiUo4dKwjknbupQ= -github.com/beeper/discordgo v0.0.0-20250607214857-f23a8518ece2/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beeper/discordgo v0.0.0-20260429033617-4dea361a9bb6 h1:L956OBNYiTXMSNzJ1cADxf395/IXxXrSqD1kC97ufjA= +github.com/beeper/discordgo v0.0.0-20260429033617-4dea361a9bb6/go.mod h1:lioivnibvB8j1KcF5TVpLdRLKCKHtcl8A03GpxRCre4= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= +github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= -github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= -github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= +github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= -github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= +github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= -github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb h1:Is+6vDKgINRy9KHodvi7NElxoDaWA8sc2S3cF3+QWjs= -go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb/go.mod h1:tiBX6nxVSOjU89jVQ7wBh3P8KjM26Lv1k7/I5QdSvBw= -go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= -go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= +github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 h1:YPEmc+li7TF6C9AdRTcSLMb6yCHdF27/wNT7kFLIVNg= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25/go.mod h1:jE9FfhbgEgAwxei6lomO9v8zdCIATcquONUu4vjRwSs= +go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= +go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -71,7 +73,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= -maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.3-0.20250810202616-6bc5698125c2 h1:8PdwIklPNHTL/tI9tG2S0Tf9UvAgRt8yZjJbjV0XIpA= -maunium.net/go/mautrix v0.16.3-0.20250810202616-6bc5698125c2/go.mod h1:gCgLw/4c1a8QsiOWTdUdXlt5cYdE0rJ9wLeZQKPD58Q= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 h1:zNC9eVAhw8FhKpM3AxNAh/iy75UEYX91uJUvqqAYlvo= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4/go.mod h1:3sOGhXi3P1V6/NruTA0gujkvTypXVUraWktCuTGyDuM= diff --git a/guildportal.go b/guildportal.go deleted file mode 100644 index d7be670..0000000 --- a/guildportal.go +++ /dev/null @@ -1,335 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "errors" - "fmt" - "sync" - - log "maunium.net/go/maulogger/v2" - "maunium.net/go/maulogger/v2/maulogadapt" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "github.com/bwmarrin/discordgo" - - "go.mau.fi/mautrix-discord/config" - "go.mau.fi/mautrix-discord/database" -) - -type Guild struct { - *database.Guild - - bridge *DiscordBridge - log log.Logger - - roomCreateLock sync.Mutex -} - -func (br *DiscordBridge) loadGuild(dbGuild *database.Guild, id string, createIfNotExist bool) *Guild { - if dbGuild == nil { - if id == "" || !createIfNotExist { - return nil - } - - dbGuild = br.DB.Guild.New() - dbGuild.ID = id - dbGuild.Insert() - } - - guild := br.NewGuild(dbGuild) - - br.guildsByID[guild.ID] = guild - if guild.MXID != "" { - br.guildsByMXID[guild.MXID] = guild - } - - return guild -} - -func (br *DiscordBridge) GetGuildByMXID(mxid id.RoomID) *Guild { - br.guildsLock.Lock() - defer br.guildsLock.Unlock() - - portal, ok := br.guildsByMXID[mxid] - if !ok { - return br.loadGuild(br.DB.Guild.GetByMXID(mxid), "", false) - } - - return portal -} - -func (br *DiscordBridge) GetGuildByID(id string, createIfNotExist bool) *Guild { - br.guildsLock.Lock() - defer br.guildsLock.Unlock() - - guild, ok := br.guildsByID[id] - if !ok { - return br.loadGuild(br.DB.Guild.GetByID(id), id, createIfNotExist) - } - - return guild -} - -func (br *DiscordBridge) GetAllGuilds() []*Guild { - return br.dbGuildsToGuilds(br.DB.Guild.GetAll()) -} - -func (br *DiscordBridge) dbGuildsToGuilds(dbGuilds []*database.Guild) []*Guild { - br.guildsLock.Lock() - defer br.guildsLock.Unlock() - - output := make([]*Guild, len(dbGuilds)) - for index, dbGuild := range dbGuilds { - if dbGuild == nil { - continue - } - - guild, ok := br.guildsByID[dbGuild.ID] - if !ok { - guild = br.loadGuild(dbGuild, "", false) - } - - output[index] = guild - } - - return output -} - -func (br *DiscordBridge) NewGuild(dbGuild *database.Guild) *Guild { - guild := &Guild{ - Guild: dbGuild, - bridge: br, - log: br.Log.Sub(fmt.Sprintf("Guild/%s", dbGuild.ID)), - } - - return guild -} - -func (guild *Guild) getBridgeInfo() (string, event.BridgeEventContent) { - bridgeInfo := event.BridgeEventContent{ - BridgeBot: guild.bridge.Bot.UserID, - Creator: guild.bridge.Bot.UserID, - Protocol: event.BridgeInfoSection{ - ID: "discordgo", - DisplayName: "Discord", - AvatarURL: guild.bridge.Config.AppService.Bot.ParsedAvatar.CUString(), - ExternalURL: "https://discord.com/", - }, - Channel: event.BridgeInfoSection{ - ID: guild.ID, - DisplayName: guild.Name, - AvatarURL: guild.AvatarURL.CUString(), - }, - } - bridgeInfoStateKey := fmt.Sprintf("fi.mau.discord://discord/%s", guild.ID) - return bridgeInfoStateKey, bridgeInfo -} - -func (guild *Guild) UpdateBridgeInfo() { - if len(guild.MXID) == 0 { - guild.log.Debugln("Not updating bridge info: no Matrix room created") - return - } - guild.log.Debugln("Updating bridge info...") - stateKey, content := guild.getBridgeInfo() - _, err := guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateBridge, stateKey, content) - if err != nil { - guild.log.Warnln("Failed to update m.bridge:", err) - } - // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - _, err = guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateHalfShotBridge, stateKey, content) - if err != nil { - guild.log.Warnln("Failed to update uk.half-shot.bridge:", err) - } -} - -func (guild *Guild) CreateMatrixRoom(user *User, meta *discordgo.Guild) error { - guild.roomCreateLock.Lock() - defer guild.roomCreateLock.Unlock() - if guild.MXID != "" { - return nil - } - guild.log.Infoln("Creating Matrix room for guild") - guild.UpdateInfo(user, meta) - - bridgeInfoStateKey, bridgeInfo := guild.getBridgeInfo() - - initialState := []*event.Event{{ - Type: event.StateBridge, - Content: event.Content{Parsed: bridgeInfo}, - StateKey: &bridgeInfoStateKey, - }, { - // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - Type: event.StateHalfShotBridge, - Content: event.Content{Parsed: bridgeInfo}, - StateKey: &bridgeInfoStateKey, - }} - - if !guild.AvatarURL.IsEmpty() { - initialState = append(initialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{Parsed: &event.RoomAvatarEventContent{ - URL: guild.AvatarURL, - }}, - }) - } - - creationContent := map[string]interface{}{ - "type": event.RoomTypeSpace, - } - if !guild.bridge.Config.Bridge.FederateRooms { - creationContent["m.federate"] = false - } - - resp, err := guild.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ - Visibility: "private", - Name: guild.Name, - Preset: "private_chat", - InitialState: initialState, - CreationContent: creationContent, - RoomVersion: "11", - }) - if err != nil { - guild.log.Warnln("Failed to create room:", err) - return err - } - - guild.MXID = resp.RoomID - guild.NameSet = true - guild.AvatarSet = !guild.AvatarURL.IsEmpty() - guild.Update() - guild.bridge.guildsLock.Lock() - guild.bridge.guildsByMXID[guild.MXID] = guild - guild.bridge.guildsLock.Unlock() - guild.log.Infoln("Matrix room created:", guild.MXID) - - user.ensureInvited(nil, guild.MXID, false, true) - - return nil -} - -func (guild *Guild) UpdateInfo(source *User, meta *discordgo.Guild) *discordgo.Guild { - if meta.Unavailable { - guild.log.Debugfln("Ignoring unavailable guild update") - return meta - } - changed := false - changed = guild.UpdateName(meta) || changed - changed = guild.UpdateAvatar(meta.Icon) || changed - if changed { - guild.UpdateBridgeInfo() - guild.Update() - } - source.ensureInvited(nil, guild.MXID, false, false) - return meta -} - -func (guild *Guild) UpdateName(meta *discordgo.Guild) bool { - name := guild.bridge.Config.Bridge.FormatGuildName(config.GuildNameParams{ - Name: meta.Name, - }) - if guild.PlainName == meta.Name && guild.Name == name && (guild.NameSet || guild.MXID == "") { - return false - } - guild.log.Debugfln("Updating name %q -> %q", guild.Name, name) - guild.Name = name - guild.PlainName = meta.Name - guild.NameSet = false - if guild.MXID != "" { - _, err := guild.bridge.Bot.SetRoomName(guild.MXID, guild.Name) - if err != nil { - guild.log.Warnln("Failed to update room name: %s", err) - } else { - guild.NameSet = true - } - } - return true -} - -func (guild *Guild) UpdateAvatar(iconID string) bool { - if guild.Avatar == iconID && (iconID == "") == guild.AvatarURL.IsEmpty() && (guild.AvatarSet || guild.MXID == "") { - return false - } - guild.log.Debugfln("Updating avatar %q -> %q", guild.Avatar, iconID) - guild.AvatarSet = false - guild.Avatar = iconID - guild.AvatarURL = id.ContentURI{} - if guild.Avatar != "" { - // TODO direct media support - copied, err := guild.bridge.copyAttachmentToMatrix(guild.bridge.Bot, discordgo.EndpointGuildIcon(guild.ID, iconID), false, AttachmentMeta{ - AttachmentID: fmt.Sprintf("guild_avatar/%s/%s", guild.ID, iconID), - }) - if err != nil { - guild.log.Warnfln("Failed to reupload guild avatar %s: %v", iconID, err) - return true - } - guild.AvatarURL = copied.MXC - } - if guild.MXID != "" { - _, err := guild.bridge.Bot.SetRoomAvatar(guild.MXID, guild.AvatarURL) - if err != nil { - guild.log.Warnln("Failed to update room avatar:", err) - } else { - guild.AvatarSet = true - } - } - return true -} - -func (guild *Guild) cleanup() { - if guild.MXID == "" { - return - } - intent := guild.bridge.Bot - if guild.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - err := intent.BeeperDeleteRoom(guild.MXID) - if err != nil && !errors.Is(err, mautrix.MNotFound) { - guild.log.Errorfln("Failed to delete %s using hungryserv yeet endpoint: %v", guild.MXID, err) - } - return - } - guild.bridge.cleanupRoom(intent, guild.MXID, false, *maulogadapt.MauAsZero(guild.log)) -} - -func (guild *Guild) RemoveMXID() { - guild.bridge.guildsLock.Lock() - defer guild.bridge.guildsLock.Unlock() - if guild.MXID == "" { - return - } - delete(guild.bridge.guildsByMXID, guild.MXID) - guild.MXID = "" - guild.AvatarSet = false - guild.NameSet = false - guild.BridgingMode = database.GuildBridgeNothing - guild.Update() -} - -func (guild *Guild) Delete() { - guild.Guild.Delete() - guild.bridge.guildsLock.Lock() - delete(guild.bridge.guildsByID, guild.ID) - if guild.MXID != "" { - delete(guild.bridge.guildsByMXID, guild.MXID) - } - guild.bridge.guildsLock.Unlock() - -} diff --git a/main.go b/main.go deleted file mode 100644 index 5b6f635..0000000 --- a/main.go +++ /dev/null @@ -1,208 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - _ "embed" - "net/http" - "sync" - - "go.mau.fi/util/configupgrade" - "go.mau.fi/util/exsync" - "golang.org/x/sync/semaphore" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/commands" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/config" - "go.mau.fi/mautrix-discord/database" -) - -// Information to find out exactly which commit the bridge was built from. -// These are filled at build time with the -X linker flag. -var ( - Tag = "unknown" - Commit = "unknown" - BuildTime = "unknown" -) - -//go:embed example-config.yaml -var ExampleConfig string - -type DiscordBridge struct { - bridge.Bridge - - Config *config.Config - DB *database.Database - - DMA *DirectMediaAPI - provisioning *ProvisioningAPI - - usersByMXID map[id.UserID]*User - usersByID map[string]*User - usersLock sync.Mutex - - managementRooms map[id.RoomID]*User - managementRoomsLock sync.Mutex - - portalsByMXID map[id.RoomID]*Portal - portalsByID map[database.PortalKey]*Portal - portalsLock sync.Mutex - - threadsByID map[string]*Thread - threadsByRootMXID map[id.EventID]*Thread - threadsByCreationNoticeMXID map[id.EventID]*Thread - threadsLock sync.Mutex - - guildsByMXID map[id.RoomID]*Guild - guildsByID map[string]*Guild - guildsLock sync.Mutex - - puppets map[string]*Puppet - puppetsByCustomMXID map[id.UserID]*Puppet - puppetsLock sync.Mutex - - attachmentTransfers *exsync.Map[attachmentKey, *exsync.ReturnableOnce[*database.File]] - parallelAttachmentSemaphore *semaphore.Weighted -} - -func (br *DiscordBridge) GetExampleConfig() string { - return ExampleConfig -} - -func (br *DiscordBridge) GetConfigPtr() interface{} { - br.Config = &config.Config{ - BaseConfig: &br.Bridge.Config, - } - br.Config.BaseConfig.Bridge = &br.Config.Bridge - return br.Config -} - -func (br *DiscordBridge) Init() { - br.CommandProcessor = commands.NewProcessor(&br.Bridge) - br.RegisterCommands() - br.EventProcessor.On(event.StateTombstone, br.HandleTombstone) - - matrixHTMLParser.PillConverter = br.pillConverter - - br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database")) - discordLog = br.ZLog.With().Str("component", "discordgo").Logger() -} - -func (br *DiscordBridge) Start() { - if br.Config.Bridge.Provisioning.SharedSecret != "disable" { - br.provisioning = newProvisioningAPI(br) - } - if br.Config.Bridge.PublicAddress != "" { - br.AS.Router.HandleFunc("/mautrix-discord/avatar/{server}/{mediaID}/{checksum}", br.serveMediaProxy).Methods(http.MethodGet) - } - br.DMA = newDirectMediaAPI(br) - br.WaitWebsocketConnected() - go br.startUsers() -} - -func (br *DiscordBridge) Stop() { - for _, user := range br.usersByMXID { - if user.Session == nil { - continue - } - - br.Log.Debugln("Disconnecting", user.MXID) - user.Session.Close() - } -} - -func (br *DiscordBridge) GetIPortal(mxid id.RoomID) bridge.Portal { - p := br.GetPortalByMXID(mxid) - if p == nil { - return nil - } - return p -} - -func (br *DiscordBridge) GetIUser(mxid id.UserID, create bool) bridge.User { - p := br.GetUserByMXID(mxid) - if p == nil { - return nil - } - return p -} - -func (br *DiscordBridge) IsGhost(mxid id.UserID) bool { - _, isGhost := br.ParsePuppetMXID(mxid) - return isGhost -} - -func (br *DiscordBridge) GetIGhost(mxid id.UserID) bridge.Ghost { - p := br.GetPuppetByMXID(mxid) - if p == nil { - return nil - } - return p -} - -func (br *DiscordBridge) CreatePrivatePortal(id id.RoomID, user bridge.User, ghost bridge.Ghost) { - //TODO implement -} - -func main() { - br := &DiscordBridge{ - usersByMXID: make(map[id.UserID]*User), - usersByID: make(map[string]*User), - - managementRooms: make(map[id.RoomID]*User), - - portalsByMXID: make(map[id.RoomID]*Portal), - portalsByID: make(map[database.PortalKey]*Portal), - - threadsByID: make(map[string]*Thread), - threadsByRootMXID: make(map[id.EventID]*Thread), - threadsByCreationNoticeMXID: make(map[id.EventID]*Thread), - - guildsByID: make(map[string]*Guild), - guildsByMXID: make(map[id.RoomID]*Guild), - - puppets: make(map[string]*Puppet), - puppetsByCustomMXID: make(map[id.UserID]*Puppet), - - attachmentTransfers: exsync.NewMap[attachmentKey, *exsync.ReturnableOnce[*database.File]](), - parallelAttachmentSemaphore: semaphore.NewWeighted(3), - } - br.Bridge = bridge.Bridge{ - Name: "mautrix-discord", - URL: "https://github.com/mautrix/discord", - Description: "A Matrix-Discord puppeting bridge.", - Version: "0.7.5", - ProtocolName: "Discord", - BeeperServiceName: "discordgo", - BeeperNetworkName: "discord", - - CryptoPickleKey: "maunium.net/go/mautrix-whatsapp", - - ConfigUpgrader: &configupgrade.StructUpgrader{ - SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade), - Blocks: config.SpacedBlocks, - Base: ExampleConfig, - }, - - Child: br, - } - br.InitVersion(Tag, Commit, BuildTime) - - br.Main() -} diff --git a/pkg/connector/backfill.go b/pkg/connector/backfill.go new file mode 100644 index 0000000..95f1c2b --- /dev/null +++ b/pkg/connector/backfill.go @@ -0,0 +1,227 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "slices" + "strconv" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +var ( + _ bridgev2.BackfillingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.BackfillingNetworkAPIWithLimits = (*DiscordClient)(nil) +) + +func (d *DiscordClient) FetchMessages(ctx context.Context, fetchParams bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { + if !d.IsLoggedIn() { + return nil, bridgev2.ErrNotLoggedIn + } + + parentChannelID := discordid.ParseChannelPortalID(fetchParams.Portal.ID) + channelID := parentChannelID + threadChannelID := "" + var knownThreadRootID *networkid.MessageID + + if fetchParams.ThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, discordid.ParseMessageID(fetchParams.ThreadRoot)) + if err != nil { + return nil, err + } + if thread == nil { + return &bridgev2.FetchMessagesResponse{ + Messages: nil, + HasMore: false, + }, nil + } + threadChannelID = thread.ThreadChannelID + channelID = threadChannelID + threadRootID := fetchParams.ThreadRoot + knownThreadRootID = &threadRootID + } + + guildID := fetchParams.Portal.Metadata.(*discordid.PortalMetadata).GuildID + refererOpt := makeDiscordReferer(guildID, parentChannelID, threadChannelID) + + log := zerolog.Ctx(ctx).With(). + Str("action", "fetch messages"). + Str("channel_id", channelID). + Str("thread_channel_id", threadChannelID). + Int("desired_count", fetchParams.Count). + Bool("forward", fetchParams.Forward).Logger() + ctx = log.WithContext(ctx) + + var beforeID string + var afterID string + + if fetchParams.AnchorMessage != nil { + anchorID := discordid.ParseMessageID(fetchParams.AnchorMessage.ID) + + if fetchParams.Forward { + afterID = anchorID + } else { + beforeID = anchorID + } + } + + // ChannelMessages returns messages ordered from newest to oldest. + count := min(fetchParams.Count, 100) + log.Debug().Msg("Fetching channel history for backfill") + msgs, err := d.Session.ChannelMessages(channelID, count, beforeID, afterID, "", refererOpt) + if err != nil { + return nil, d.tryWrappingError(ctx, err) + } + + // Update our user cache with all of the users present in the response. This + // indirectly makes `GetUserInfo` on `DiscordClient` return the information + // we've fetched above. + cachedDiscordUserIDs := d.userCache.UpdateWithMessages(msgs) + + { + log := zerolog.Ctx(ctx).With(). + Str("action", "update ghosts via fetched messages"). + Logger() + ctx := log.WithContext(ctx) + + // Update/create all of the ghosts for the users involved. This lets us + // set a correct per-message profile on each message, even for users + // that we've never seen until now. + for _, discordUserID := range cachedDiscordUserIDs { + + ghost, err := d.connector.Bridge.GetGhostByID(ctx, discordid.MakeUserID(discordUserID)) + if err != nil { + log.Err(err).Str("ghost_id", discordUserID). + Msg("Failed to get ghost associated with message") + continue + } + ghost.UpdateInfoIfNecessary(ctx, d.UserLogin, bridgev2.RemoteEventMessage) + } + } + + converted := make([]*bridgev2.BackfillMessage, 0, len(msgs)) + provablyReadMessageCount := 0 + for _, msg := range msgs { + parsedMsgID, _ := strconv.ParseInt(msg.ID, 10, 64) + msgTs, _ := discordgo.SnowflakeTimestamp(msg.ID) + + readState := d.readStateForID(msg.ChannelID) + if readState != nil { + lastAckedMsgID, _ := strconv.ParseInt(string(readState.LastMessageID), 10, 64) + if lastAckedMsgID >= parsedMsgID { + provablyReadMessageCount += 1 + } + } + + // NOTE: For now, we aren't backfilling reactions. This is because: + // + // - Discord does not provide enough historical reaction data in the + // response from the message history endpoint to construct valid + // BackfillReactions. + // - Fetching the reaction data would be prohibitively expensive for + // messages with many reactions. Messages in large guilds can have + // tens of thousands of reactions. + // - Indicating aggregated child events[1] from BackfillMessage doesn't + // seem possible due to how portal backfilling batching currently + // works. + // + // [1]: https://spec.matrix.org/v1.16/client-server-api/#reference-relations + // + // It might be worth fetching the reaction data anyways if we observe + // a small overall number of reactions. + sender := d.makeEventSender(msg.Author) + + // Use the ghost's intent, falling back to the bridge's. + ghost, err := d.connector.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + log.Err(err).Msg("Failed to look up ghost while converting backfilled message") + } + var intent bridgev2.MatrixAPI + if ghost == nil { + intent = fetchParams.Portal.Bridge.Bot + } else { + intent = ghost.Intent + } + + converted = append(converted, &bridgev2.BackfillMessage{ + ID: discordid.MakeMessageID(msg.ID), + ConvertedMessage: d.connector.MsgConv.ToMatrix(ctx, fetchParams.Portal, intent, d.UserLogin, d.Session, msg, knownThreadRootID), + Sender: sender, + Timestamp: msgTs, + StreamOrder: parsedMsgID, + }) + + if fetchParams.ThreadRoot == "" && msg.Flags&discordgo.MessageFlagsHasThread != 0 { + latest := "" + if msg.Thread != nil { + latest = msg.Thread.LastMessageID + } + if latest == "" { + latest = msg.ID + } + converted[len(converted)-1].ShouldBackfillThread = true + converted[len(converted)-1].LastThreadMessage = discordid.MakeMessageID(latest) + if err := d.upsertThreadInfoFromMessage(ctx, msg); err != nil { + log.Err(err).Str("message_id", msg.ID).Msg("Failed to store thread info while backfilling") + } + } + } + // FetchMessagesResponse expects messages to always be ordered from oldest to newest. + slices.Reverse(converted) + + log.Debug(). + Int("converted_count", len(converted)). + Int("provably_read_message_count", provablyReadMessageCount). + Msg("Finished fetching and converting, returning backfill response") + + // It doesn't seem like we can express unreadness for every message, so do + // it for the entire batch. A single unread message makes the entire batch + // unread, even if some messages were actually read. + entireBatchWasProvablyRead := len(msgs) == provablyReadMessageCount + return &bridgev2.FetchMessagesResponse{ + Messages: converted, + Forward: fetchParams.Forward, + MarkRead: entireBatchWasProvablyRead, + // This might not actually be true if the channel's total number of messages is itself a multiple + // of `count`, but that's probably okay. + HasMore: len(msgs) == count, + }, nil +} + +func (d *DiscordClient) GetBackfillMaxBatchCount( + _ context.Context, + portal *bridgev2.Portal, + _ *database.BackfillTask, +) int { + backfillQueueConfig := d.connector.Bridge.Config.Backfill.Queue + + switch portal.RoomType { + case database.RoomTypeDM: + return backfillQueueConfig.GetOverride("dm") + case database.RoomTypeGroupDM: + return backfillQueueConfig.GetOverride("group_dm") + default: + return backfillQueueConfig.GetOverride("channel") + } +} diff --git a/pkg/connector/capabilities.go b/pkg/connector/capabilities.go new file mode 100644 index 0000000..b770769 --- /dev/null +++ b/pkg/connector/capabilities.go @@ -0,0 +1,166 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + + "go.mau.fi/util/ffmpeg" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +var DiscordGeneralCaps = &bridgev2.NetworkGeneralCapabilities{ + Provisioning: bridgev2.ProvisioningCapabilities{ + ResolveIdentifier: bridgev2.ResolveIdentifierCapabilities{}, + GroupCreation: map[string]bridgev2.GroupTypeCapabilities{}, + }, +} + +func (d *DiscordConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities { + return DiscordGeneralCaps +} + +func (d *DiscordConnector) GetBridgeInfoVersion() (info, caps int) { + return 1, 4 +} + +/*func supportedIfFFmpeg() event.CapabilitySupportLevel { + if ffmpeg.Supported() { + return event.CapLevelPartialSupport + } + return event.CapLevelRejected +}*/ + +func capID() string { + base := "fi.mau.discord.capabilities.2026_03_18" + if ffmpeg.Supported() { + return base + "+ffmpeg" + } + return base +} + +// TODO: This limit is increased depending on user subscription status (Discord Nitro). +const MaxTextLength = 2000 + +// TODO: This limit is increased depending on user subscription status (Discord Nitro). +// TODO: Verify this figure (10 MiB). +const MaxFileSize = 10485760 + +var discordCaps = &event.RoomFeatures{ + ID: capID(), + Reply: event.CapLevelFullySupported, + Reaction: event.CapLevelFullySupported, + Edit: event.CapLevelFullySupported, + Delete: event.CapLevelFullySupported, + Formatting: event.FormattingFeatureMap{ + event.FmtBold: event.CapLevelFullySupported, + event.FmtItalic: event.CapLevelFullySupported, + event.FmtStrikethrough: event.CapLevelFullySupported, + event.FmtInlineCode: event.CapLevelFullySupported, + event.FmtCodeBlock: event.CapLevelFullySupported, + event.FmtSyntaxHighlighting: event.CapLevelFullySupported, + event.FmtBlockquote: event.CapLevelFullySupported, + event.FmtInlineLink: event.CapLevelFullySupported, + event.FmtUserLink: event.CapLevelFullySupported, + event.FmtRoomLink: event.CapLevelUnsupported, // TODO: Support. + event.FmtEventLink: event.CapLevelUnsupported, // TODO: Support. + event.FmtAtRoomMention: event.CapLevelUnsupported, // TODO: Support. + event.FmtUnorderedList: event.CapLevelFullySupported, + event.FmtOrderedList: event.CapLevelFullySupported, + event.FmtListStart: event.CapLevelFullySupported, + event.FmtListJumpValue: event.CapLevelUnsupported, + event.FmtCustomEmoji: event.CapLevelUnsupported, // TODO: Support. + }, + File: event.FileFeatureMap{ + event.MsgImage: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "image/jpeg": event.CapLevelFullySupported, + "image/png": event.CapLevelFullySupported, + "image/gif": event.CapLevelFullySupported, + "image/webp": event.CapLevelFullySupported, + "image/*": event.CapLevelPartialSupport, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + event.MsgVideo: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "video/mp4": event.CapLevelFullySupported, + "video/webm": event.CapLevelFullySupported, + "video/*": event.CapLevelPartialSupport, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + event.MsgAudio: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "audio/mpeg": event.CapLevelFullySupported, + "audio/webm": event.CapLevelFullySupported, + "audio/wav": event.CapLevelFullySupported, + "audio/*": event.CapLevelPartialSupport, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + event.CapMsgVoice: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "audio/ogg; codecs=opus": event.CapLevelFullySupported, + "audio/ogg": event.CapLevelFullySupported, + "audio/webm; codecs=opus": event.CapLevelFullySupported, + "audio/webm": event.CapLevelFullySupported, + "audio/*": event.CapLevelPartialSupport, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + event.MsgFile: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "*/*": event.CapLevelFullySupported, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + event.CapMsgGIF: { + MimeTypes: map[string]event.CapabilitySupportLevel{ + "image/gif": event.CapLevelFullySupported, + }, + Caption: event.CapLevelFullySupported, + MaxCaptionLength: MaxTextLength, + MaxSize: MaxFileSize, + }, + }, + LocationMessage: event.CapLevelUnsupported, + MaxTextLength: MaxTextLength, + Thread: event.CapLevelPartialSupport, +} + +func (d *DiscordClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal) *event.RoomFeatures { + if portal.Metadata.(*discordid.PortalMetadata).GuildID == "" { + caps := discordCaps.Clone() + caps.Thread = event.CapLevelUnsupported + return caps + } + return discordCaps +} diff --git a/pkg/connector/chatinfo.go b/pkg/connector/chatinfo.go new file mode 100644 index 0000000..f9497b9 --- /dev/null +++ b/pkg/connector/chatinfo.go @@ -0,0 +1,271 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +// getGuildSpaceInfo computes the [bridgev2.ChatInfo] for a guild space. +func (d *DiscordClient) getGuildSpaceInfo(_ctx context.Context, guild *discordgo.Guild) (*bridgev2.ChatInfo, error) { + selfEvtSender := d.selfEventSender() + + return &bridgev2.ChatInfo{ + Name: &guild.Name, + Topic: nil, + Members: &bridgev2.ChatMemberList{ + MemberMap: map[networkid.UserID]bridgev2.ChatMember{ + selfEvtSender.Sender: {EventSender: selfEvtSender}, + }, + // As recommended by the spec, prohibit normal events by setting + // events_default to a suitably high number. + PowerLevels: &bridgev2.PowerLevelOverrides{EventsDefault: ptr.Ptr(100)}, + }, + Avatar: d.makeAvatarForGuild(guild), + Type: ptr.Ptr(database.RoomTypeSpace), + }, nil +} + +func portalIsPrivate(p *bridgev2.Portal) bool { + return p.RoomType == database.RoomTypeDM || p.RoomType == database.RoomTypeGroupDM +} + +func channelIsPrivate(ch *discordgo.Channel) bool { + return ch.Type == discordgo.ChannelTypeDM || ch.Type == discordgo.ChannelTypeGroupDM +} + +func readableChannelType(typ discordgo.ChannelType) (desc string) { + desc = "other" + + switch typ { + case discordgo.ChannelTypeGuildText: + desc = "guild text" + case discordgo.ChannelTypeDM: + desc = "dm" + case discordgo.ChannelTypeGroupDM: + desc = "group dm" + case discordgo.ChannelTypeGuildPublicThread: + desc = "public thread" + case discordgo.ChannelTypeGuildPrivateThread: + desc = "private thread" + } + + return +} + +func (d *DiscordClient) makeAvatarForChannel(ctx context.Context, ch *discordgo.Channel) *bridgev2.Avatar { + if channelIsPrivate(ch) { + return &bridgev2.Avatar{ + ID: discordid.MakeAvatarID(ch.Icon), + Get: func(ctx context.Context) ([]byte, error) { + url := discordgo.EndpointGroupIcon(ch.ID, ch.Icon) + return httpGet(ctx, d.httpClient, url, "channel/gdm icon") + }, + Remove: ch.Icon == "", + } + } else { + if !d.connector.Config.GuildAvatarsInRoomsEnabled() { + return nil + } + + guild, err := d.Session.State.Guild(ch.GuildID) + + if err != nil || guild == nil { + zerolog.Ctx(ctx).Err(err).Msg("Couldn't look up guild in cache in order to create room avatar") + return nil + } + + return d.makeAvatarForGuild(guild) + } +} + +func (d *DiscordClient) getPrivateChannelMemberList(ch *discordgo.Channel) bridgev2.ChatMemberList { + var members bridgev2.ChatMemberList + members.IsFull = true + members.MemberMap = make(bridgev2.ChatMemberMap, len(ch.Recipients)) + + if len(ch.Recipients) > 0 { + selfEventSender := d.selfEventSender() + + // Private channels' array of participants doesn't include ourselves, + // so inject ourselves as a member. + members.MemberMap[selfEventSender.Sender] = bridgev2.ChatMember{EventSender: selfEventSender} + + for _, recipient := range ch.Recipients { + sender := d.makeEventSender(recipient) + members.MemberMap[sender.Sender] = bridgev2.ChatMember{EventSender: sender} + } + + members.TotalMemberCount = len(ch.Recipients) + } + + return members +} + +func (d *DiscordClient) getChannelNameParams(ch *discordgo.Channel) *ChannelNameParams { + params := &ChannelNameParams{ + Name: ch.Name, + Type: ch.Type, + NSFW: ch.NSFW, + IsDM: ch.Type == discordgo.ChannelTypeDM, + IsGroupDM: ch.Type == discordgo.ChannelTypeGroupDM, + IsCategory: ch.Type == discordgo.ChannelTypeGuildCategory, + IsGuildChannel: ch.GuildID != "", + } + + if ch.ParentID != "" { + parent, err := d.Session.State.Channel(ch.ParentID) + if err == nil && parent != nil { + params.ParentName = parent.Name + } + } + + if ch.GuildID != "" { + guild, err := d.Session.State.Guild(ch.GuildID) + if err == nil && guild != nil { + params.GuildName = guild.Name + } + } + + return params +} + +func (d *DiscordClient) getChannelName(ch *discordgo.Channel) *string { + if ch.Type == discordgo.ChannelTypeDM { + // Respect friend nicknames. + if len(ch.Recipients) > 0 { + if rel := d.relationshipWithUserID(ch.Recipients[0].ID); rel != nil && rel.Nickname != "" { + return &rel.Nickname + } + } else { + // Impossible? + } + + return nil + } + + name := d.connector.Config.FormatChannelName(d.getChannelNameParams(ch)) + return &name +} + +// getChannelChatInfo computes [bridgev2.ChatInfo] for a guild channel or private (DM or group DM) channel. +func (d *DiscordClient) getChannelChatInfo(ctx context.Context, ch *discordgo.Channel) (*bridgev2.ChatInfo, error) { + var roomType database.RoomType + switch ch.Type { + case discordgo.ChannelTypeGuildCategory: + roomType = database.RoomTypeSpace + case discordgo.ChannelTypeDM: + roomType = database.RoomTypeDM + case discordgo.ChannelTypeGroupDM: + roomType = database.RoomTypeGroupDM + default: + roomType = database.RoomTypeDefault + } + + var parentPortalID *networkid.PortalID + if ch.Type == discordgo.ChannelTypeGuildCategory || (ch.ParentID == "" && ch.GuildID != "") { + // Categories and uncategorized guild channels always have the guild as their parent. + parentPortalID = ptr.Ptr(discordid.MakeGuildPortalIDWithID(ch.GuildID)) + } else if ch.ParentID != "" { + // Categorized guild channels. + parentPortalID = ptr.Ptr(discordid.MakeChannelPortalIDWithID(ch.ParentID)) + } + + var memberList bridgev2.ChatMemberList + if channelIsPrivate(ch) { + memberList = d.getPrivateChannelMemberList(ch) + } else { + // TODO we're _always_ sending partial member lists for guilds; we can probably + // do better than that + selfEventSender := d.selfEventSender() + + memberList = bridgev2.ChatMemberList{ + IsFull: false, + MemberMap: map[networkid.UserID]bridgev2.ChatMember{ + selfEventSender.Sender: {EventSender: selfEventSender}, + }, + } + } + + return &bridgev2.ChatInfo{ + Name: d.getChannelName(ch), + Topic: &ch.Topic, + Avatar: d.makeAvatarForChannel(ctx, ch), + + Members: &memberList, + + Type: &roomType, + ParentID: parentPortalID, + + UserLocal: &bridgev2.UserLocalPortalInfo{ + MutedUntil: ptr.Ptr(d.channelMutedUntil(ch.GuildID, ch.ID)), + }, + CanBackfill: true, + + ExtraUpdates: func(ctx context.Context, portal *bridgev2.Portal) (changed bool) { + meta := portal.Metadata.(*discordid.PortalMetadata) + if meta.GuildID != ch.GuildID { + meta.GuildID = ch.GuildID + changed = true + } + if meta.ChannelType == nil || *meta.ChannelType != ch.Type { + meta.ChannelType = ptr.Ptr(ch.Type) + changed = true + } + + return + }, + }, nil +} + +func (d *DiscordClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if !d.IsLoggedIn() { + return nil, bridgev2.ErrNotLoggedIn + } + + guildID := discordid.ParseGuildPortalID(portal.ID) + if guildID != "" { + // Portal is a space representing a Discord guild. + + guild, err := d.Session.State.Guild(guildID) + if err != nil { + return nil, fmt.Errorf("couldn't get guild: %w", err) + } + + return d.getGuildSpaceInfo(ctx, guild) + } else { + // Portal is to a channel of some kind (private or guild). + channelID := discordid.ParseChannelPortalID(portal.ID) + + ch, err := d.Session.State.Channel(channelID) + if err != nil { + return nil, fmt.Errorf("couldn't get channel: %w", err) + } + + return d.getChannelChatInfo(ctx, ch) + } +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go new file mode 100644 index 0000000..521c771 --- /dev/null +++ b/pkg/connector/client.go @@ -0,0 +1,1065 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "errors" + "fmt" + "io" + "iter" + "maps" + "net/http" + "regexp" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/simplevent" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + + "go.mau.fi/util/exmaps" + + "go.mau.fi/mautrix-discord/pkg/discordauth" + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +type DiscordClient struct { + connector *DiscordConnector + UserLogin *bridgev2.UserLogin + Session *discordgo.Session + httpClient *http.Client + + stopConnecting atomic.Pointer[context.CancelFunc] + hasBegunSyncing bool + + markedOpened map[string]time.Time + markedOpenedLock sync.Mutex + + // A map of guild ID (or "" for the settings concerning private channels) + // to its corresponding UserGuildSettings. + guildSettings map[string]*discordgo.UserGuildSettings + guildSettingsLock sync.RWMutex + + // A map of resource (e.g. channel) ID to its corresponding read state. + // + // Since there can be thousands of read state entries, the map is to help + // keep lookups by channel ID speedy by avoiding constant linear searching. + readStates map[string]*discordgo.ReadState + readStatesLock sync.RWMutex + + relationshipLock sync.RWMutex + relationships map[string]*discordgo.Relationship + + userCache *UserCache + + lastSendAttemptMutex sync.Mutex + lastSendAttempt *SendAttempt +} + +func (d *DiscordConnector) LoadUserLogin(ctx context.Context, login *bridgev2.UserLogin) error { + meta := login.Metadata.(*discordid.UserLoginMetadata) + + var session *discordgo.Session + if meta.Token == "" { + login.Log.Warn().Msg("Login has no token, not setting up a session") + // Session on the UserLogin will be nil. + } else { + var err error + session, err = NewDiscordSession(ctx, meta.Token) + if err != nil { + return err + } + } + + cl := DiscordClient{ + connector: d, + UserLogin: login, + Session: session, + httpClient: d.Bridge.GetHTTPClientSettings().Compile(), + userCache: NewUserCache(session), + guildSettings: make(map[string]*discordgo.UserGuildSettings), + readStates: make(map[string]*discordgo.ReadState), + relationships: make(map[string]*discordgo.Relationship), + } + login.Client = &cl + + if session != nil { + session.RESTResponseHook = cl.tapDiscordRESTResponse + } + + return nil +} + +var _ bridgev2.NetworkAPI = (*DiscordClient)(nil) + +func (d *DiscordClient) userLoginMetadata() *discordid.UserLoginMetadata { + return d.UserLogin.Metadata.(*discordid.UserLoginMetadata) +} + +func (d *DiscordClient) Connect(ctx context.Context) { + log := zerolog.Ctx(ctx) + + lacksToken := !d.HasToken() + lacksSession := d.Session == nil + if lacksToken || lacksSession { + // (d.Session can be nil if we lacked credentials on startup.) + log.Warn().Bool("lacking_token", lacksToken). + Bool("lacking_session", lacksSession). + Msg("Refusing to connect") + + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateBadCredentials, + Error: DCNotLoggedIn, + UserAction: status.UserActionRelogin, + }) + return + } + + meta := d.userLoginMetadata() + if meta.HeartbeatSession.IsExpired() { + log.Info().Msg("Heartbeat session expired, creating a new one") + meta.HeartbeatSession = discordgo.NewHeartbeatSession() + } + meta.HeartbeatSession.BumpLastUsed() + d.Session.HeartbeatSession = meta.HeartbeatSession + + d.markedOpened = make(map[string]time.Time) + + d.connectRetrying(ctx, 0) +} + +const maxGatewayConnectRetries = 5 + +// tokenInvalidated responds to Discord invalidating our token. +func (d *DiscordClient) tokenInvalidated(ctx context.Context, circumstance string) { + log := zerolog.Ctx(ctx) + log.Info().Msg("Invalidating user login") + + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateBadCredentials, + Error: DCWebsocketDisconnect4004, + UserAction: status.UserActionRelogin, + }) + + props := d.baseAnalyticsProps(ctx) + props["circumstance"] = circumstance + d.UserLogin.TrackAnalytics("Discord auth invalidation", props) + + // Empty out the token. + log.Debug().Msg("Emptying token") + meta := d.UserLogin.Metadata.(*discordid.UserLoginMetadata) + meta.Token = "" + if err := d.UserLogin.Save(ctx); err != nil { + log.Err(err).Msg("Failed to save user login in order to invalidate session") + } +} + +func (d *DiscordClient) connectRetrying(ctx context.Context, retryCount int) { + retryCtx, cancel := context.WithCancel(ctx) + oldStop := d.stopConnecting.Swap(&cancel) + if oldStop != nil { + (*oldStop)() + } + + log := zerolog.Ctx(ctx).With().Int("retry_count", retryCount).Logger() + + log.Debug().Msg("Connecting to Discord") + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateConnecting, + }) + + err := d.connect(ctx) + if err != nil { + log.Err(err).Msg("Couldn't connect to Discord") + + closeErr := &websocket.CloseError{} + if errors.As(err, &closeErr) && closeErr.Code == 4004 { + // Effectively the same as *discordgo.InvalidAuth, but at connect + // time. (discordgo only dispatches the synthetic InvalidAuth event + // once you've already connected successfully.) + // + // Don't retry. + d.tokenInvalidated(ctx, "when connecting") + } else if retryCount <= maxGatewayConnectRetries { + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateTransientDisconnect, + Error: DCUnknownWebsocketError, + Message: err.Error(), + }) + + sleepDuration := time.Second * time.Duration(2< 0 + log.Trace(). + Int64("permissions", perms). + Bool("channel_visible", canView). + Msg("Computed visibility of guild channel") + return canView +} + +func (d *DiscordClient) makeAvatarForGuild(guild *discordgo.Guild) *bridgev2.Avatar { + return &bridgev2.Avatar{ + ID: discordid.MakeAvatarID(guild.Icon), + Get: func(ctx context.Context) ([]byte, error) { + url := discordgo.EndpointGuildIcon(guild.ID, guild.Icon) + return httpGet(ctx, d.httpClient, url, "guild icon") + }, + Remove: guild.Icon == "", + } +} + +// bridgedGuildIDs returns a set of guild IDs that should be bridged. Note that +// presence in the returned set does not imply anything about the corresponding +// portals and rooms. +func (d *DiscordClient) bridgedGuildIDs() map[string]struct{} { + meta := d.UserLogin.Metadata.(*discordid.UserLoginMetadata) + bridgingGuildIDs := map[string]struct{}{} + + // guilds that were bridged via the provisioning api + for guildID, bridged := range meta.BridgedGuildIDs { + if bridged { + bridgingGuildIDs[guildID] = struct{}{} + } + } + + // guilds that were declared in the configuration file + for _, guildID := range d.connector.Config.Guilds.BridgingGuildIDs { + bridgingGuildIDs[guildID] = struct{}{} + } + + return bridgingGuildIDs +} + +func (d *DiscordClient) syncGuilds(ctx context.Context) { + guildIDs := slices.Sorted(maps.Keys(d.bridgedGuildIDs())) + + for _, guildID := range guildIDs { + log := zerolog.Ctx(ctx).With(). + Str("guild_id", guildID). + Str("action", "sync guild"). + Logger() + + err := d.syncGuild(log.WithContext(ctx), guildID) + if err != nil { + log.Err(err).Msg("Couldn't bridge guild during sync") + } + } +} + +// deleteGuildPortalSpace queues a remote event that deletes a guild space +// (including children). +func (d *DiscordClient) deleteGuildPortalSpace(ctx context.Context, guildID string) { + log := zerolog.Ctx(ctx) + log.Info().Msg("Unbridging guild by deleting the entire space") + + d.connector.Bridge.QueueRemoteEvent(d.UserLogin, &simplevent.ChatDelete{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatDelete, + PortalKey: d.guildPortalKey(guildID), + }, + OnlyForMe: true, + Children: true, + }) +} + +// ensurePortal synchronously guarantees the existence of a portal's Matrix +// room with up-to-date chat info. +// +// If info is nil, then the chat info is fetched from the NetworkAPI. +func (d *DiscordClient) ensurePortal(ctx context.Context, key networkid.PortalKey, info *bridgev2.ChatInfo) error { + portal, err := d.connector.Bridge.GetPortalByKey(ctx, key) + if err != nil { + return fmt.Errorf("failed to get portal: %w", err) + } + + if info == nil { + info, err = d.GetChatInfo(ctx, portal) + if err != nil { + return fmt.Errorf("failed to get chat info: %w", err) + } + } + + if portal.MXID == "" { + // CreateMatrixRoom will indirectly lead to UpdateInfo being called. + if err := portal.CreateMatrixRoom(ctx, d.UserLogin, info); err != nil { + return fmt.Errorf("failed to create matrix room: %w", err) + } + } else { + portal.UpdateInfo(ctx, info, d.UserLogin, nil, time.Time{}) + } + + return nil +} + +func (d *DiscordClient) syncGuild(ctx context.Context, guildID string) error { + log := zerolog.Ctx(ctx).With(). + Str("guild_id", guildID). + Str("action", "bridge guild"). + Logger() + ctx = log.WithContext(ctx) + + guild, err := d.Session.State.Guild(guildID) + if errors.Is(err, discordgo.ErrStateNotFound) || guild == nil { + log.Err(err).Msg("Couldn't find guild, user isn't a member?") + // TODO likely left/kicked/banned from guild; nuke the portals + return errors.New("couldn't find guild in state") + } + + if err = d.syncGuildRoles(ctx, guildID, guild.Roles); err != nil { + return fmt.Errorf("failed to sync guild roles during guild sync: %w", err) + } + + // Synchronously guarantee the proper creation of the guild space portal so + // child rooms are born with the correct `m.bridge` state. + portalKey := d.guildPortalKey(guild.ID) + if err := d.ensurePortal(ctx, portalKey, nil); err != nil { + return fmt.Errorf("failed to ensure guild space portal: %w", err) + } + + visibleCategoryIDs := make(exmaps.Set[string]) + visibleChannels := make([]*discordgo.Channel, 0, len(guild.Channels)) + for _, guildCh := range guild.Channels { + // Only bridge text channels that are visible. + if guildCh.Type != discordgo.ChannelTypeGuildText || !d.canSeeGuildChannel(ctx, guildCh) { + continue + } + visibleChannels = append(visibleChannels, guildCh) + if guildCh.ParentID != "" { + visibleCategoryIDs.Add(guildCh.ParentID) + } + } + // Synchronously guarantee the proper creation of category space portals + // for the same reason that we do so for guild space portals. + // + // Note that we only care about syncing categories that contain at least + // one channel we can actually see. This matches the behavior of Discord's + // first-party clients. The permission bits on the category channel + // _itself_ are irrelevant. + for categoryID := range visibleCategoryIDs.Iter() { + category := d.channelWithID(ctx, categoryID) + if category == nil { + log.Error().Str("channel_id", categoryID).Msg("Failed to find category channel somehow, proceeding") + continue + } + + err := d.ensurePortal(ctx, d.portalKeyForChannel(category), nil) + if err != nil { + log.Err(err).Msg("Failed to ensure category space, proceeding") + // FIXME The children of this category channel will still be synced + // but with bogus `m.bridge` state. + } + } + // Now that all possible parent spaces exist, we can fan out the syncing of + // all guild channels we can see. + for _, visibleCh := range visibleChannels { + d.queueChannelResync(ctx, visibleCh) + } + + for _, thread := range guild.Threads { + err = d.upsertThreadInfoFromChannel(ctx, thread) + if err != nil { + log.Err(err).Str("thread_id", thread.ID).Msg("Failed to cache thread info during guild sync") + } + } + + d.subscribeGuild(ctx, guildID) + + return nil +} + +func (d *DiscordClient) subscribeGuild(ctx context.Context, guildID string) { + log := zerolog.Ctx(ctx) + + log.Debug().Msg("Subscribing to guild") + err := d.Session.SubscribeGuild(discordgo.GuildSubscribeData{ + GuildID: guildID, + Typing: true, + Activities: true, + Threads: true, + }) + if err != nil { + log.Warn().Err(err).Msg("Failed to subscribe to guild, proceeding") + } +} + +func httpGet(ctx context.Context, httpClient *http.Client, url, thing string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to prepare request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to download %s: %w", thing, err) + } + defer resp.Body.Close() + if resp.StatusCode > 300 { + return nil, fmt.Errorf("failed to download %s: got HTTP %d", thing, resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read %s data: %w", thing, err) + } + return data, nil +} + +func (d *DiscordClient) makeEventSenderWithID(userID string) bridgev2.EventSender { + return bridgev2.EventSender{ + IsFromMe: userID == d.Session.State.User.ID, + SenderLogin: discordid.MakeUserLoginID(userID), + Sender: discordid.MakeUserID(userID), + } +} + +func (d *DiscordClient) selfEventSender() bridgev2.EventSender { + return d.makeEventSenderWithID(d.Session.State.User.ID) +} + +func (d *DiscordClient) makeEventSender(user *discordgo.User) bridgev2.EventSender { + if user == nil { + panic("DiscordClient makeEventSender was passed a nil user") + } + + return d.makeEventSenderWithID(user.ID) +} + +func (d *DiscordClient) queueChannelResync(_ context.Context, ch *discordgo.Channel) { + d.connector.Bridge.QueueRemoteEvent(d.UserLogin, &DiscordChatResync{ + Client: d, + channel: ch, + }) +} + +func (d *DiscordClient) readStateForID(resourceID string) *discordgo.ReadState { + d.readStatesLock.RLock() + defer d.readStatesLock.RUnlock() + + return d.readStates[resourceID] +} + +func (d *DiscordClient) computeMutedUntil(muted bool, cfg *discordgo.MuteConfig) time.Time { + if !muted { + return bridgev2.Unmuted + } + + // If Muted is true but we don't have a MuteConfig, then the mute is + // indefinite. + if cfg == nil { + return event.MutedForever + } + + // Check for the explicit "forever" time window. + if cfg.SelectedTimeWindow != nil && *cfg.SelectedTimeWindow == -1 { + return event.MutedForever + } + + endTime := cfg.EndTime + if endTime == nil { + d.UserLogin.Log.Warn(). + Bool("muted", muted). + Any("mute_config", cfg). + Msg("Encountered bogus mute state, falling back to indefinite mute") + return event.MutedForever + } + return *endTime +} + +// channelMutedUntil computes an appropriate UserLocalPortalInfo.MutedUntil time +// for a given channel. +// +// This method works with private channels if an empty string is passed as the +// guild ID. +func (d *DiscordClient) channelMutedUntil(guildID string, channelID string) time.Time { + settings := d.guildSettingsForGuildID(guildID) + if settings == nil { + return bridgev2.Unmuted + } + + // TODO: Might be worth speeding this up via map. + for _, override := range settings.ChannelOverrides { + if override.ChannelID == channelID { + return d.computeMutedUntil(override.Muted, override.MuteConfig) + } + } + + return d.computeMutedUntil(settings.Muted, settings.MuteConfig) +} + +func (d *DiscordClient) guildSettingsForGuildID(guildID string) *discordgo.UserGuildSettings { + d.guildSettingsLock.RLock() + defer d.guildSettingsLock.RUnlock() + + return d.guildSettings[guildID] +} + +func (d *DiscordClient) channelWithID(ctx context.Context, channelID string) *discordgo.Channel { + if !d.IsLoggedIn() { + return nil + } + + ch, err := d.Session.State.Channel(channelID) + if err != nil { + if errors.Is(err, discordgo.ErrStateNotFound) { + return nil + } + + // Some other weird error happened. This is currently impossible but it's + // best to not rely on implementation details. + zerolog.Ctx(ctx).Err(err). + Str("channel_id", channelID). + Msg("Failed to look up channel") + return nil + } + + return ch +} + +func (d *DiscordClient) syncRemoteProfile(ctx context.Context) bool { + if !d.IsLoggedIn() { + return false + } + + log := zerolog.Ctx(ctx).With(). + Str("action", "sync remote discord profile"). + Logger() + ctx = log.WithContext(ctx) + + me := d.Session.State.User + if me == nil { + return false + } + + log.Debug().Msg("Updating remote profile if needed") + changed := false + remoteName := makeRemoteName(me) + + // Try to update our own ghost, which should upload the avatar if + // everything goes well. + ghost, err := d.connector.Bridge.GetGhostByID(ctx, discordid.MakeUserID(me.ID)) + if err != nil { + log.Err(err).Msg("Failed to get own ghost, remote profile will lack an avatar") + } else if info, err := d.GetUserInfo(ctx, ghost); err != nil { + // Shouldn't happen as the user cache shouldn't even reach out to the + // network; our own user should be there by now. + log.Err(err).Msg("Failed to get own user info") + } else { + log.Debug().Msg("Updating own ghost with user info") + ghost.UpdateInfo(ctx, info) + } + + profile := makeRemoteProfile(me, ghost) + if d.UserLogin.RemoteName != remoteName { + d.UserLogin.RemoteName = remoteName + changed = true + } + if d.UserLogin.RemoteProfile != profile { + d.UserLogin.RemoteProfile = profile + changed = true + } + + if changed { + if err := d.UserLogin.Save(ctx); err != nil { + log.Err(err).Msg("Failed to save UserLogin while updating remote profile") + } + } + return changed + // NOTE: For clients to immediately get the new remote profile, you need to + // send a bridge state. +} + +func (d *DiscordClient) wrapReceived40002(ctx context.Context, err error) error { + log := zerolog.Ctx(ctx) + log.Err(err).Msg("Received 40002 from Discord") + + props := d.baseAnalyticsProps(ctx) + props["errorMessage"] = err.Error() + d.UserLogin.TrackAnalytics("Discord account verification required", props) + + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateBadCredentials, + UserAction: status.UserActionOpenNative, + Error: DCHTTP40002, + }) + + return bridgev2.WrapErrorInStatus(err). + // Tell clients to not retry. + WithStatus(event.MessageStatusFail). + WithIsCertain(true). + WithMessage(accountVerificationRequiredMessage). + WithSendNotice(true) +} + +func (d *DiscordClient) tryWrappingError(ctx context.Context, err error) error { + if err == nil { + return nil + } + + var restErr *discordgo.RESTError + + if errors.As(err, &restErr) && restErr.Message != nil { + if restErr.Message.Code == discordgo.ErrCodeActionRequiredVerifiedAccount { + return d.wrapReceived40002(ctx, err) + } + } + + return err +} + +var snowflakeish = regexp.MustCompile(`\d{17,}`) + +func redactDiscordRESTPath(path string) string { + return snowflakeish.ReplaceAllLiteralString(path, "...") +} + +func dmChannelRecipientID(ch *discordgo.Channel) *string { + if ch == nil { + return nil + } + if ch.Type != discordgo.ChannelTypeDM { + return nil + } + if len(ch.Recipients) != 1 { + return nil + } + + return &ch.Recipients[0].ID +} + +func (d *DiscordClient) baseAnalyticsProps(ctx context.Context) map[string]any { + props := make(map[string]any) + if ctx == nil { + return props + } + + ch, ok := ctx.Value(contextKeyChannel).(*discordgo.Channel) + if ok && ch != nil { + risky := false + props["channelType"] = readableChannelType(ch.Type) + + if recipientID := dmChannelRecipientID(ch); recipientID != nil { + relationshipDesc := "none" + if rel := d.relationshipWithUserID(*recipientID); rel != nil { + relationshipDesc = readableRelationshipType(rel.Type) + } else if ch.Type == discordgo.ChannelTypeDM { + // No relationship with the recipient and it's a 1:1 DM. + risky = true + } + + props["relationshipWithRecipient"] = relationshipDesc + props["risky"] = risky + } + } + + d.lastSendAttemptMutex.Lock() + if attempt := d.lastSendAttempt; attempt != nil { + props["lastInMemorySendAttemptAgeMs"] = time.Since(attempt.At).Milliseconds() + props["lastInMemorySendAttemptChannelType"] = readableChannelType(attempt.ChannelType) + if relType := attempt.RecipientRelationshipType; relType != nil { + props["lastInMemorySendAttemptRecipientRelationshipType"] = readableRelationshipType(*relType) + } + } + d.lastSendAttemptMutex.Unlock() + + return props +} + +func (d *DiscordClient) tapDiscordRESTResponse(req *http.Request, resp *http.Response, body []byte) { + // NOTE: discordgo calls this in a blocking fashion after reading the HTTP + // response from Discord, so don't block here. + ctx := context.Background() + + if d.Session != nil && !d.Session.IsUser { + return + } + + captcha := discordauth.TryUnmarshalingCaptcha(ctx, resp, body) + if captcha == nil { + return + } + + redactedEndpoint := redactDiscordRESTPath(req.URL.Path) + props := d.baseAnalyticsProps(req.Context()) + maps.Copy(props, map[string]any{ + "apiEndpoint": redactedEndpoint, + "httpMethod": req.Method, + "captchaService": string(captcha.Service), + "captchaInvisible": captcha.Invisible, + "captchaUserFlow": captcha.UserFlow, + }) + + // (This fires a goroutine under the hood so it's alright to call this from + // here.) + d.UserLogin.TrackAnalytics("Discord CAPTCHA challenge", props) +} + +func (d *DiscordClient) relationshipWithUserID(userID string) *discordgo.Relationship { + if d.Session == nil || d.Session.State == nil { + return nil + } + + d.relationshipLock.RLock() + defer d.relationshipLock.RUnlock() + + return d.relationships[userID] +} + +func (d *DiscordClient) relationshipWithDMRecipient(ch *discordgo.Channel) *discordgo.Relationship { + if ch == nil { + return nil + } + + recip := dmChannelRecipientID(ch) + if recip == nil { + return nil + } + + rel := d.relationshipWithUserID(*recip) + return rel +} + +// dmChannelForUserID finds the DM channel with the given user, if any. +func (d *DiscordClient) dmChannelForUserID(userID string) *discordgo.Channel { + if d.Session == nil || d.Session.State == nil { + return nil + } + + d.Session.State.RLock() + defer d.Session.State.RUnlock() + + for _, ch := range d.Session.State.PrivateChannels { + if len(ch.Recipients) == 1 && ch.Recipients[0].ID == userID { + return ch + } + } + + return nil +} + +func (d *DiscordClient) rebuildRelationships() { + if d.Session == nil || d.Session.State == nil { + return + } + + d.relationshipLock.Lock() + defer d.relationshipLock.Unlock() + + clear(d.relationships) + + for _, rel := range d.Session.State.Relationships { + if rel == nil { + continue + } + d.relationships[rel.ID] = rel + } +} + +func (d *DiscordClient) upsertRelationship(rel *discordgo.Relationship) { + if rel == nil { + return + } + + d.relationshipLock.Lock() + defer d.relationshipLock.Unlock() + + d.relationships[rel.ID] = rel +} + +func (d *DiscordClient) removeRelationship(userID string) { + d.relationshipLock.Lock() + defer d.relationshipLock.Unlock() + + delete(d.relationships, userID) +} diff --git a/pkg/connector/config.go b/pkg/connector/config.go new file mode 100644 index 0000000..7c672b5 --- /dev/null +++ b/pkg/connector/config.go @@ -0,0 +1,121 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + _ "embed" + "strings" + "text/template" + + "github.com/bwmarrin/discordgo" + up "go.mau.fi/util/configupgrade" + "gopkg.in/yaml.v3" +) + +//go:embed example-config.yaml +var ExampleConfig string + +const defaultChannelNameTemplate = `{{if and .IsGuildChannel (not .IsCategory)}}#{{end}}{{.Name}}` + +type Config struct { + Guilds struct { + BridgingGuildIDs []string `yaml:"bridging_guild_ids"` + } `yaml:"guilds"` + + // ChannelNameTemplate formats Matrix room names for Discord channels other + // than 1:1 DMs, which intentionally use bridgev2's ghost-derived default. + ChannelNameTemplate string `yaml:"channel_name_template"` + CustomEmojiReactions *bool `yaml:"custom_emoji_reactions"` + GuildAvatarsInRooms *bool `yaml:"guild_avatars_in_rooms"` + + ForbidDMingStrangers *bool `yaml:"forbid_dming_strangers"` + + LogWhenDroppingMessages bool `yaml:"log_when_dropping_messages"` + + channelNameTemplate *template.Template `yaml:"-"` +} + +type umConfig Config + +func (c *Config) UnmarshalYAML(node *yaml.Node) error { + err := node.Decode((*umConfig)(c)) + if err != nil { + return err + } + + if c.ChannelNameTemplate == "" { + c.ChannelNameTemplate = defaultChannelNameTemplate + } + + c.channelNameTemplate, err = template.New("channel_name").Parse(c.ChannelNameTemplate) + if err != nil { + return err + } + + return nil +} + +// ChannelNameParams describes the values available to [Config.FormatChannelName]. +// +// It intentionally includes both the raw Discord channel type and convenience +// booleans so templates can express v1-style naming rules without relying on +// numeric channel type constants. +type ChannelNameParams struct { + Name string + ParentName string + GuildName string + Type discordgo.ChannelType + NSFW bool + IsDM bool + IsGroupDM bool + IsCategory bool + IsGuildChannel bool +} + +// FormatChannelName renders [Config.ChannelNameTemplate] for non-guild-space +// channel portals. One-to-one DMs intentionally bypass this helper so bridgev2 +// can derive the room name from the other user's ghost. +func (c *Config) FormatChannelName(params *ChannelNameParams) string { + var buffer strings.Builder + _ = c.channelNameTemplate.Execute(&buffer, params) + return buffer.String() +} + +func (c Config) ForbidDMingStrangersEnabled() bool { + return c.ForbidDMingStrangers == nil || *c.ForbidDMingStrangers +} + +func (c Config) CustomEmojiReactionsEnabled() bool { + return c.CustomEmojiReactions == nil || *c.CustomEmojiReactions +} + +func (c Config) GuildAvatarsInRoomsEnabled() bool { + return c.GuildAvatarsInRooms != nil && *c.GuildAvatarsInRooms +} + +func upgradeConfig(helper up.Helper) { + helper.Copy(up.List, "guilds", "bridging_guild_ids") + helper.Copy(up.Bool, "guilds", "guild_avatars_in_rooms") + helper.Copy(up.Bool, "forbid_dming_strangers") + helper.Copy(up.Str, "channel_name_template") + helper.Copy(up.Bool, "custom_emoji_reactions") + helper.Copy(up.Bool, "log_when_dropping_messages") +} + +func (d *DiscordConnector) GetConfig() (example string, data any, upgrader up.Upgrader) { + return ExampleConfig, &d.Config, up.SimpleUpgrader(upgradeConfig) +} diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go new file mode 100644 index 0000000..c3cee8d --- /dev/null +++ b/pkg/connector/connector.go @@ -0,0 +1,95 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "net/http" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" + "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/msgconv" +) + +type DiscordConnector struct { + Bridge *bridgev2.Bridge + Config Config + DB *discorddb.DiscordDB + MsgConv *msgconv.MessageConverter + attachmentCache *attachmentCache + httpClient *http.Client +} + +var ( + _ bridgev2.NetworkConnector = (*DiscordConnector)(nil) + _ bridgev2.MaxFileSizeingNetwork = (*DiscordConnector)(nil) + _ bridgev2.TransactionIDGeneratingNetwork = (*DiscordConnector)(nil) +) + +func (d *DiscordConnector) Init(bridge *bridgev2.Bridge) { + d.Bridge = bridge + d.DB = discorddb.New(bridge.DB.Database, bridge.Log.With().Str("db_section", "discord").Logger()) + d.MsgConv = msgconv.NewMessageConverter(bridge) + d.attachmentCache = NewAttachmentCache() + d.MsgConv.CacheDirectMediaAttachment = d.attachmentCache.Insert + d.httpClient = d.Bridge.GetHTTPClientSettings().Compile() +} + +func (d *DiscordConnector) SetMaxFileSize(maxSize int64) { + d.MsgConv.MaxFileSize = maxSize +} + +func (d *DiscordConnector) Start(ctx context.Context) error { + log := zerolog.Ctx(ctx) + + err := d.DB.Upgrade(ctx) + if err != nil { + log.Err(err).Msg("Failed to upgrade Discord database") + return err + } + + log.Debug().Msg("Setting up provisioning API") + + err = d.setUpProvisioningAPIs() + if err != nil { + log.Err(err).Msg("Failed to set up provisioning API, proceeding") + // Don't treat this error as fatal. + } + + return nil +} + +func (d *DiscordConnector) GetName() bridgev2.BridgeName { + return bridgev2.BridgeName{ + DisplayName: "Discord", + NetworkURL: "https://discord.com", + NetworkIcon: "mxc://maunium.net/nIdEykemnwdisvHbpxflpDlC", + NetworkID: "discord", + BeeperBridgeType: "discordgo", + DefaultPort: 29334, + } +} + +func (d *DiscordConnector) GenerateTransactionID(_ id.UserID, _ id.RoomID, _ event.Type) networkid.RawTransactionID { + return networkid.RawTransactionID(discordid.GenerateNonce()) +} diff --git a/config/config.go b/pkg/connector/dbmeta.go similarity index 63% rename from config/config.go rename to pkg/connector/dbmeta.go index d704651..6d8fbd5 100644 --- a/config/config.go +++ b/pkg/connector/dbmeta.go @@ -1,5 +1,5 @@ // mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2026 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -14,22 +14,21 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package config +package connector import ( - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/bridgev2/database" + + "go.mau.fi/mautrix-discord/pkg/discordid" ) -type Config struct { - *bridgeconfig.BaseConfig `yaml:",inline"` - - Bridge BridgeConfig `yaml:"bridge"` -} - -func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool { - _, homeserver, _ := userID.Parse() - _, hasSecret := config.Bridge.DoublePuppetConfig.SharedSecretMap[homeserver] - - return hasSecret +func (d *DiscordConnector) GetDBMetaTypes() database.MetaTypes { + return database.MetaTypes{ + Portal: func() any { + return &discordid.PortalMetadata{} + }, + UserLogin: func() any { + return &discordid.UserLoginMetadata{} + }, + } } diff --git a/pkg/connector/directmedia.go b/pkg/connector/directmedia.go new file mode 100644 index 0000000..f7c68bb --- /dev/null +++ b/pkg/connector/directmedia.go @@ -0,0 +1,205 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/mediaproxy" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +var ( + _ bridgev2.DirectMediableNetwork = (*DiscordConnector)(nil) +) + +func (d *DiscordConnector) Download( + ctx context.Context, + mediaID networkid.MediaID, + params map[string]string, +) (mediaproxy.GetMediaResponse, error) { + info, err := discordid.ParseMediaID(mediaID) + if err != nil { + return nil, fmt.Errorf("failed to parse media id for download: %w", err) + } + + return d.downloadAttachment(ctx, info) +} + +func (d *DiscordConnector) SetUseDirectMedia() { + d.MsgConv.DirectMedia = true +} + +func (d *DiscordConnector) downloadAttachment( + ctx context.Context, + info *discordid.MediaInfo, +) (*mediaproxy.GetMediaResponseURL, error) { + url, expiresAt, err := d.resolveAttachmentURL(ctx, info) + if err != nil { + return nil, fmt.Errorf("failed to refresh attachment url for download: %w", err) + } + if expiresAt.IsZero() { + // A zero expiry becomes effectively immutable caching in mediaproxy. + // Unknown expiry is safer as no-store for now. + expiresAt = time.Now() + } + return &mediaproxy.GetMediaResponseURL{ + URL: url, + ExpiresAt: expiresAt, + }, nil +} + +func (d *DiscordConnector) resolveAttachmentURL(ctx context.Context, info *discordid.MediaInfo) (url string, expires time.Time, err error) { + if entry, ok := d.attachmentCache.Get(info.MediaInfoV1); ok { + return entry.URL, entry.Expiry, nil + } + + url, expiresAt, err := d.refreshAttachmentURL(ctx, info) + if err != nil { + return "", time.Time{}, err + } + + d.attachmentCache.Insert(info, url) + return url, expiresAt, nil +} + +func (d *DiscordConnector) refreshAttachmentURL( + ctx context.Context, + info *discordid.MediaInfo, +) (url string, expires time.Time, err error) { + log := zerolog.Ctx(ctx).With().Str("action", "refresh attachment url").Logger() + ctx = log.WithContext(ctx) + + login, err := d.Bridge.GetExistingUserLoginByID(ctx, info.UserLoginID) + if err != nil { + return "", time.Time{}, err + } else if login == nil { + return "", time.Time{}, mautrix.MNotFound.WithMessage("Direct media login not found") + } + + client, ok := login.Client.(*DiscordClient) + if !ok || client == nil || !client.IsLoggedIn() { + return "", time.Time{}, mautrix.MNotFound.WithMessage("Direct media login is not connected") + } + + channelID := info.ChannelID + messageID := info.MessageID + attachmentID := info.AttachmentID + + parentChannelID := channelID + threadChannelID := "" + threadInfo, err := d.DB.Thread.GetByThreadChannelID(ctx, string(info.UserLoginID), channelID) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to query thread info: %w", err) + } else if threadInfo != nil { + parentChannelID = threadInfo.ParentChannelID + threadChannelID = threadInfo.ThreadChannelID + } + + var requestOptions []discordgo.RequestOption + portalKey := discordid.MakeChannelPortalKey(parentChannelID, info.UserLoginID, d.Bridge.Config.SplitPortals) + portal, err := d.Bridge.GetExistingPortalByKey(ctx, portalKey) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to query portal for direct media: %w", err) + } else if portal != nil { + if meta, ok := portal.Metadata.(*discordid.PortalMetadata); ok { + requestOptions = append(requestOptions, makeDiscordReferer(meta.GuildID, parentChannelID, threadChannelID)) + } + } else if threadChannelID == "" { + // DMs still benefit from @me referers. + requestOptions = append(requestOptions, makeDiscordReferer("", parentChannelID, "")) + } + + var messages []*discordgo.Message + if client.Session.IsUser { + messages, err = client.Session.ChannelMessages(channelID, 5, "", "", messageID, requestOptions...) + } else { + var msg *discordgo.Message + msg, err = client.Session.ChannelMessage(channelID, messageID, requestOptions...) + if err == nil && msg != nil { + messages = []*discordgo.Message{msg} + } + } + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to fetch direct media message: %w", err) + } + + for _, msg := range messages { + for _, att := range msg.Attachments { + if att.ID == attachmentID { + expiresAt := normalizeAttachmentExpiry(parseAttachmentExpiryFromURL(att.URL)) + // (Trace is not the default log level, so this is only visible + // in development scenarios.) + log.Trace(). + Str("channel_id", channelID). + Str("message_id", messageID). + Str("attachment_id", attachmentID). + Time("expires_at", expiresAt). + Msg("Resolved direct media attachment URL") + // TODO(skip): This is ignoring the rest of the attachments. + return att.URL, expiresAt, nil + } + } + } + + return "", time.Time{}, mautrix.MNotFound.WithMessage("Attachment not found in message") +} + +func parseAttachmentExpiryParam(ex string) time.Time { + tsBytes, err := hex.DecodeString(ex) + if err != nil || len(tsBytes) != 4 { + return time.Time{} + } + + parsedTS := int64(binary.BigEndian.Uint32(tsBytes)) + now := time.Now() + expiry := time.Unix(parsedTS, 0) + if expiry.Before(now) || expiry.After(now.Add(365*24*time.Hour)) { + // Looks to be invalid. + return time.Time{} + } + return expiry +} + +func parseAttachmentExpiryFromURL(rawURL string) time.Time { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return time.Time{} + } + + return parseAttachmentExpiryParam(parsedURL.Query().Get("ex")) +} + +func normalizeAttachmentExpiry(expiry time.Time) time.Time { + // Default to a validity period of 24 hours. + if expiry.IsZero() { + return time.Now().Add(24 * time.Hour) + } + + return expiry +} diff --git a/pkg/connector/directmedia_cache.go b/pkg/connector/directmedia_cache.go new file mode 100644 index 0000000..b4a588b --- /dev/null +++ b/pkg/connector/directmedia_cache.go @@ -0,0 +1,91 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "sync" + "time" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +const attachmentCacheLife = 5 * time.Minute + +type attachmentCacheEntry struct { + Expiry time.Time + URL string +} + +func (ce *attachmentCacheEntry) IsExpired() bool { + return time.Until(ce.Expiry) <= attachmentCacheLife +} + +// attachmentCache tracks expiring attachment URLs from Discord. An +// attachmentCache is safe for concurrent use by multiple goroutines. +type attachmentCache struct { + sync.RWMutex + cache map[discordid.MediaInfoV1]attachmentCacheEntry +} + +// TODO(skip): The cache grows in an unbounded fashion. + +func NewAttachmentCache() *attachmentCache { + return &attachmentCache{ + cache: make(map[discordid.MediaInfoV1]attachmentCacheEntry), + } +} + +func (ac *attachmentCache) Get(key discordid.MediaInfoV1) (*attachmentCacheEntry, bool) { + ac.Lock() + defer ac.Unlock() + + cached, ok := ac.cache[key] + if !ok { + return nil, false + } + + if cached.IsExpired() { + delete(ac.cache, key) + return nil, false + } + + return &cached, true +} + +func (ac *attachmentCache) Insert(info *discordid.MediaInfo, url string) { + if url == "" { + return + } + + expiry := normalizeAttachmentExpiry(parseAttachmentExpiryFromURL(url)) + + ac.Lock() + defer ac.Unlock() + + key := info.MediaInfoV1 + entry := attachmentCacheEntry{ + URL: url, + Expiry: expiry, + } + + if expiry.IsZero() || entry.IsExpired() { + delete(ac.cache, key) + return + } + + ac.cache[key] = entry +} diff --git a/pkg/connector/directmedia_test.go b/pkg/connector/directmedia_test.go new file mode 100644 index 0000000..e4d09b8 --- /dev/null +++ b/pkg/connector/directmedia_test.go @@ -0,0 +1,21 @@ +package connector + +import ( + "testing" + "time" +) + +func TestParseAttachmentExpiryParam(t *testing.T) { + losAngeles, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + t.Fatalf("failed to load test timezone: %v", err) + } + + expiry := parseAttachmentExpiryParam("69be6214").In(losAngeles) + got := expiry.String() + want := "2026-03-21 02:17:08 -0700 PDT" + + if got != want { + t.Fatalf("unexpected parsed expiry: got %q, want %q", got, want) + } +} diff --git a/pkg/connector/discorddb/00-latest-schema.sql b/pkg/connector/discorddb/00-latest-schema.sql new file mode 100644 index 0000000..9ced24b --- /dev/null +++ b/pkg/connector/discorddb/00-latest-schema.sql @@ -0,0 +1,52 @@ +-- v0 -> v3 (compatible with v1+): latest schema + +-- https://docs.discord.com/developers/resources/emoji#emoji-object +CREATE TABLE custom_emoji ( + discord_id TEXT NOT NULL, + name TEXT NOT NULL, + animated BOOLEAN NOT NULL, + + mxc TEXT, + + PRIMARY KEY (discord_id) +); +CREATE INDEX custom_emoji_mxc_idx ON custom_emoji (mxc); + +CREATE TABLE role ( + discord_guild_id TEXT NOT NULL, + discord_id TEXT NOT NULL, + + name TEXT NOT NULL, + icon TEXT, + + mentionable BOOLEAN NOT NULL, + managed BOOLEAN NOT NULL, + hoist BOOLEAN NOT NULL, + + color INTEGER NOT NULL, + position INTEGER NOT NULL, + permissions BIGINT NOT NULL, + + PRIMARY KEY (discord_guild_id, discord_id) +); + +CREATE TABLE discord_thread ( + -- The ID of the UserLogin that witnessed the thread. + user_login_id TEXT NOT NULL, + + -- The ID of the thread itself. For public threads, this exactly matches the + -- ID of the message that the thread originates from. + thread_channel_id TEXT NOT NULL, + + -- The ID of the thread's "root" message. For public threads, this will + -- match `id` and therefore the message that the thread originates from. + -- For private threads, this will be NULL. + root_message_id TEXT, + + -- The Discord channel ID that the thread belongs to. + parent_channel_id TEXT NOT NULL, + + PRIMARY KEY (user_login_id, thread_channel_id) +); +CREATE UNIQUE INDEX discord_thread_user_login_root_msg_uidx +ON discord_thread (user_login_id, root_message_id); diff --git a/pkg/connector/discorddb/02-roles.sql b/pkg/connector/discorddb/02-roles.sql new file mode 100644 index 0000000..c845ff5 --- /dev/null +++ b/pkg/connector/discorddb/02-roles.sql @@ -0,0 +1,19 @@ +-- v1 -> v2 (compatible with v1+): roles + +CREATE TABLE role ( + discord_guild_id TEXT NOT NULL, + discord_id TEXT NOT NULL, + + name TEXT NOT NULL, + icon TEXT, + + mentionable BOOLEAN NOT NULL, + managed BOOLEAN NOT NULL, + hoist BOOLEAN NOT NULL, + + color INTEGER NOT NULL, + position INTEGER NOT NULL, + permissions BIGINT NOT NULL, + + PRIMARY KEY (discord_guild_id, discord_id) +); diff --git a/pkg/connector/discorddb/03-threads.sql b/pkg/connector/discorddb/03-threads.sql new file mode 100644 index 0000000..2a5fe67 --- /dev/null +++ b/pkg/connector/discorddb/03-threads.sql @@ -0,0 +1,22 @@ +-- v2 -> v3 (compatible with v1+): threads + +CREATE TABLE discord_thread ( + -- The ID of the UserLogin that witnessed the thread. + user_login_id TEXT NOT NULL, + + -- The ID of the thread itself. For public threads, this exactly matches the + -- ID of the message that the thread originates from. + thread_channel_id TEXT NOT NULL, + + -- The ID of the thread's "root" message. For public threads, this will + -- match `id` and therefore the message that the thread originates from. + -- For private threads, this will be NULL. + root_message_id TEXT, + + -- The Discord channel ID that the thread belongs to. + parent_channel_id TEXT NOT NULL, + + PRIMARY KEY (user_login_id, thread_channel_id) +); +CREATE UNIQUE INDEX discord_thread_user_login_root_msg_uidx +ON discord_thread (user_login_id, root_message_id); diff --git a/pkg/connector/discorddb/database.go b/pkg/connector/discorddb/database.go new file mode 100644 index 0000000..86f2aa9 --- /dev/null +++ b/pkg/connector/discorddb/database.go @@ -0,0 +1,60 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discorddb + +import ( + "embed" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" +) + +type DiscordDB struct { + *dbutil.Database + CustomEmoji *CustomEmojiQuery + Role *RoleQuery + Thread *ThreadQuery +} + +var table dbutil.UpgradeTable + +//go:embed *.sql +var upgrades embed.FS + +func init() { + table.RegisterFS(upgrades) +} + +func UpgradeTable() dbutil.UpgradeTable { + return table +} + +func New(db *dbutil.Database, log zerolog.Logger) *DiscordDB { + db = db.Child("discord_version", table, dbutil.ZeroLogger(log)) + return &DiscordDB{ + Database: db, + CustomEmoji: &CustomEmojiQuery{ + QueryHelper: dbutil.MakeQueryHelper(db, newCustomEmoji), + }, + Role: &RoleQuery{ + QueryHelper: dbutil.MakeQueryHelper(db, newRole), + }, + Thread: &ThreadQuery{ + QueryHelper: dbutil.MakeQueryHelper(db, newThread), + }, + } +} diff --git a/pkg/connector/discorddb/emoji.go b/pkg/connector/discorddb/emoji.go new file mode 100644 index 0000000..21f7280 --- /dev/null +++ b/pkg/connector/discorddb/emoji.go @@ -0,0 +1,81 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discorddb + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +type CustomEmojiQuery struct { + *dbutil.QueryHelper[*CustomEmoji] +} + +type CustomEmoji struct { + ID string + Name string + Animated bool + ImageMXC id.ContentURIString +} + +func (ce *CustomEmoji) sqlVariables() []any { + return []any{ce.ID, ce.Name, ce.Animated, dbutil.StrPtr(ce.ImageMXC)} +} + +func newCustomEmoji(_ *dbutil.QueryHelper[*CustomEmoji]) *CustomEmoji { + return &CustomEmoji{} +} + +const ( + getCustomEmojiByMXCQuery = ` + SELECT discord_id, name, animated, mxc FROM custom_emoji WHERE mxc=$1 ORDER BY name + ` + getCustomEmojiByDiscordIDQuery = ` + SELECT discord_id, name, animated, mxc FROM custom_emoji WHERE discord_id=$1 ORDER BY name + ` + upsertCustomEmojiQuery = ` + INSERT INTO custom_emoji (discord_id, name, animated, mxc) + VALUES ($1, $2, $3, $4) + ON CONFLICT (discord_id) DO UPDATE + SET name = excluded.name, animated = excluded.animated, mxc = excluded.mxc + ` +) + +func (ceq *CustomEmojiQuery) GetByDiscordID(ctx context.Context, discordID string) (*CustomEmoji, error) { + return ceq.QueryOne(ctx, getCustomEmojiByDiscordIDQuery, &discordID) +} + +func (ceq *CustomEmojiQuery) GetByMXC(ctx context.Context, mxc string) (*CustomEmoji, error) { + return ceq.QueryOne(ctx, getCustomEmojiByMXCQuery, &mxc) +} + +func (ceq *CustomEmojiQuery) Put(ctx context.Context, emoji *CustomEmoji) error { + return ceq.Exec(ctx, upsertCustomEmojiQuery, emoji.sqlVariables()...) +} + +func (ce *CustomEmoji) Scan(row dbutil.Scannable) (*CustomEmoji, error) { + var imageURL sql.NullString + err := row.Scan(&ce.ID, &ce.Name, &ce.Animated, &imageURL) + if err != nil { + return nil, err + } + ce.ImageMXC = id.ContentURIString(imageURL.String) + return ce, nil +} diff --git a/pkg/connector/discorddb/role.go b/pkg/connector/discorddb/role.go new file mode 100644 index 0000000..1f9f7f7 --- /dev/null +++ b/pkg/connector/discorddb/role.go @@ -0,0 +1,151 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discorddb + +import ( + "context" + "database/sql" + + "github.com/bwmarrin/discordgo" + "go.mau.fi/util/dbutil" +) + +type RoleQuery struct { + *dbutil.QueryHelper[*Role] +} + +type Role struct { + GuildID string + discordgo.Role +} + +func (r *Role) sqlVariables() []any { + return []any{ + r.GuildID, + r.ID, + r.Name, + dbutil.StrPtr(r.Icon), + r.Mentionable, + r.Managed, + r.Hoist, + r.Color, + r.Position, + r.Permissions, + } +} + +func newRole(_ *dbutil.QueryHelper[*Role]) *Role { + return &Role{} +} + +const ( + getRoleByIDQuery = ` + SELECT discord_guild_id, discord_id, name, icon, mentionable, managed, hoist, color, position, permissions + FROM role + WHERE discord_guild_id=$1 AND discord_id=$2 + ` + getRolesByGuildIDQuery = ` + SELECT discord_guild_id, discord_id, name, icon, mentionable, managed, hoist, color, position, permissions + FROM role + WHERE discord_guild_id=$1 + ORDER BY position DESC, discord_id + ` + upsertRoleQuery = ` + INSERT INTO role (discord_guild_id, discord_id, name, icon, mentionable, managed, hoist, color, position, permissions) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (discord_guild_id, discord_id) DO UPDATE + SET name = excluded.name, + icon = excluded.icon, + mentionable = excluded.mentionable, + managed = excluded.managed, + hoist = excluded.hoist, + color = excluded.color, + position = excluded.position, + permissions = excluded.permissions + ` + deleteRolesByGuildIDQuery = ` + DELETE FROM role WHERE discord_guild_id=$1 + ` + deleteRoleByIDQuery = ` + DELETE FROM role WHERE discord_guild_id=$1 AND discord_id=$2 + ` +) + +func (rq *RoleQuery) GetByID(ctx context.Context, guildID, roleID string) (*Role, error) { + return rq.QueryOne(ctx, getRoleByIDQuery, &guildID, &roleID) +} + +func (rq *RoleQuery) GetByGuildID(ctx context.Context, guildID string) ([]*Role, error) { + return rq.QueryMany(ctx, getRolesByGuildIDQuery, &guildID) +} + +func (rq *RoleQuery) Put(ctx context.Context, role *Role) error { + return rq.Exec(ctx, upsertRoleQuery, role.sqlVariables()...) +} + +func (rq *RoleQuery) PutMany(ctx context.Context, roles []*Role) error { + for _, role := range roles { + if err := rq.Put(ctx, role); err != nil { + return err + } + } + return nil +} + +func (rq *RoleQuery) DeleteByGuildID(ctx context.Context, guildID string) error { + return rq.Exec(ctx, deleteRolesByGuildIDQuery, &guildID) +} + +func (rq *RoleQuery) DeleteByID(ctx context.Context, guildID, roleID string) error { + return rq.Exec(ctx, deleteRoleByIDQuery, &guildID, &roleID) +} + +func (rq *RoleQuery) ReplaceGuildRoles(ctx context.Context, guildID string, roles []*Role) error { + return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + if err := rq.DeleteByGuildID(ctx, guildID); err != nil { + return err + } + for _, role := range roles { + role.GuildID = guildID + if err := rq.Put(ctx, role); err != nil { + return err + } + } + return nil + }) +} + +func (r *Role) Scan(row dbutil.Scannable) (*Role, error) { + var icon sql.NullString + err := row.Scan( + &r.GuildID, + &r.ID, + &r.Name, + &icon, + &r.Mentionable, + &r.Managed, + &r.Hoist, + &r.Color, + &r.Position, + &r.Permissions, + ) + if err != nil { + return nil, err + } + r.Icon = icon.String + return r, nil +} diff --git a/pkg/connector/discorddb/thread.go b/pkg/connector/discorddb/thread.go new file mode 100644 index 0000000..6a6a065 --- /dev/null +++ b/pkg/connector/discorddb/thread.go @@ -0,0 +1,106 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discorddb + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" +) + +type ThreadQuery struct { + *dbutil.QueryHelper[*Thread] +} + +type Thread struct { + UserLoginID string + ThreadChannelID string + RootMessageID string + ParentChannelID string +} + +func (t *Thread) sqlVariables() []any { + var rootMsgID *string + if t.RootMessageID != "" { + rootMsgID = &t.RootMessageID + } + return []any{ + t.UserLoginID, + t.ThreadChannelID, + rootMsgID, + t.ParentChannelID, + } +} + +func newThread(_ *dbutil.QueryHelper[*Thread]) *Thread { + return &Thread{} +} + +const ( + getThreadByChannelIDQuery = ` + SELECT user_login_id, thread_channel_id, root_message_id, parent_channel_id + FROM discord_thread + WHERE user_login_id=$1 AND thread_channel_id=$2 + ` + getThreadByRootMessageIDQuery = ` + SELECT user_login_id, thread_channel_id, root_message_id, parent_channel_id + FROM discord_thread + WHERE user_login_id=$1 AND root_message_id=$2 + ` + upsertThreadQuery = ` + INSERT INTO discord_thread (user_login_id, thread_channel_id, root_message_id, parent_channel_id) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_login_id, thread_channel_id) DO UPDATE + SET root_message_id = excluded.root_message_id, + parent_channel_id = excluded.parent_channel_id + ` + deleteThreadByChannelIDQuery = ` + DELETE FROM discord_thread WHERE user_login_id=$1 AND thread_channel_id=$2 + ` +) + +func (tq *ThreadQuery) GetByThreadChannelID(ctx context.Context, userLoginID, threadChannelID string) (*Thread, error) { + return tq.QueryOne(ctx, getThreadByChannelIDQuery, &userLoginID, &threadChannelID) +} + +func (tq *ThreadQuery) GetByRootMessageID(ctx context.Context, userLoginID, rootMessageID string) (*Thread, error) { + return tq.QueryOne(ctx, getThreadByRootMessageIDQuery, &userLoginID, &rootMessageID) +} + +func (tq *ThreadQuery) Put(ctx context.Context, thread *Thread) error { + return tq.Exec(ctx, upsertThreadQuery, thread.sqlVariables()...) +} + +func (tq *ThreadQuery) DeleteByThreadChannelID(ctx context.Context, userLoginID, threadChannelID string) error { + return tq.Exec(ctx, deleteThreadByChannelIDQuery, &userLoginID, &threadChannelID) +} + +func (t *Thread) Scan(row dbutil.Scannable) (*Thread, error) { + var rootMsgID sql.NullString + err := row.Scan( + &t.UserLoginID, + &t.ThreadChannelID, + &rootMsgID, + &t.ParentChannelID, + ) + if err != nil { + return nil, err + } + t.RootMessageID = rootMsgID.String + return t, nil +} diff --git a/pkg/connector/emoji.go b/pkg/connector/emoji.go new file mode 100644 index 0000000..9e9388a --- /dev/null +++ b/pkg/connector/emoji.go @@ -0,0 +1,106 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" +) + +func (d *DiscordConnector) getCustomEmojiDownloadURL(emojiID string, animated bool) (string, string) { + // TODO probably best to leverage http.DetectContentType instead of + // assuming the media type + if animated { + return discordgo.EndpointEmojiAnimated(emojiID), "image/webp" + } + // TODO think about using webp for size savings + return discordgo.EndpointEmoji(emojiID), "image/png" +} + +func (d *DiscordConnector) GetCustomEmojiByMXC(ctx context.Context, mxc string) (*discorddb.CustomEmoji, error) { + return d.DB.CustomEmoji.GetByMXC(ctx, mxc) +} + +func (d *DiscordConnector) GetCustomEmojiMXC(ctx context.Context, emojiID, name string, animated bool) (id.ContentURIString, error) { + log := zerolog.Ctx(ctx).With(). + Str("action", "get discord custom emoji"). + Str("emoji_id", emojiID). + Str("emoji_name", name). + Logger() + ctx = log.WithContext(ctx) + + dbEmoji, err := d.DB.CustomEmoji.GetByDiscordID(ctx, emojiID) + if err != nil { + return "", fmt.Errorf("failed to get custom emoji from database: %w", err) + } + + if dbEmoji != nil && dbEmoji.ImageMXC != "" { + if dbEmoji.Name != name || dbEmoji.Animated != animated { + // Make sure to save changed information. + dbEmoji.Name = name + dbEmoji.Animated = animated + + err = d.DB.CustomEmoji.Put(ctx, dbEmoji) + if err != nil { + log.Warn().Err(err).Msg("Failed to update custom emoji metadata in database") + } + } + + return dbEmoji.ImageMXC, nil + } + + // Custom emoji wasn't in the database or it lacked an MXC, so we have to + // download it. + + emojiURL, mimeType := d.getCustomEmojiDownloadURL(emojiID, animated) + data, err := httpGet(ctx, d.httpClient, emojiURL, "emoji") + if err != nil { + return "", err + } + + mxc, _, err := d.Bridge.Bot.UploadMedia(ctx, "", data, "", mimeType) + + log = log.With().Str("image_mxc", string(mxc)).Logger() + ctx = log.WithContext(ctx) + + if err != nil { + return "", fmt.Errorf("failed to upload emoji to Matrix: %w", err) + } + + if dbEmoji == nil { + dbEmoji = &discorddb.CustomEmoji{ + ID: emojiID, + } + } + + dbEmoji.Name = name + dbEmoji.Animated = animated + dbEmoji.ImageMXC = mxc + + err = d.DB.CustomEmoji.Put(ctx, dbEmoji) + if err != nil { + log.Warn().Err(err).Msg("Failed to save custom emoji") + } + + return mxc, nil +} diff --git a/pkg/connector/events_chat_resync.go b/pkg/connector/events_chat_resync.go new file mode 100644 index 0000000..af30216 --- /dev/null +++ b/pkg/connector/events_chat_resync.go @@ -0,0 +1,113 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +type DiscordChatResync struct { + Client *DiscordClient + channel *discordgo.Channel +} + +var ( + _ bridgev2.RemoteChatResyncWithInfo = (*DiscordChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*DiscordChatResync)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*DiscordChatResync)(nil) +) + +func (d *DiscordChatResync) AddLogContext(c zerolog.Context) zerolog.Context { + c = c.Str("channel_id", d.channel.ID).Int("channel_type", int(d.channel.Type)) + return c +} + +func (d *DiscordChatResync) GetPortalKey() networkid.PortalKey { + ch := d.channel + return d.Client.portalKeyForChannel(ch) +} + +func (d *DiscordChatResync) GetSender() bridgev2.EventSender { + return bridgev2.EventSender{} +} + +func (d *DiscordChatResync) GetType() bridgev2.RemoteEventType { + return bridgev2.RemoteEventChatResync +} + +func (d *DiscordChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + return d.Client.GetChatInfo(ctx, portal) + +} + +func (d *DiscordChatResync) ShouldCreatePortal() bool { + return true +} + +// compareMessageIDs compares two Discord message IDs. +// +// If the first ID is lower, -1 is returned. +// If the second ID is lower, 1 is returned. +// If the IDs are equal, 0 is returned. +func compareMessageIDs(id1, id2 string) int { + if id1 == id2 { + return 0 + } + if len(id1) < len(id2) { + return -1 + } else if len(id2) < len(id1) { + return 1 + } + if id1 < id2 { + return -1 + } + return 1 +} + +func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool { + return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1 +} + +func (d *DiscordChatResync) CheckNeedsBackfill(ctx context.Context, latestBridged *database.Message) (bool, error) { + log := zerolog.Ctx(ctx).With(). + Str("resyncing_channel_id", d.channel.ID). + Str("resyncing_channel_last_message_id", d.channel.LastMessageID). + Str("resyncing_guild_id", d.channel.GuildID). + Bool("has_latest_bridged", latestBridged != nil). + Logger() + + if latestBridged == nil { + needsBackfill := d.channel.LastMessageID != "" + log.Debug().Bool("needs_backfill", needsBackfill).Msg("Computed needs backfill") + return needsBackfill, nil + } + + needsBackfill := shouldBackfill( + discordid.ParseMessageID(latestBridged.ID), + d.channel.LastMessageID, + ) + log.Debug().Bool("needs_backfill", needsBackfill).Msg("Computed needs backfill") + return needsBackfill, nil +} diff --git a/pkg/connector/events_guild_resync.go b/pkg/connector/events_guild_resync.go new file mode 100644 index 0000000..cd8ccc9 --- /dev/null +++ b/pkg/connector/events_guild_resync.go @@ -0,0 +1,61 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type DiscordGuildResync struct { + Client *DiscordClient + guild *discordgo.Guild + portalKey networkid.PortalKey +} + +var ( + _ bridgev2.RemoteChatResyncWithInfo = (*DiscordGuildResync)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*DiscordGuildResync)(nil) +) + +func (d *DiscordGuildResync) AddLogContext(c zerolog.Context) zerolog.Context { + return c.Str("guild_id", d.guild.ID).Str("guild_name", d.guild.Name) +} + +func (d *DiscordGuildResync) GetPortalKey() networkid.PortalKey { + return d.portalKey +} + +func (d *DiscordGuildResync) GetSender() bridgev2.EventSender { + return bridgev2.EventSender{} +} + +func (d *DiscordGuildResync) GetType() bridgev2.RemoteEventType { + return bridgev2.RemoteEventChatResync +} + +func (d *DiscordGuildResync) ShouldCreatePortal() bool { + return true +} + +func (d *DiscordGuildResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + return d.Client.GetChatInfo(ctx, portal) +} diff --git a/pkg/connector/example-config.yaml b/pkg/connector/example-config.yaml new file mode 100644 index 0000000..4ed7c35 --- /dev/null +++ b/pkg/connector/example-config.yaml @@ -0,0 +1,38 @@ +# Configuration options related to Discord guilds (also known as "servers"). +guilds: + # UNSTABLE: The IDs of the guilds to bridge. This is a stopgap measure + # during bridge development. If no guild IDs are specified, then no guilds + # are bridged at all. + bridging_guild_ids: [] + + # Should guild channel portals take on the guild icon as their avatars? + guild_avatars_in_rooms: false + +# Should the bridge refuse to send direct messages to recipients the user isn't +# friends with on Discord? Discord generally considers this to be a "risky" +# action. +forbid_dming_strangers: true + +# Template for Matrix room names created for Discord channels, except for 1:1 +# DMs. 1:1 DMs intentionally do not use this template as their room metadata is +# derived from the other user's ghost (when private_chat_portal_meta is enabled). +# +# Available variables: +# .Name - The Discord channel name. +# .ParentName - The parent channel/category name, if any. +# .GuildName - The guild name for guild channels. +# .Type - The raw Discord channel type. +# .NSFW - Whether the channel is marked NSFW. +# .IsDM - Whether the channel is a 1:1 DM. +# .IsGroupDM - Whether the channel is a group DM. +# .IsCategory - Whether the channel is a guild category. +# .IsGuildChannel - Whether the channel belongs to a guild. +channel_name_template: "{{if and .IsGuildChannel (not .IsCategory)}}#{{end}}{{.Name}}" + +# Should incoming custom emoji reactions be bridged as mxc:// URIs? +# If false, they are bridged as :shortcode: instead. +custom_emoji_reactions: true + +# Should we log when messages from unbridged guild channels are dropped? This +# only includes metadata such as channel and message ID. +log_when_dropping_messages: true diff --git a/pkg/connector/handlediscord.go b/pkg/connector/handlediscord.go new file mode 100644 index 0000000..3fbcb75 --- /dev/null +++ b/pkg/connector/handlediscord.go @@ -0,0 +1,949 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + "runtime/debug" + "slices" + "strconv" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/simplevent" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + + "go.mau.fi/util/variationselector" + + "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/router" +) + +const ( + DCNotLoggedIn status.BridgeStateErrorCode = "dc-not-logged-in" + DCWebsocketDisconnect4004 status.BridgeStateErrorCode = "dc-websocket-disconnect-4004" + DCUnknownWebsocketError status.BridgeStateErrorCode = "dc-unknown-websocket-error" + DCHTTP40002 status.BridgeStateErrorCode = "dc-http-40002" +) +const accountVerificationRequiredMessage = "You need to verify your account in the Discord app." + +func init() { + status.BridgeStateHumanErrors.Update(status.BridgeStateErrorMap{ + DCWebsocketDisconnect4004: "Please log in to your Discord account again.", + DCNotLoggedIn: "Please log in to your Discord account.", + DCHTTP40002: accountVerificationRequiredMessage, + // (For DCUnknownWebsocketError, provide a specific error message when + // sending state. If there were a generic message here, it would + // overwrite that.) + }) +} + +type DiscordEventMeta struct { + Type bridgev2.RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + route router.Route +} + +func (em *DiscordEventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + if em.LogContext == nil { + return c + } + c = em.LogContext(c) + return c +} + +func (em *DiscordEventMeta) GetType() bridgev2.RemoteEventType { + return em.Type +} + +func (em *DiscordEventMeta) GetPortalKey() networkid.PortalKey { + return em.route.PortalKey +} + +func (em *DiscordEventMeta) PortalReceiverIsUncertain() bool { + return em.route.Uncertain +} + +type DiscordMessage struct { + *DiscordEventMeta + Data *discordgo.Message + Client *DiscordClient + ThreadRootID *networkid.MessageID +} + +func (m *DiscordMessage) ShouldCreatePortal() bool { + // Do not create a portal merely to bridge a message deletion or edit. + return m.Type == bridgev2.RemoteEventMessage +} + +func (m *DiscordMessage) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { + log := zerolog.Ctx(ctx).With(). + Str("action", "convert discord edit").Logger() + ctx = log.WithContext(ctx) + + // FIXME don't redundantly reupload attachments + convertedEdit := m.Client.connector.MsgConv.ToMatrix( + ctx, + portal, + intent, + m.Client.UserLogin, + m.Client.Session, + m.Data, + m.ThreadRootID, + ) + + // TODO this is really gross and relies on how we assign incrementing numeric + // part ids. to return a semantically correct `ConvertedEdit` we should ditch + // this system + slices.SortStableFunc(existing, func(a *database.Message, b *database.Message) int { + ai, _ := strconv.Atoi(string(a.PartID)) + bi, _ := strconv.Atoi(string(b.PartID)) + return ai - bi + }) + + if len(convertedEdit.Parts) != len(existing) { + // FIXME support # of parts changing; triggerable by removing individual + // attachments, etc. + // + // at the very least we can make this better by handling attachments, + // which are always(?) at the end + log.Warn().Int("n_parts_existing", len(existing)).Int("n_parts_after_edit", len(convertedEdit.Parts)). + Msg("Ignoring message edit that changed number of parts") + return nil, bridgev2.ErrIgnoringRemoteEvent + } + + parts := make([]*bridgev2.ConvertedEditPart, 0, len(existing)) + for pi, part := range convertedEdit.Parts { + parts = append(parts, part.ToEditPart(existing[pi])) + } + + return &bridgev2.ConvertedEdit{ + ModifiedParts: parts, + }, nil +} + +var ( + _ bridgev2.RemoteMessage = (*DiscordMessage)(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*DiscordMessage)(nil) + _ bridgev2.RemoteMessageRemove = (*DiscordMessage)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*DiscordMessage)(nil) + _ bridgev2.RemoteEventWithUncertainPortalReceiver = (*DiscordMessage)(nil) + _ bridgev2.RemoteEdit = (*DiscordMessage)(nil) +) + +func (m *DiscordMessage) GetTargetMessage() networkid.MessageID { + return discordid.MakeMessageID(m.Data.ID) +} + +func (m *DiscordMessage) GetTransactionID() networkid.TransactionID { + if m.Data.Nonce == "" { + return "" + } + return networkid.TransactionID(m.Data.Nonce) +} + +func (m *DiscordMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return m.Client.connector.MsgConv.ToMatrix(ctx, portal, intent, m.Client.UserLogin, m.Client.Session, m.Data, m.ThreadRootID), nil +} + +func (m *DiscordMessage) GetID() networkid.MessageID { + return discordid.MakeMessageID(m.Data.ID) +} + +func (m *DiscordMessage) GetSender() bridgev2.EventSender { + if m.Data.Author == nil { + // Message deletions don't have a sender associated with them. + return bridgev2.EventSender{} + } + + return m.Client.makeEventSender(m.Data.Author) +} + +func (d *DiscordClient) wrapDiscordMessage(ctx context.Context, msg *discordgo.Message, route *router.Route, typ bridgev2.RemoteEventType) DiscordMessage { + if msg == nil { + msg = &discordgo.Message{} + } + + return DiscordMessage{ + DiscordEventMeta: &DiscordEventMeta{ + Type: typ, + route: *route, + }, + Data: msg, + Client: d, + ThreadRootID: route.FromThreadRootMessageID(), + } +} + +type DiscordReaction struct { + *DiscordEventMeta + Reaction *discordgo.MessageReaction + Client *DiscordClient + + Emoji string + EmojiID networkid.EmojiID + Extra map[string]any +} + +func (r *DiscordReaction) GetSender() bridgev2.EventSender { + return r.Client.makeEventSenderWithID(r.Reaction.UserID) +} + +func (r *DiscordReaction) GetTargetMessage() networkid.MessageID { + return discordid.MakeMessageID(r.Reaction.MessageID) +} + +func (r *DiscordReaction) GetRemovedEmojiID() networkid.EmojiID { + return r.EmojiID +} + +var ( + _ bridgev2.RemoteReaction = (*DiscordReaction)(nil) + _ bridgev2.RemoteEventWithUncertainPortalReceiver = (*DiscordReaction)(nil) + _ bridgev2.RemoteReactionRemove = (*DiscordReaction)(nil) + _ bridgev2.RemoteReactionWithExtraContent = (*DiscordReaction)(nil) +) + +func (r *DiscordReaction) GetReactionEmoji() (string, networkid.EmojiID) { + return r.Emoji, r.EmojiID +} + +func (r *DiscordReaction) GetReactionExtraContent() map[string]any { + return r.Extra +} + +func (d *DiscordClient) wrapDiscordReaction(ctx context.Context, reaction *discordgo.MessageReaction, route *router.Route, beingAdded bool) (*DiscordReaction, error) { + if reaction == nil { + return nil, nil + } + evtType := bridgev2.RemoteEventReaction + if !beingAdded { + evtType = bridgev2.RemoteEventReactionRemove + } + + var matrixEmoji string + var emojiID string + var extra map[string]any + + if reaction.Emoji.ID != "" { + // A custom emoji. + emojiID = fmt.Sprintf("%s:%s", reaction.Emoji.Name, reaction.Emoji.ID) + shortcode := fmt.Sprintf(":%s:", reaction.Emoji.Name) + + extra = map[string]any{ + "fi.mau.discord.reaction": map[string]any{ + "id": reaction.Emoji.ID, + "name": reaction.Emoji.Name, + // "mxc" is added later if it's `beingAdded`. + }, + "com.beeper.reaction.shortcode": shortcode, + } + + if beingAdded { + reactionMXC, err := d.connector.GetCustomEmojiMXC( + ctx, + reaction.Emoji.ID, + reaction.Emoji.Name, + reaction.Emoji.Animated, + ) + + if err != nil || reactionMXC == "" { + zerolog.Ctx(ctx).Err(err). + Str("emoji_id", reaction.Emoji.ID). + Str("emoji_name", reaction.Emoji.Name). + Msg("Failed to get Matrix MXC for custom emoji reaction being added") + return nil, err + } + + extra["fi.mau.discord.reaction"].(map[string]any)["mxc"] = reactionMXC + + if d.connector.Config.CustomEmojiReactionsEnabled() { + matrixEmoji = string(reactionMXC) + } else { + matrixEmoji = shortcode + } + } + } else { + // A Unicode emoji. + emojiID = reaction.Emoji.Name + matrixEmoji = variationselector.Add(reaction.Emoji.Name) + } + + return &DiscordReaction{ + DiscordEventMeta: &DiscordEventMeta{ + Type: evtType, + route: *route, + }, + Reaction: reaction, + Client: d, + Emoji: matrixEmoji, + EmojiID: discordid.MakeEmojiID(emojiID), + Extra: extra, + }, nil +} + +func (d *DiscordClient) handleDiscordTyping(ctx context.Context, typing *discordgo.TypingStart, route *router.Route) { + if typing.UserID == d.Session.State.User.ID { + return + } + + log := zerolog.Ctx(ctx).With(). + Str("typing_channel_id", typing.ChannelID). + Str("typing_user_id", typing.UserID). + Str("typing_guild_id", typing.GuildID). + Logger() + ctx = log.WithContext(ctx) + + // Make sure we have this user's info in case we haven't seen them at all yet. + _ = d.userCache.Resolve(ctx, typing.UserID) + + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &simplevent.Typing{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventTyping, + PortalKey: route.PortalKey, + Sender: d.makeEventSenderWithID(typing.UserID), + UncertainReceiver: route.Uncertain, + }, + Timeout: 12 * time.Second, + Type: bridgev2.TypingTypeText, + }) +} + +func (d *DiscordClient) handleChannelCreate(ctx context.Context, ch *discordgo.ChannelCreate) error { + log := zerolog.Ctx(ctx).With().Str("channel_id", ch.ID).Logger() + + if ch.GuildID == "" { + log.Debug().Msg("Private channel was created, creating portal") + d.queueChannelResync(ctx, ch.Channel) + } else { + log.Debug().Msg("Guild channel was created") + // FIXME(skip): Sync guild channels. Same logic as syncGuild. + } + + return nil +} + +func (d *DiscordClient) handleChannelUpdate(ctx context.Context, upd *discordgo.ChannelUpdate) error { + if upd.BeforeUpdate == nil { + // Channel doesn't exist in the discordgo's state; don't bother bridging. + return nil + } + + log := zerolog.Ctx(ctx).With().Str("action", "handle channel update").Logger() + ctx = log.WithContext(ctx) + + portalKey := d.portalKeyForChannel(upd.Channel) + portal, err := d.connector.Bridge.GetExistingPortalByKey(ctx, portalKey) + if err != nil { + return fmt.Errorf("failed to look up existing channel: %w", err) + } + if portal == nil { + // Don't bridge updates for channels we haven't actually bridged. + return nil + } + + ts := time.Now() + // Re-use main GetChatInfo logic to avoid drift. The rest of this function + // is mostly removing what didn't change. + patch, err := d.GetChatInfo(ctx, portal) + if err != nil { + return fmt.Errorf("failed to recompute chat info: %w", err) + } + + patch.Type = nil + patch.CanBackfill = false + + old := upd.BeforeUpdate + // People leaving or joining a group DM isn't expressed via CHANNEL_UPDATE. + patch.Members = nil + if upd.Name == old.Name { + patch.Name = nil + } + if upd.Topic == old.Topic { + patch.Topic = nil + } + if upd.Icon == old.Icon { + patch.Avatar = nil + } + if upd.ParentID == old.ParentID { + patch.ParentID = nil + } + + d.UserLogin.QueueRemoteEvent(&simplevent.ChatInfoChange{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatInfoChange, + PortalKey: portalKey, + Timestamp: ts, + }, + ChatInfoChange: &bridgev2.ChatInfoChange{ + ChatInfo: patch, + }, + }) + + return nil +} + +// handleChannelDelete handles a channel being deleted. This can be a guild +// channel getting "actually" deleted or a private channel getting "closed". +func (d *DiscordClient) handleChannelDelete(ctx context.Context, evt *discordgo.ChannelDelete) error { + portalKey := d.portalKeyForChannel(evt.Channel) + log := zerolog.Ctx(ctx).With(). + Str("channel_id", evt.ID). + Str("guild_id", evt.GuildID). + Stringer("deleted_channel_portal_key", portalKey).Logger() + + log.Debug().Msg("Handling channel deletion") + d.queueChatDelete(portalKey, evt.Channel.GuildID) + + return nil +} + +func (d *DiscordClient) queueChatDelete(portalKey networkid.PortalKey, deletedChannelGuildID string) { + ts := time.Now() + + onlyForMe := true + if !d.connector.Bridge.Config.SplitPortals && deletedChannelGuildID != "" { + // When split portals are disabled and a guild channel was deleted, + // then it should be deleted for everyone. + onlyForMe = false + } + + d.UserLogin.QueueRemoteEvent(&simplevent.ChatDelete{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatDelete, + PortalKey: portalKey, + Timestamp: ts, + }, + OnlyForMe: onlyForMe, + // Do not pass Children: true as deleting a guild channel category + // merely detaches the parent_id from all child channels. + // CHANNEL_UPDATE events will be dispatched for all child channels, + // which should reparent them. + }) +} + +func (d *DiscordClient) handleThreadUpdate(ctx context.Context, thread *discordgo.Channel) error { + if thread == nil || !isThread(thread) { + return nil + } + return d.upsertThreadInfoFromChannel(ctx, thread) +} + +func (d *DiscordClient) handleThreadDelete(ctx context.Context, thread *discordgo.Channel) error { + if thread == nil || thread.ID == "" { + return nil + } + return d.connector.DB.Thread.DeleteByThreadChannelID(ctx, string(d.UserLogin.ID), thread.ID) +} + +func (d *DiscordClient) queueIndividualMembershipChange( + ctx context.Context, + portalKey networkid.PortalKey, + user *discordgo.User, + membership event.Membership, + ts time.Time, +) { + log := zerolog.Ctx(ctx) + + userID := discordid.MakeUserID(user.ID) + info := d.getUserInfo(ctx, user) + + log.Debug(). + Stringer("portal_key", portalKey). + Str("moving_user_id", user.ID). + Str("membership", string(membership)). + Msg("Queueing chat info change in response to membership change") + + d.UserLogin.QueueRemoteEvent(&simplevent.ChatInfoChange{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatInfoChange, + PortalKey: portalKey, + Timestamp: ts, + }, + ChatInfoChange: &bridgev2.ChatInfoChange{ + MemberChanges: &bridgev2.ChatMemberList{ + MemberMap: bridgev2.ChatMemberMap{ + userID: bridgev2.ChatMember{ + // TODO: Can't effectively send MemberSender here to + // attribute e.g. someone getting kicked from a group + // DM because that information isn't in the gateway + // payload. Might need to wait for the corresponding + // system message. + EventSender: d.makeEventSender(user), + Membership: membership, + UserInfo: info, + }, + }, + }, + }, + }) +} + +func (d *DiscordClient) handleRecipientAdd(ctx context.Context, evt *discordgo.ChannelRecipientAdd, route *router.Route) error { + d.queueIndividualMembershipChange(ctx, route.PortalKey, evt.User, event.MembershipJoin, time.Now()) + return nil +} + +func (d *DiscordClient) handleRecipientRemove(ctx context.Context, evt *discordgo.ChannelRecipientRemove, route *router.Route) error { + d.queueIndividualMembershipChange(ctx, route.PortalKey, evt.User, event.MembershipLeave, time.Now()) + return nil +} + +func (d *DiscordClient) handleGuildMemberJoinMessage(ctx context.Context, msg *discordgo.Message, route *router.Route) { + ts := msg.Timestamp + if ts.IsZero() { + ts = time.Now() + } + d.queueIndividualMembershipChange(ctx, route.PortalKey, msg.Author, event.MembershipJoin, ts) +} + +func (d *DiscordClient) handleMessageAck(ctx context.Context, ack *discordgo.MessageAck, bridged bool, route *router.Route) { + d.readStatesLock.Lock() + zerolog.Ctx(ctx).Trace(). + Str("channel_id", ack.ChannelID). + Str("message_id", ack.MessageID). + Msg("Updating state with MESSAGE_ACK") + + // TODO: mention_count can appear in MESSAGE_ACK payloads. Update it if it's + // present and not `null`. This needs discordgo changes. (There's even more + // missing fields than this.) + d.readStates[ack.ChannelID] = &discordgo.ReadState{ + ID: ack.ChannelID, + LastMessageID: discordgo.StringOrInt(ack.MessageID), + } + d.readStatesLock.Unlock() + + if bridged { + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &simplevent.Receipt{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventReadReceipt, + PortalKey: route.PortalKey, + Sender: d.selfEventSender(), + UncertainReceiver: route.Uncertain, + }, + LastTarget: discordid.MakeMessageID(ack.MessageID), + }) + } +} + +// channelIsBridged uses routing logic to check whether a portal (with an +// existing room) exists for a given Discord channel ID. +func (d *DiscordClient) channelIsBridged(ctx context.Context, channelID string) (bool, *router.Route) { + log := zerolog.Ctx(ctx) + + route, err := d.Route(ctx, channelID) + if err != nil { + log.Err(err).Msg("Failed to route channel when determining channel bridgedness") + return false, nil + } + existingPortal, err := d.connector.Bridge.GetExistingPortalByKey(ctx, route.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to look up existing portal when determining channel bridgedness") + return false, route + } + return existingPortal != nil && existingPortal.MXID != "", route +} + +func (d *DiscordClient) handleUserGuildSettingsUpdate(ctx context.Context, evt *discordgo.UserGuildSettingsUpdate) { + log := zerolog.Ctx(ctx) + log.Debug().Msg("Handling user guild settings update") + d.applySingleGuildSettings(evt.UserGuildSettings) +} + +func messageCtx(ctx context.Context, msg *discordgo.Message) (context.Context, *zerolog.Logger) { + if msg == nil { + return ctx, zerolog.Ctx(ctx) + } + + wipLog := zerolog.Ctx(ctx).With(). + Str("guild_id", msg.GuildID). + Str("channel_id", msg.ChannelID). + Str("message_id", msg.ID) + if msg.Author != nil { + wipLog = wipLog.Str("author_id", msg.Author.ID). + Bool("author_bot", msg.Author.Bot) + } + if msg.WebhookID != "" { + wipLog = wipLog.Str("webhook_id", msg.WebhookID) + } + log := wipLog.Logger() + + return log.WithContext(ctx), &log +} + +func (d *DiscordClient) handleDiscordStateEvent(rawEvt any) { + ctx := d.UserLogin.Bridge.BackgroundCtx + log := zerolog.Ctx(ctx) + + switch evt := rawEvt.(type) { + case *discordgo.ReadySupplemental: + log.Info(). + Int("n_lazy_private_channels", len(evt.LazyPrivateChannels)). + Msg("Received supplemental READY") + case *discordgo.Ready: + d.rebuildRelationships() + case *discordgo.RelationshipAdd: + d.upsertRelationship(evt.Relationship) + case *discordgo.RelationshipUpdate: + d.upsertRelationship(evt.Relationship) + case *discordgo.RelationshipRemove: + d.removeRelationship(evt.ID) + } +} + +func (d *DiscordClient) handleRelationshipNickChange(ctx context.Context, userID, nickname string) { + ch := d.dmChannelForUserID(userID) + if ch == nil { + return + } + + portalKey := d.portalKeyForChannel(ch) + portal, err := d.connector.Bridge.GetExistingPortalByKey(ctx, portalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to look up DM portal for relationship nick change") + return + } + if portal == nil || portal.MXID == "" { + return + } + + var name *string + if nickname != "" { + name = &nickname + } else { + name = bridgev2.DefaultChatName + } + + d.UserLogin.QueueRemoteEvent(&simplevent.ChatInfoChange{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatInfoChange, + PortalKey: portalKey, + Timestamp: time.Now(), + }, + ChatInfoChange: &bridgev2.ChatInfoChange{ + ChatInfo: &bridgev2.ChatInfo{ + Name: name, + }, + }, + }) +} + +func (d *DiscordClient) handleDiscordEvent(rawEvt any) { + defer func() { + err := recover() + if err == nil { + return + } + + d.UserLogin.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in Discord event handler") + + props := d.baseAnalyticsProps(d.UserLogin.Bridge.BackgroundCtx) + props["eventType"] = fmt.Sprintf("%T", rawEvt) + props["error"] = fmt.Sprint(err) + + d.UserLogin.TrackAnalytics("Discord event handler panic", props) + }() + + log := d.UserLogin.Log.With().Str("action", "handle discord event"). + Type("event_type", rawEvt). + Logger() + ctx := log.WithContext(d.UserLogin.Bridge.BackgroundCtx) + + // NOTE: discordgo seemingly dispatches both the proper unmarshalled type + // (e.g. `*discordgo.TypingStart`) _as well as_ a "raw" *discordgo.Event + // (e.g. `*discordgo.Event` with `Type` of `TYPING_START`) for every gateway + // event. + + // NOTE: We explicitly return early from paths where we would otherwise + // QueueRemoteEvent for a portal that hasn't been bridged by the user yet. + // (Specifically, we check for an extant portal with an associated room.) + // This avoids the eager creation of stub portals that have bogus metadata + // (e.g. GuildID == "" despite being a guild channel). This is because you + // can't specify metadata upfront when a portal is implicitly created. We + // might want to rely on our metadata always being "correct" in the future. + // + // This also helps avoid excessive "Dropping event as portal doesn't exist" + // logs from Mautrix. You receive events for every guild you're in, so this + // can become noisy fast. + + switch evt := rawEvt.(type) { + case *discordgo.Ready: + log.Info(). + Int("n_dms", len(evt.PrivateChannels)). + Int("n_guilds", len(evt.Guilds)). + Int("n_merged_members", len(evt.MergedMembers)). + Int("n_relationships", len(evt.Relationships)). + Int("n_users", len(evt.Users)). + Msg("Received READY dispatch from discordgo") + + d.userCache.UpdateWithReady(evt) + d.syncRemoteProfile(ctx) + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateConnected, + }) + case *discordgo.Resumed: + // (All missed gateway events have been replayed, and all subsequent + // events will be new.) + log.Info().Msg("Received RESUMED dispatch from discordgo") + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateConnected, + }) + case *discordgo.InvalidAuth: + log.Warn().Msg("Got logged out of Discord due to invalid token") + d.tokenInvalidated(ctx, "while connected") + case *discordgo.TypingStart: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + d.handleDiscordTyping(ctx, evt, route) + case *discordgo.GuildCreate: + if evt.Unavailable { + break + } + if err := d.syncGuildRoles(ctx, evt.ID, evt.Roles); err != nil { + log.Err(err).Str("guild_id", evt.ID).Msg("Failed to sync guild roles from guild create event") + } + case *discordgo.GuildUpdate: + if err := d.syncGuildRoles(ctx, evt.ID, evt.Roles); err != nil { + log.Err(err).Str("guild_id", evt.ID).Msg("Failed to sync guild roles from guild update event") + } + case *discordgo.GuildRoleCreate: + roleID := "" + if evt.Role != nil { + roleID = evt.Role.ID + } + if err := d.upsertGuildRole(ctx, evt.GuildID, evt.Role); err != nil { + log.Err(err).Str("guild_id", evt.GuildID).Str("role_id", roleID).Msg("Failed to store role create event") + } + case *discordgo.GuildRoleUpdate: + roleID := "" + if evt.Role != nil { + roleID = evt.Role.ID + } + if err := d.upsertGuildRole(ctx, evt.GuildID, evt.Role); err != nil { + log.Err(err).Str("guild_id", evt.GuildID).Str("role_id", roleID).Msg("Failed to store role update event") + } + case *discordgo.GuildRoleDelete: + if err := d.connector.DB.Role.DeleteByID(ctx, evt.GuildID, evt.RoleID); err != nil { + log.Err(err).Str("guild_id", evt.GuildID).Str("role_id", evt.RoleID).Msg("Failed to delete role from database") + } + case *discordgo.ChannelCreate: + if err := d.handleChannelCreate(ctx, evt); err != nil { + log.Err(err).Msg("Failed to handle channel create") + } + case *discordgo.ChannelUpdate: + bridged, _ := d.channelIsBridged(ctx, evt.ID) + if !bridged { + return + } + err := d.handleChannelUpdate(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to handle channel update") + } + case *discordgo.ChannelDelete: + // The route computed by channelIsBridged will always be uncertain + // because the channel has already disappeared from discordgo's state. + bridged, _ := d.channelIsBridged(ctx, evt.ID) + if !bridged { + return + } + if err := d.handleChannelDelete(ctx, evt); err != nil { + log.Err(err).Msg("Failed to handle channel delete") + } + case *discordgo.ChannelRecipientAdd: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + if err := d.handleRecipientAdd(ctx, evt, route); err != nil { + log.Err(err).Msg("Failed to handle channel recipient add") + } + case *discordgo.ChannelRecipientRemove: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + if err := d.handleRecipientRemove(ctx, evt, route); err != nil { + log.Err(err).Msg("Failed to handle channel recipient remove") + } + case *discordgo.ThreadCreate: + err := d.handleThreadUpdate(ctx, evt.Channel) + if err != nil { + log.Err(err).Str("thread_id", evt.ID).Msg("Failed to handle thread create event") + } + case *discordgo.ThreadUpdate: + err := d.handleThreadUpdate(ctx, evt.Channel) + if err != nil { + log.Err(err).Str("thread_id", evt.ID).Msg("Failed to handle thread update event") + } + case *discordgo.ThreadDelete: + err := d.handleThreadDelete(ctx, evt.Channel) + if err != nil { + log.Err(err).Str("thread_id", evt.ID).Msg("Failed to handle thread delete event") + } + case *discordgo.ThreadListSync: + for _, thread := range evt.Threads { + err := d.handleThreadUpdate(ctx, thread) + if err != nil { + log.Err(err).Str("thread_id", thread.ID).Msg("Failed to handle thread in thread list sync event") + } + } + case *discordgo.MessageCreate: + if evt.Author == nil { + log.Trace().Int("message_type", int(evt.Message.Type)). + Str("guild_id", evt.GuildID). + Str("message_id", evt.ID). + Str("channel_id", evt.ChannelID). + Msg("Dropping message that lacks an author") + return + } + ctx, log := messageCtx(ctx, evt.Message) + inBridgedChannel, route := d.channelIsBridged(ctx, evt.ChannelID) + isDM := route != nil && route.FromChannel != nil && channelIsPrivate(route.FromChannel) + if !inBridgedChannel && !isDM { + if d.connector.Config.LogWhenDroppingMessages { + log.Debug(). + Str("channel_id", evt.ChannelID). + Str("message_id", evt.ID). + Bool("route_uncertain", route != nil && route.Uncertain). + Bool("from_channel_known", route != nil && route.FromChannel != nil). + Bool("from_thread_known", route != nil && route.FromThread != nil). + Msg("Dropping message for non-bridged channel") + } + return + } + + if evt.Message.Type == discordgo.MessageTypeGuildMemberJoin { + d.userCache.UpdateWithMessage(evt.Message) + d.handleGuildMemberJoinMessage(ctx, evt.Message, route) + return + } + + if err := d.upsertThreadInfoFromMessage(ctx, evt.Message); err != nil { + log.Err(err).Msg("Failed to persist thread info from message create") + } + d.userCache.UpdateWithMessage(evt.Message) + + wrappedEvt := d.wrapDiscordMessage(ctx, evt.Message, route, bridgev2.RemoteEventMessage) + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &wrappedEvt) + case *discordgo.MessageUpdate: + ctx, log := messageCtx(ctx, evt.Message) + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + + if err := d.upsertThreadInfoFromMessage(ctx, evt.Message); err != nil { + log.Err(err).Str("message_id", evt.ID).Msg("Failed to persist thread info from message update") + } + + wrappedEvt := d.wrapDiscordMessage(ctx, evt.Message, route, bridgev2.RemoteEventEdit) + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &wrappedEvt) + case *discordgo.UserUpdate: + // The current user changed. (This is not sent out for anyone else.) + log.Info().Msg("Current user was updated") + + // discordgo does not update State.User for us. This is probably a bug. + // Do it ourselves in the meantime. + { + state := d.Session.State + state.Lock() + *d.Session.State.User = *evt.User + state.Unlock() + } + d.userCache.UpdateWithUserUpdate(evt) + + if d.syncRemoteProfile(ctx) { + // Send out a new bridge state so clients immediately get the + // updated profile. + d.UserLogin.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateConnected, + }) + } + case *discordgo.MessageDelete: + ctx, _ := messageCtx(ctx, evt.Message) + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + + wrappedEvt := d.wrapDiscordMessage(ctx, evt.Message, route, bridgev2.RemoteEventMessageRemove) + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &wrappedEvt) + // TODO *discordgo.MessageDeleteBulk + case *discordgo.MessageReactionAdd: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + wrappedEvt, err := d.wrapDiscordReaction(ctx, evt.MessageReaction, route, true) + if err != nil { + log.Err(err).Msg("Dropping incoming reaction due to error") + } else { + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, wrappedEvt) + } + // TODO case *discordgo.MessageReactionRemoveAll: + // TODO case *discordgo.MessageReactionRemoveEmoji: (needs impl. in discordgo) + case *discordgo.MessageReactionRemove: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + if !bridged { + return + } + wrappedEvt, err := d.wrapDiscordReaction(ctx, evt.MessageReaction, route, false) + if err != nil { + log.Err(err).Msg("Dropping incoming reaction removal due to error") + } else { + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, wrappedEvt) + } + // NOTE: Relationship updates are also handled in handleDiscordStateEvent, + // which is synchronously invoked before this one. This is to ensure + // coherence in the face of concurrency, because this method is always + // dispatched on a new goroutine. + case *discordgo.RelationshipAdd: + d.handleRelationshipNickChange(ctx, evt.ID, evt.Nickname) + case *discordgo.RelationshipUpdate: + d.handleRelationshipNickChange(ctx, evt.ID, evt.Nickname) + case *discordgo.RelationshipRemove: + d.handleRelationshipNickChange(ctx, evt.ID, "") + case *discordgo.PresenceUpdate: + return + case *discordgo.MessageAck: + bridged, route := d.channelIsBridged(ctx, evt.ChannelID) + d.handleMessageAck(ctx, evt, bridged, route) + case *discordgo.UserGuildSettingsUpdate: + d.handleUserGuildSettingsUpdate(ctx, evt) + case *discordgo.GuildDelete: + if evt.Unavailable { + log.Warn().Str("guild_id", evt.ID).Msg("Guild became unavailable") + // For now, leave the portals alone if the guild only went away due to an outage. + return + } + if err := d.connector.DB.Role.DeleteByGuildID(ctx, evt.ID); err != nil { + log.Err(err).Str("guild_id", evt.ID).Msg("Failed to delete guild roles from database") + } + d.deleteGuildPortalSpace(ctx, evt.ID) + } +} diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go new file mode 100644 index 0000000..5a75843 --- /dev/null +++ b/pkg/connector/handlematrix.go @@ -0,0 +1,614 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "errors" + "fmt" + "maps" + "math" + "strings" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + + "go.mau.fi/util/ptr" + "go.mau.fi/util/variationselector" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +var ( + _ bridgev2.ReactionHandlingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.RedactionHandlingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.EditHandlingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.ReadReceiptHandlingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.TypingHandlingNetworkAPI = (*DiscordClient)(nil) + _ bridgev2.MuteHandlingNetworkAPI = (*DiscordClient)(nil) +) + +type contextKey int + +const ( + contextKeyChannel contextKey = iota +) + +type SendAttempt struct { + At time.Time + ChannelType discordgo.ChannelType + RecipientRelationshipType *discordgo.RelationshipType +} + +func (d *DiscordClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { + if !d.IsLoggedIn() { + return nil, bridgev2.ErrNotLoggedIn + } + + log := zerolog.Ctx(ctx).With().Str("action", "matrix message send").Logger() + ctx = log.WithContext(ctx) + + portal := msg.Portal + guildID := portal.Metadata.(*discordid.PortalMetadata).GuildID + parentChannelID := discordid.ParseChannelPortalID(portal.ID) + channelID := parentChannelID + threadChannelID := "" + threadRootRemoteID := getMatrixThreadRootRemoteMessageID(msg.ThreadRoot) + + if threadRootRemoteID != "" { + thread, err := d.getThreadByRootMessageID(ctx, threadRootRemoteID) + if err != nil { + return nil, err + } + if thread != nil { + threadChannelID = thread.ThreadChannelID + } else if guildID != "" { + var startErr error + threadChannelID, startErr = d.startThreadFromMatrix(ctx, guildID, parentChannelID, threadRootRemoteID, getThreadName(msg.Content)) + if startErr != nil { + // If creating the thread failed, try resolving it once more in case it already exists. + thread, err = d.getThreadByRootMessageID(ctx, threadRootRemoteID) + if err != nil { + return nil, err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + } else { + return nil, fmt.Errorf("failed to create Discord thread from Matrix message: %w", startErr) + } + } + } + } + if threadChannelID != "" { + channelID = threadChannelID + } + refererOpt := makeDiscordReferer(guildID, parentChannelID, threadChannelID) + + ch := d.channelWithID(ctx, channelID) + ctx = context.WithValue(ctx, contextKeyChannel, ch) + + // Perform any required screening before making any requests to Discord at + // all (message conversion does). + if err := d.screenOutgoingMessage(ctx, ch); err != nil { + return nil, err + } + + sendReq, err := d.connector.MsgConv.ToDiscord(ctx, d.Session, msg, channelID, refererOpt) + if err != nil { + return nil, err + } + + if sendReq.Reference != nil && sendReq.Reference.ChannelID == parentChannelID && threadChannelID != "" { + sendReq.Reference.ChannelID = threadChannelID + } + + if ch != nil { + var relType *discordgo.RelationshipType + if rel := d.relationshipWithDMRecipient(ch); rel != nil { + relType = &rel.Type + } + + if channelIsPrivate(ch) { + // NOTE: These analytics are so that we can get some data on what's + // causing Discord to disable/restrict/ban accounts. For message + // attempts, we only send these for DMs at the moment. + // + // (This fires a goroutine internally so it won't block.) + d.sendOutgoingMessageAttemptAnalytics(ctx, map[string]any{ + "messageFlags": sendReq.Flags, + "messageType": sendReq.Type, + "hasAttachments": len(sendReq.Attachments) > 0, + "hasEmbeds": len(sendReq.Embeds) > 0, + "isReplying": sendReq.Reference != nil && sendReq.Reference.Type == discordgo.MessageReferenceTypeDefault, + }) + } + + d.lastSendAttemptMutex.Lock() + d.lastSendAttempt = &SendAttempt{ + At: time.Now(), + ChannelType: ch.Type, + RecipientRelationshipType: relType, + } + d.lastSendAttemptMutex.Unlock() + } + + sentMsg, err := d.Session.ChannelMessageSendComplex(channelID, sendReq, refererOpt, discordgo.WithContext(ctx)) + if err != nil { + return nil, d.tryWrappingError(ctx, err) + } + sentMsgTimestamp, _ := discordgo.SnowflakeTimestamp(sentMsg.ID) + dbMessage := &database.Message{ + ID: discordid.MakeMessageID(sentMsg.ID), + SenderID: discordid.MakeUserID(sentMsg.Author.ID), + Timestamp: sentMsgTimestamp, + } + if threadRootRemoteID != "" { + dbMessage.ThreadRoot = discordid.MakeMessageID(threadRootRemoteID) + } + + return &bridgev2.MatrixMessageResponse{ + DB: dbMessage, + }, nil +} + +var errCannotDMStranger = errors.New("can't direct message a stranger") + +func (d *DiscordClient) screenOutgoingMessage(ctx context.Context, destCh *discordgo.Channel) error { + log := zerolog.Ctx(ctx) + + if d.connector.Config.ForbidDMingStrangersEnabled() { + dmRecipID := dmChannelRecipientID(destCh) + if dmRecipID != nil { + rel := d.relationshipWithUserID(*dmRecipID) + friendsWithDMRecip := rel != nil && rel.Type == discordgo.RelationshipFriend + + dmRecip := d.userCache.Resolve(ctx, *dmRecipID) + + if dmRecip != nil && !dmRecip.Bot && !friendsWithDMRecip { + loggedRelType := "none" + if rel != nil { + loggedRelType = readableRelationshipType(rel.Type) + } + log.Info(). + Str("relationship_type", loggedRelType). + Msg("Preventing direct message send to a stranger") + + return bridgev2.WrapErrorInStatus(errCannotDMStranger). + WithStatus(event.MessageStatusFail). + WithIsCertain(true). + WithMessage("You can't message users who aren't on your friends list. To continue, use the Discord app to chat or add them as a friend."). + WithSendNotice(true) + } + } + } + + return nil +} + +func (d *DiscordClient) sendOutgoingMessageAttemptAnalytics(ctx context.Context, extra map[string]any) { + props := d.baseAnalyticsProps(ctx) + maps.Copy(props, extra) + + d.UserLogin.TrackAnalytics("Discord outgoing message attempt", props) +} + +func (d *DiscordClient) HandleMatrixEdit(ctx context.Context, msg *bridgev2.MatrixEdit) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + log := zerolog.Ctx(ctx).With().Str("action", "matrix message edit").Logger() + ctx = log.WithContext(ctx) + + content, _ := d.connector.MsgConv.ConvertMatrixMessageContent( + ctx, + msg.Portal, + msg.Content, + // Disregard link previews for now. Discord generally allows you to + // remove individual link previews from a message though. + []string{}, + ) + + guildID := msg.Portal.Metadata.(*discordid.PortalMetadata).GuildID + parentChannelID := discordid.ParseChannelPortalID(msg.Portal.ID) + channelID := parentChannelID + threadChannelID := "" + if msg.EditTarget != nil && msg.EditTarget.ThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, discordid.ParseMessageID(msg.EditTarget.ThreadRoot)) + if err != nil { + return fmt.Errorf("failed to resolve target thread for message edit: %w", err) + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + channelID = threadChannelID + } + } + + _, err := d.Session.ChannelMessageEdit( + channelID, + discordid.ParseMessageID(msg.EditTarget.ID), + content, + makeDiscordReferer(guildID, parentChannelID, threadChannelID), + ) + if err != nil { + return d.tryWrappingError(ctx, err) + } + + return nil +} + +func (d *DiscordClient) PreHandleMatrixReaction(ctx context.Context, reaction *bridgev2.MatrixReaction) (bridgev2.MatrixReactionPreResponse, error) { + if !d.IsLoggedIn() { + return bridgev2.MatrixReactionPreResponse{}, bridgev2.ErrNotLoggedIn + } + + emojiID := reaction.Content.RelatesTo.Key + + // Figure out if this is a custom emoji or not. + if strings.HasPrefix(emojiID, "mxc://") { + customEmoji, err := d.connector.GetCustomEmojiByMXC(ctx, emojiID) + + if err != nil { + return bridgev2.MatrixReactionPreResponse{}, fmt.Errorf("failed to get custom emoji by mxc: %w", err) + } else if customEmoji == nil || customEmoji.ID == "" || customEmoji.Name == "" { + return bridgev2.MatrixReactionPreResponse{}, fmt.Errorf("unknown custom emoji mxc: %q", emojiID) + } + + emojiID = fmt.Sprintf("%s:%s", customEmoji.Name, customEmoji.ID) + } else { + emojiID = variationselector.FullyQualify(emojiID) + } + + return bridgev2.MatrixReactionPreResponse{ + SenderID: discordid.UserLoginIDToUserID(d.UserLogin.ID), + EmojiID: discordid.MakeEmojiID(emojiID), + }, nil +} + +func (d *DiscordClient) HandleMatrixReaction(ctx context.Context, reaction *bridgev2.MatrixReaction) (*database.Reaction, error) { + if !d.IsLoggedIn() { + return nil, bridgev2.ErrNotLoggedIn + } + + portal := reaction.Portal + meta := portal.Metadata.(*discordid.PortalMetadata) + parentChannelID := discordid.ParseChannelPortalID(portal.ID) + channelID := parentChannelID + threadChannelID := "" + if reaction.TargetMessage != nil && reaction.TargetMessage.ThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, discordid.ParseMessageID(reaction.TargetMessage.ThreadRoot)) + if err != nil { + return nil, err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + channelID = threadChannelID + } + } + + return nil, d.tryWrappingError(ctx, d.Session.MessageReactionAddUser( + meta.GuildID, + channelID, + discordid.ParseMessageID(reaction.TargetMessage.ID), + discordid.ParseEmojiID(reaction.PreHandleResp.EmojiID), + makeDiscordReferer(meta.GuildID, parentChannelID, threadChannelID), + )) +} + +func (d *DiscordClient) HandleMatrixReactionRemove(ctx context.Context, removal *bridgev2.MatrixReactionRemove) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + removing := removal.TargetReaction + emojiID := removing.EmojiID + parentChannelID := discordid.ParseChannelPortalID(removal.Portal.ID) + channelID := parentChannelID + threadChannelID := "" + guildID := removal.Portal.Metadata.(*discordid.PortalMetadata).GuildID + targetMessage, err := d.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, d.UserLogin.ID, removing.MessageID) + if err != nil { + return err + } + if targetMessage != nil && targetMessage.ThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, discordid.ParseMessageID(targetMessage.ThreadRoot)) + if err != nil { + return err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + channelID = threadChannelID + } + } + + return d.tryWrappingError(ctx, d.Session.MessageReactionRemoveUser( + guildID, + channelID, + discordid.ParseMessageID(removing.MessageID), + discordid.ParseEmojiID(emojiID), + discordid.ParseUserLoginID(d.UserLogin.ID), + makeDiscordReferer(guildID, parentChannelID, threadChannelID), + )) +} + +func (d *DiscordClient) HandleMatrixMessageRemove(ctx context.Context, removal *bridgev2.MatrixMessageRemove) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + guildID := removal.Portal.Metadata.(*discordid.PortalMetadata).GuildID + parentChannelID := discordid.ParseChannelPortalID(removal.Portal.ID) + channelID := parentChannelID + threadChannelID := "" + if removal.TargetMessage != nil && removal.TargetMessage.ThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, discordid.ParseMessageID(removal.TargetMessage.ThreadRoot)) + if err != nil { + return err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + channelID = threadChannelID + } + } + messageID := discordid.ParseMessageID(removal.TargetMessage.ID) + return d.tryWrappingError(ctx, d.Session.ChannelMessageDelete(channelID, messageID, makeDiscordReferer(guildID, parentChannelID, threadChannelID))) +} + +func (d *DiscordClient) HandleMatrixReadReceipt(ctx context.Context, msg *bridgev2.MatrixReadReceipt) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + log := msg.Portal.Log.With(). + Str("event_id", string(msg.EventID)). + Str("action", "matrix read receipt").Logger() + + guildID := msg.Portal.Metadata.(*discordid.PortalMetadata).GuildID + parentChannelID := discordid.ParseChannelPortalID(msg.Portal.ID) + threadChannelID := "" + threadRootRemoteID := "" + threadID := msg.Receipt.ThreadID + threadScoped := threadID != "" && threadID != event.ReadReceiptThreadMain + + if threadScoped { + rootMsg, err := d.UserLogin.Bridge.DB.Message.GetPartByMXID(ctx, threadID) + if err != nil { + log.Err(err).Msg("Failed to resolve thread root event from receipt") + return err + } else if rootMsg != nil { + threadRootRemoteID = discordid.ParseMessageID(rootMsg.ID) + if rootMsg.ThreadRoot != "" { + threadRootRemoteID = discordid.ParseMessageID(rootMsg.ThreadRoot) + } + thread, err := d.getThreadByRootMessageID(ctx, threadRootRemoteID) + if err != nil { + log.Err(err).Msg("Failed to resolve thread channel from thread root") + return err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + } + } + } + if threadScoped && threadRootRemoteID == "" { + log.Debug().Stringer("receipt_thread_id", threadID).Msg("Dropping thread-scoped read receipt: unknown thread root") + return nil + } + + var targetMessage *database.Message + var targetMessageID string + + // Figure out the ID of the Discord message that we'll mark as read. If the + // receipt didn't exactly correspond with a message, try finding one close + // by to use as the target. + if msg.ExactMessage != nil { + targetMessage = msg.ExactMessage + targetMessageID = discordid.ParseMessageID(targetMessage.ID) + log = log.With(). + Str("message_id", targetMessageID). + Logger() + } else { + var err error + if threadScoped && threadRootRemoteID != "" { + targetMessage, err = d.UserLogin.Bridge.DB.Message.GetLastThreadMessage(ctx, msg.Portal.PortalKey, discordid.MakeMessageID(threadRootRemoteID)) + if err != nil { + log.Err(err).Msg("Failed to find latest thread message") + return err + } + if targetMessage != nil && targetMessage.Timestamp.After(msg.ReadUpTo) { + targetMessage = nil + } + } else { + targetMessage, err = d.UserLogin.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, msg.Portal.PortalKey, msg.ReadUpTo) + if err != nil { + log.Err(err).Msg("Failed to find closest message part") + return err + } + } + + if targetMessage != nil { + // The read receipt didn't specify an exact message but we were able to + // find one close by. + + targetMessageID = discordid.ParseMessageID(targetMessage.ID) + log = log.With(). + Str("closest_message_id", targetMessageID). + Str("closest_event_id", targetMessage.MXID.String()). + Logger() + log.Debug(). + Msg("Read receipt target event not found, using closest message") + } else { + log.Debug().Msg("Dropping read receipt: no messages found") + return nil + } + } + + if threadScoped && targetMessage != nil { + targetMsgThreadRoot := discordid.ParseMessageID(targetMessage.ThreadRoot) + if targetMsgThreadRoot == "" { + targetMsgThreadRoot = discordid.ParseMessageID(targetMessage.ID) + } + if threadRootRemoteID != "" && targetMsgThreadRoot != threadRootRemoteID { + log.Debug(). + Str("receipt_thread_root", threadRootRemoteID). + Str("target_thread_root", targetMsgThreadRoot). + Msg("Dropping read receipt due to thread mismatch") + return nil + } + if threadChannelID == "" && targetMsgThreadRoot != "" { + thread, err := d.getThreadByRootMessageID(ctx, targetMsgThreadRoot) + if err != nil { + return err + } else if thread != nil { + threadChannelID = thread.ThreadChannelID + } + } + } + + channelID := parentChannelID + if threadChannelID != "" { + channelID = threadChannelID + } + resp, err := d.Session.ChannelMessageAckNoToken( + channelID, + targetMessageID, + makeDiscordReferer(guildID, parentChannelID, threadChannelID), + ) + if err != nil { + log.Err(err).Msg("Failed to send read receipt to Discord") + return err + } else if resp.Token != nil { + log.Debug(). + Str("unexpected_resp_token", *resp.Token). + Msg("Marked message as read on Discord (and got unexpected non-nil token)") + } else { + log.Debug().Msg("Marked message as read on Discord") + } + + return nil +} + +func (d *DiscordClient) viewingChannel(ctx context.Context, portal *bridgev2.Portal) error { + if portal.Metadata.(*discordid.PortalMetadata).GuildID != "" { + // Only private channels need this logic. + return nil + } + + d.markedOpenedLock.Lock() + defer d.markedOpenedLock.Unlock() + + channelID := discordid.ParseChannelPortalID(portal.ID) + log := zerolog.Ctx(ctx).With(). + Str("channel_id", channelID).Logger() + + lastMarkedOpenedTs := d.markedOpened[channelID] + if lastMarkedOpenedTs.IsZero() { + d.markedOpened[channelID] = time.Now() + + err := d.Session.MarkViewing(channelID) + + if err != nil { + log.Error().Err(err).Msg("Failed to mark user as viewing channel") + return err + } + + log.Trace().Msg("Marked channel as being viewed") + } else { + log.Trace().Str("channel_id", channelID). + Msg("Already marked channel as viewed, not doing so") + } + + return nil +} + +func (d *DiscordClient) HandleMatrixTyping(ctx context.Context, msg *bridgev2.MatrixTyping) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + log := zerolog.Ctx(ctx) + + // Don't mind if this fails. + _ = d.viewingChannel(ctx, msg.Portal) + + guildID := msg.Portal.Metadata.(*discordid.PortalMetadata).GuildID + channelID := discordid.ParseChannelPortalID(msg.Portal.ID) + err := d.Session.ChannelTyping(channelID, makeDiscordReferer(guildID, channelID, "")) + + if err != nil { + log.Warn().Err(err).Msg("Failed to mark user as typing") + return err + } + + log.Debug().Msg("Marked user as typing") + return nil +} + +func (d *DiscordClient) HandleMute(ctx context.Context, msg *bridgev2.MatrixMute) error { + if !d.IsLoggedIn() { + return bridgev2.ErrNotLoggedIn + } + + channelID := discordid.ParseChannelPortalID(msg.Portal.ID) + log := zerolog.Ctx(ctx).With(). + Str("muting_channel_id", channelID). + Int64("muting_until", msg.Content.MutedUntil). + Logger() + ctx = log.WithContext(ctx) + log.Debug().Msg("Handling Matrix mute") + + ch := d.channelWithID(ctx, channelID) + if ch == nil { + log.Error().Msg("Failed to find channel to mute") + return fmt.Errorf("failed to mute non-existent channel %s", channelID) + } + + mutedUntil := msg.Content.GetMutedUntilTime() + isMuting := mutedUntil.After(time.Now()) + override := discordgo.UserGuildSettingsChannelOverrideEdit{ + Muted: ptr.Ptr(isMuting), + } + if isMuting && mutedUntil != event.MutedForever { + // At the time of writing, arbitrary mute durations are supported by + // Discord; you aren't restricted to the official client's choices + // of 15 minutes, 1 hour, 3 hours, 8 hours, and 24 hours. + secs := int(math.Round(msg.Content.GetMuteDuration().Seconds())) + override.MuteConfig = &discordgo.MuteConfig{ + EndTime: &mutedUntil, + SelectedTimeWindow: &secs, + } + } + + overrides := make(map[string]*discordgo.UserGuildSettingsChannelOverrideEdit) + overrides[ch.ID] = &override + + edit := discordgo.UserGuildSettingsEdit{ + ChannelOverrides: overrides, + } + + log.Debug().Interface("muting_override", override).Msg("Computed channel override for mute") + + guildID := ch.GuildID + if guildID == "" { + // Target private channels properly. + guildID = "@me" + } + _, err := d.Session.UserGuildSettingsEdit(guildID, &edit) + if err != nil { + return fmt.Errorf("failed to edit guild settings in response to mute: %w", err) + + } + return nil +} diff --git a/pkg/connector/id.go b/pkg/connector/id.go new file mode 100644 index 0000000..1dd859f --- /dev/null +++ b/pkg/connector/id.go @@ -0,0 +1,56 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "github.com/bwmarrin/discordgo" + "maunium.net/go/mautrix/bridgev2/networkid" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +func (d *DiscordClient) portalKeyForChannel(ch *discordgo.Channel) networkid.PortalKey { + switch ch.Type { + case discordgo.ChannelTypeDM: + return d.dmChannelPortalKey(ch.ID) + case discordgo.ChannelTypeGroupDM: + return d.groupDMChannelPortalKey(ch.ID) + default: + return d.guildChannelPortalKey(ch.ID) + } +} + +func (d *DiscordClient) guildChannelPortalKey(channelID string) networkid.PortalKey { + wantReceiver := d.connector.Bridge.Config.SplitPortals + return discordid.MakeChannelPortalKey(channelID, d.UserLogin.ID, wantReceiver) +} + +func (d *DiscordClient) groupDMChannelPortalKey(channelID string) networkid.PortalKey { + // Same logic as guild channels (only specify a receiver when split portals + // are enabled). + return d.guildChannelPortalKey(channelID) +} + +func (d *DiscordClient) dmChannelPortalKey(channelID string) networkid.PortalKey { + // 1:1 DMs should _always_ have a receiver. + return discordid.MakeChannelPortalKey(channelID, d.UserLogin.ID, true) +} + +func (d *DiscordClient) guildPortalKey(guildID string) networkid.PortalKey { + wantReceiver := d.connector.Bridge.Config.SplitPortals + return discordid.MakeGuildPortalKey(guildID, d.UserLogin.ID, wantReceiver) +} diff --git a/pkg/connector/login.go b/pkg/connector/login.go new file mode 100644 index 0000000..2322af5 --- /dev/null +++ b/pkg/connector/login.go @@ -0,0 +1,76 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "maunium.net/go/mautrix/bridgev2" +) + +const LoginStepIDComplete = "fi.mau.discord.login.complete" + +func (d *DiscordConnector) GetLoginFlows() []bridgev2.LoginFlow { + return []bridgev2.LoginFlow{ + { + ID: LoginFlowIDBrowser, + Name: "Browser", + Description: "Log in to your Discord account in a web browser.", + }, + { + ID: LoginFlowIDRemoteAuth, + Name: "QR Code", + Description: "Scan a QR code with the Discord mobile app to log in.", + }, + { + ID: LoginFlowIDToken, + Name: "Token", + Description: "Provide a Discord user token to connect with.", + }, + { + ID: LoginFlowIDMachine, + Name: "Email/Phone & Password", + Description: "Log in with an email or phone number and a password. Supports multi-factor authentication (e.g. TOTP, SMS, etc.)", + }, + } +} + +func (d *DiscordConnector) CreateLogin(ctx context.Context, user *bridgev2.User, flowID string) (bridgev2.LoginProcess, error) { + login := DiscordGenericLogin{ + connector: d, + User: user, + } + + switch flowID { + case LoginFlowIDToken: + return &DiscordTokenLogin{DiscordGenericLogin: &login}, nil + case LoginFlowIDRemoteAuth: + return &DiscordRemoteAuthLogin{DiscordGenericLogin: &login}, nil + case LoginFlowIDBrowser: + return &DiscordBrowserLogin{DiscordGenericLogin: &login}, nil + case LoginFlowIDMachine: + mach, err := NewDiscordMachineLogin(ctx, &login) + if err != nil { + return nil, fmt.Errorf("failed to set up discord login machine: %w", err) + } + + return mach, nil + default: + return nil, fmt.Errorf("unknown discord login flow id") + } +} diff --git a/pkg/connector/login_browser.go b/pkg/connector/login_browser.go new file mode 100644 index 0000000..43f2a9e --- /dev/null +++ b/pkg/connector/login_browser.go @@ -0,0 +1,97 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" +) + +const LoginFlowIDBrowser = "token" + +type DiscordBrowserLogin struct { + *DiscordGenericLogin +} + +var _ bridgev2.LoginProcessCookies = (*DiscordBrowserLogin)(nil) + +const ExtractDiscordTokenJS = ` +new Promise((resolve) => { + let mautrixDiscordTokenCheckInterval + + const iframe = document.createElement('iframe') + document.head.append(iframe) + + mautrixDiscordTokenCheckInterval = setInterval(() => { + const token = iframe.contentWindow.localStorage.token + if (token) { + resolve({ token: token.slice(1, -1) }) + clearInterval(mautrixDiscordTokenCheckInterval) + } + }, 200) +}) +` + +func (dl *DiscordBrowserLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeCookies, + StepID: "fi.mau.discord.cookies", + Instructions: "Log in with Discord.", + CookiesParams: &bridgev2.LoginCookiesParams{ + URL: "https://discord.com/login", + UserAgent: "", + Fields: []bridgev2.LoginCookieField{{ + ID: "token", + Required: true, + Sources: []bridgev2.LoginCookieFieldSource{{ + Type: bridgev2.LoginCookieTypeSpecial, + Name: "fi.mau.discord.token", + }}, + }}, + ExtractJS: ExtractDiscordTokenJS, + }, + }, nil +} + +func (dl *DiscordBrowserLogin) SubmitCookies(ctx context.Context, cookies map[string]string) (*bridgev2.LoginStep, error) { + log := zerolog.Ctx(ctx) + + token := cookies["token"] + if token == "" { + log.Error().Msg("Received empty token") + return nil, fmt.Errorf("received empty token") + } + log.Debug().Msg("Logging in with submitted cookie") + + ul, err := dl.FinalizeCreatingLogin(ctx, token) + if err != nil { + return nil, fmt.Errorf("couldn't log in via browser: %w", err) + } + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeComplete, + StepID: LoginStepIDComplete, + Instructions: dl.CompleteInstructions(), + CompleteParams: &bridgev2.LoginCompleteParams{ + UserLoginID: ul.ID, + UserLogin: ul, + }, + }, nil +} diff --git a/pkg/connector/login_generic.go b/pkg/connector/login_generic.go new file mode 100644 index 0000000..b8b3fa7 --- /dev/null +++ b/pkg/connector/login_generic.go @@ -0,0 +1,104 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +// DiscordGenericLogin is embedded within each struct that implements +// bridgev2.LoginProcess in order to encapsulate the common behavior that needs +// to occur after procuring a valid user token. Namely, creating a gateway +// connection to Discord and an associated UserLogin to wrap things up. +// +// It also implements a baseline Cancel method that closes the gateway +// connection. +type DiscordGenericLogin struct { + User *bridgev2.User + connector *DiscordConnector + + Session *discordgo.Session + + // The Discord user we've authenticated as. This is only non-nil if + // a call to FinalizeCreatingLogin has succeeded. + DiscordUser *discordgo.User +} + +func (dl *DiscordGenericLogin) FinalizeCreatingLogin(ctx context.Context, token string) (*bridgev2.UserLogin, error) { + log := zerolog.Ctx(ctx).With().Str("action", "finalize login").Logger() + + // TODO we don't need an entire discordgo session for this as we're just + // interested in /users/@me + log.Info().Msg("Creating initial session with provided token") + session, err := NewDiscordSession(ctx, token) + if err != nil { + return nil, fmt.Errorf("couldn't create discord session: %w", err) + } + dl.Session = session + + log.Info().Msg("Requesting @me with provided token") + self, err := session.User("@me") + if err != nil { + return nil, fmt.Errorf("couldn't request self user (bad credentials?): %w", err) + } + dl.DiscordUser = self + + log.Info().Msg("Fetched @me") + ul, err := dl.User.NewLogin(ctx, &database.UserLogin{ + ID: discordid.MakeUserLoginID(self.ID), + // (This will lack an avatar. Don't want to block login finalization on + // downloading it.) + RemoteProfile: makeRemoteProfile(self, nil), + RemoteName: makeRemoteName(self), + Metadata: &discordid.UserLoginMetadata{ + Token: token, + HeartbeatSession: session.HeartbeatSession, + }, + }, &bridgev2.NewLoginParams{ + DeleteOnConflict: true, + }) + if err != nil { + dl.Cancel() + return nil, fmt.Errorf("couldn't create login during finalization: %w", err) + } + + (ul.Client.(*DiscordClient)).Connect(ctx) + + return ul, nil +} + +func (dl *DiscordGenericLogin) CompleteInstructions() string { + return fmt.Sprintf("Logged in as %s", dl.DiscordUser.Username) +} + +func (dl *DiscordGenericLogin) Cancel() { + if dl.Session != nil { + dl.User.Log.Debug().Msg("Login cancelled, closing session") + err := dl.Session.Close() + if err != nil { + dl.User.Log.Err(err).Msg("Couldn't close Discord session in response to login cancellation") + } + } +} diff --git a/pkg/connector/login_machine.go b/pkg/connector/login_machine.go new file mode 100644 index 0000000..5ac0b58 --- /dev/null +++ b/pkg/connector/login_machine.go @@ -0,0 +1,668 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/bwmarrin/discordgo" + "github.com/google/uuid" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + + "go.mau.fi/mautrix-discord/pkg/discordauth" +) + +const LoginFlowIDMachine = "machine" +const LoginStepIDMachineInitialCreds = "fi.mau.discord.creds" +const LoginStepIDMachineWait = "fi.mau.discord.wait" +const LoginStepIDMachineCaptcha = "fi.mau.discord.captcha" +const LoginStepIDMachineMFAMethod = "fi.mau.discord.mfa.method" +const LoginStepIDMachineMFATOTP = "fi.mau.discord.mfa.totp" +const LoginStepIDMachineMFABackup = "fi.mau.discord.mfa.backup" +const LoginStepIDMachineMFASMS = "fi.mau.discord.mfa.sms" +const InputDataFieldIDUsernameOrPhone = "username_or_phone" +const InputDataFieldIDPassword = "password" +const InputDataFieldIDMFAMethod = "mfa_method" +const InputDataFieldIDMFABackupCode = "backup_code" +const InputDataFieldIDMFASMSCode = "sms_code" +const InputDataFieldIDMFATOTPCode = "totp_code" + +type mfaOption string + +const ( + mfaSms mfaOption = "Text me a code" + mfaTotp mfaOption = "Use my authenticator app" + mfaBackup mfaOption = "Enter a backup code" +) + +// For simplicity, AuthMachine exposes a blocking, "straight-line" API: +// Prepare/Login do not yield intermediate preemption flows. Instead, they +// synchronously call back into our ChallengeHandler methods (e.g. ContinueMFA +// or SolveCaptcha) whenever user input is needed. CAPTCHA handling makes this +// especially awkward, as any request in the flow may be preempted by one or +// more CAPTCHA challenges before the original request can complete. This is +// documented in further detail in the discordauth package. +// +// Anyhow, bridgev2 is the opposite shape: login is step-based and +// request-scoped, and each provisioning request must return a LoginStep before +// its context is canceled. To bridge that mismatch, AuthMachine runs on a +// long-lived background goroutine. That worker emits signals such as "prompt +// the user", "login complete", or "login failed", and DiscordMachineLogin +// translates them into bridgev2 steps. User replies are then forwarded back to +// the worker so the synchronous AuthMachine flow can continue. Channels are +// used to bridge the gap. +// +// In practice, this means returning a dummy DisplayAndWait step to hand +// control back to bridgev2 as our Wait method drains the next signal. +// CAPTCHA challenges reuse this plumbing via LoginStepTypeCookies, dispatching +// through SubmitCookies. + +type DiscordMachineLogin struct { + *DiscordGenericLogin + Machine *discordauth.AuthMachine + + machineCtx context.Context + cancelMachine context.CancelFunc + + currentlyPending *pendingPrompt + currentlyPendingMu sync.Mutex + + signals chan machineSignal +} + +type machineSignal struct { + prompt *pendingPrompt + done *discordauth.LoginCompleted + err error +} +type pendingPrompt struct { + step *bridgev2.LoginStep + reply chan map[string]string +} + +var _ discordauth.ChallengeHandler = (*DiscordMachineLogin)(nil) +var _ bridgev2.LoginProcessUserInput = (*DiscordMachineLogin)(nil) +var _ bridgev2.LoginProcessCookies = (*DiscordMachineLogin)(nil) +var _ bridgev2.LoginProcessDisplayAndWait = (*DiscordMachineLogin)(nil) + +func NewDiscordMachineLogin(ctx context.Context, login *DiscordGenericLogin) (*DiscordMachineLogin, error) { + http := login.User.Bridge.GetHTTPClientSettings().Compile() + + launchSig, err := discordgo.NewVanillaSignature() + if err != nil { + return nil, fmt.Errorf("failed to generate launch signature: %w", err) + } + + personality := discordauth.Personality{ + UserAgent: discordgo.DroidBrowserUserAgent, + Locale: "en-US", + TimeZone: "UTC", + DebugOptions: discordauth.DefaultDebugOptions, + // TODO dedupe with droid.go in discordgo + SuperProperties: discordauth.SuperProperties{ + OS: "Windows", + Browser: "Chrome", + SystemLocale: "en-US", + HasClientMods: false, + BrowserUserAgent: discordgo.DroidBrowserUserAgent, + BrowserVersion: discordgo.DroidBrowserVersion, + OSVersion: "10", + ReleaseChannel: "stable", + ClientBuildNumber: 497254, + ClientLaunchID: uuid.NewString(), + LaunchSignature: launchSig, + ClientAppState: "focused", + }, + ExtraHeaders: map[string]string{ + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-origin", + }, + } + + ml := &DiscordMachineLogin{ + DiscordGenericLogin: login, + } + ml.Machine = discordauth.NewAuthMachine(ctx, http, &personality, ml) + return ml, nil +} + +func (d *DiscordMachineLogin) ContinueMFA(ctx context.Context, challenge *discordauth.MFAChallenge) (*discordauth.MFAContinue, error) { + log := zerolog.Ctx(ctx).With(). + Str("action", "discord machine continue mfa"). + Str("login_instance_id", challenge.LoginInstanceID). + Bool("mfa_required", challenge.MFARequired). + Bool("mfa_sms_enabled", challenge.SMSEnabled). + Bool("mfa_totp_enabled", challenge.TOTPEnabled). + Bool("mfa_backup_codes_accepted", challenge.BackupCodesAccepted). + Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("Entering MFA login flow") + + mfaOptions := make([]string, 0) + // (Reusing the identifier strings for each authenticator method from + // discordauth as the option enumeration values for the user prompt.) + if challenge.SMSEnabled { + mfaOptions = append(mfaOptions, string(mfaSms)) + } + if challenge.TOTPEnabled { + mfaOptions = append(mfaOptions, string(mfaTotp)) + } + if challenge.BackupCodesAccepted { + mfaOptions = append(mfaOptions, string(mfaBackup)) + } + + if len(mfaOptions) == 0 { + return nil, fmt.Errorf("no supported MFA methods available (WebAuthn is unimplemented)") + } + + input, err := d.promptUser(ctx, &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: LoginStepIDMachineMFAMethod, + Instructions: "How do you want to verify it’s you?", + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Type: bridgev2.LoginInputFieldTypeSelect, + ID: InputDataFieldIDMFAMethod, + Name: "Verification Method", + Options: mfaOptions, + }, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to prompt for MFA method: %w", err) + } + + selectedMethod := mfaOption(input[InputDataFieldIDMFAMethod]) + + log = log.With().Str("mfa_selected_method", string(selectedMethod)).Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("User selected MFA method") + + switch selectedMethod { + case mfaBackup: + input, err := d.promptUser(ctx, &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: LoginStepIDMachineMFABackup, + Instructions: "If your authenticator app is unavailable, you can sign in with a backup code. Backup codes are meant for emergencies only.", + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Type: bridgev2.LoginInputFieldTypePassword, + ID: InputDataFieldIDMFABackupCode, + Name: "Backup code", + Description: "You won’t be able to use this backup code again.", + }, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to prompt user for backup code: %w", err) + } + log.Info().Msg("Received backup code from user, proceeding") + + backupCode := strings.TrimSpace(strings.ReplaceAll( + input[InputDataFieldIDMFABackupCode], + "-", + "", + )) + return &discordauth.MFAContinue{ + Type: discordauth.AuthenticatorBackup, + MFAContinuation: discordauth.MFAContinuation{ + MFAState: challenge.MFAState, + Code: backupCode, + }, + }, nil + case mfaTotp: + input, err := d.promptUser(ctx, &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: LoginStepIDMachineMFATOTP, + Instructions: "Enter the code from your authenticator app.", + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Type: bridgev2.LoginInputFieldType2FACode, + ID: InputDataFieldIDMFATOTPCode, + Name: "Authentication code", + // TODO enforce length + Pattern: `^(\d+)$`, + }, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to prompt user for TOTP code: %w", err) + } + log.Info().Msg("Received TOTP code from user, proceeding") + + totpCode := input[InputDataFieldIDMFATOTPCode] + return &discordauth.MFAContinue{ + Type: discordauth.AuthenticatorTOTP, + MFAContinuation: discordauth.MFAContinuation{ + MFAState: challenge.MFAState, + Code: totpCode, + }, + }, nil + case mfaSms: + log.Info().Msg("Requesting SMS from Discord") + _, err := challenge.RequestSMS(ctx) + + if err != nil { + log.Err(err).Msg("Failed to request SMS from Discord") + return nil, fmt.Errorf("failed to ask discord to send SMS: %w", err) + } + log.Info().Msg("Requested SMS from Discord") + + input, err := d.promptUser(ctx, &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: LoginStepIDMachineMFASMS, + Instructions: "Enter the code Discord just texted you.", + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Description: "The code might take a moment to arrive.", + ID: InputDataFieldIDMFASMSCode, + Name: "Verification code", + // TODO enforce length + Pattern: `^(\d+)$`, + Type: bridgev2.LoginInputFieldType2FACode, + }, + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to prompt user for SMS code: %w", err) + } + smsCode := strings.TrimSpace(input[InputDataFieldIDMFASMSCode]) + log.Info().Msg("Received SMS code from user, proceeding") + + return &discordauth.MFAContinue{ + Type: discordauth.AuthenticatorSMS, + MFAContinuation: discordauth.MFAContinuation{ + MFAState: challenge.MFAState, + Code: smsCode, + }, + }, nil + default: + return nil, fmt.Errorf("unknown mfa method %v", selectedMethod) + } +} + +type ExtractionConfig struct { + SiteKey string `json:"siteKey"` + Invisible bool `json:"invisible"` + RqData string `json:"rqdata,omitempty"` +} + +const CaptchaExtractionField = "captcha_token" + +// FIXME: This redirection stub is only necessary to work around some behavior in +// Beeper Desktop where it only attaches the event listeners that dispatch the +// injected JS snippets after the page loads completely. We can't run JavaScript +// "upon load", which is what we really want here. To get around that, we can +// load a small page that merely forces a redirection to the right origin. +// +// (The exact Discord URL we end up at here is mostly irrelevant, but it would +// be nice to avoid loading the actual SPA.) +const captchaRedirectionStub = ` +Loading +` +const captchaExtractionJSTemplate = `new Promise((res0, rej0) => { + if (!window.location.hostname.endsWith('discord.com')) { + return + } + if (window.__meow_captchaPromise) { + window.__meow_captchaPromise.then(res0, rej0) + return + } + + const CFG = %__CONFIG_REPLACEME__% + window.__meow_captchaPromise = new Promise((resolve, reject) => { + window.__meow_h = () => { + const c = document.createElement('div') + c.style.cssText = 'position:fixed;inset:0;z-index:2147483646;' + + 'background:#fff;display:flex;align-items:center;' + + 'justify-content:center;padding:2rem' + document.body.append(c) + + const id = hcaptcha.render(c, { + sitekey: CFG.siteKey, + size: CFG.invisible ? 'invisible' : 'normal', + callback: (token) => resolve({ captcha_token: token }), + 'error-callback': (e) => reject(new Error('hcaptcha: ' + e)), + 'expired-callback': () => reject(new Error('hcaptcha token expired')), + 'chalexpired-callback': () => reject(new Error('hcaptcha challenge expired')), + }) + + if (CFG.rqdata) { + hcaptcha.setData(id, {rqdata: CFG.rqdata}) + } + if (CFG.invisible) { + hcaptcha.execute(id) + } + } + + const s = document.createElement('script') + s.src = 'https://js.hcaptcha.com/1/api.js?render=explicit&onload=__meow_h&recaptchacompat=off' + s.onerror = () => reject(new Error('failed to load hcaptcha')) + document.head.append(s) + }) + + window.__meow_captchaPromise.then(res0, rej0) +})` + +func captchaExtractionJS(cap *discordauth.Captcha) (string, error) { + cfg := ExtractionConfig{ + Invisible: cap.Invisible, + } + if cap.SiteKey != nil { + cfg.SiteKey = *cap.SiteKey + } + if cap.RqData != nil { + cfg.RqData = *cap.RqData + } + + stateJSON, err := json.Marshal(cfg) + if err != nil { + return "", fmt.Errorf("failed to marshal extraction state: %w", err) + } + + return strings.Replace(captchaExtractionJSTemplate, "%__CONFIG_REPLACEME__%", string(stateJSON), 1), nil +} + +func (d *DiscordMachineLogin) SolveCaptcha(ctx context.Context, cap *discordauth.Captcha) (*discordauth.CaptchaSolution, error) { + log := cap.LogContext(zerolog.Ctx(ctx).With()).Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("Encountered CAPTCHA challenge") + + if cap.Service != discordauth.CaptchaServiceHCaptcha { + return nil, fmt.Errorf("%s captchas are currently unsupported", cap.Service) + } + + extractJS, err := captchaExtractionJS(cap) + if err != nil { + return nil, fmt.Errorf("failed to compute captcha extraction JS: %w", err) + } + log.Debug().Str("captcha_js", extractJS).Msg("Computed CAPTCHA solution extraction JS") + + dataURL := "data:text/html;base64," + base64.StdEncoding.EncodeToString([]byte(captchaRedirectionStub)) + + input, err := d.promptUser(ctx, &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeCookies, + StepID: LoginStepIDMachineCaptcha, + Instructions: "Discord is presenting a CAPTCHA challenge.", + CookiesParams: &bridgev2.LoginCookiesParams{ + URL: dataURL, + ExtractJS: extractJS, + Fields: []bridgev2.LoginCookieField{{ + ID: CaptchaExtractionField, + Required: true, + Sources: []bridgev2.LoginCookieFieldSource{{ + Type: bridgev2.LoginCookieTypeSpecial, + Name: CaptchaExtractionField, + }}, + }}, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to prompt user to solve captcha: %w", err) + } + + solutionToken := input[CaptchaExtractionField] + if solutionToken == "" { + return nil, fmt.Errorf("extracted captcha solution is blank") + } + + return &discordauth.CaptchaSolution{ + Solution: solutionToken, + }, nil +} + +func (d *DiscordMachineLogin) Cancel() { + d.DiscordGenericLogin.Cancel() + if d.cancelMachine != nil { + d.cancelMachine() + } +} + +func credsStep() *bridgev2.LoginStep { + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: LoginStepIDMachineInitialCreds, + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Type: bridgev2.LoginInputFieldTypeUsername, + ID: InputDataFieldIDUsernameOrPhone, + Name: "Email or phone number", + }, + { + Type: bridgev2.LoginInputFieldTypePassword, + ID: InputDataFieldIDPassword, + Name: "Password", + }, + }, + }, + } +} + +func waitStep() *bridgev2.LoginStep { + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeDisplayAndWait, + StepID: LoginStepIDMachineWait, + Instructions: "Waiting for Discord…", + DisplayAndWaitParams: &bridgev2.LoginDisplayAndWaitParams{ + Type: bridgev2.LoginDisplayTypeNothing, + }, + } +} + +func (d *DiscordMachineLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { + return credsStep(), nil +} + +func (d *DiscordMachineLogin) SubmitCookies(ctx context.Context, cookies map[string]string) (*bridgev2.LoginStep, error) { + return d.tryDrainingPendingPrompt(ctx, cookies), nil +} + +func (d *DiscordMachineLogin) SubmitUserInput(ctx context.Context, input map[string]string) (*bridgev2.LoginStep, error) { + log := zerolog.Ctx(ctx) + + // User input was submitted as part of a prompt that the worker signaled to + // us. + step := d.tryDrainingPendingPrompt(ctx, input) + if step != nil { + return step, nil + } + + // Initial submission of the username/phone and password. + username := strings.TrimSpace(input[InputDataFieldIDUsernameOrPhone]) + password := discordauth.NewSensitive(input[InputDataFieldIDPassword]) + if username == "" { + return nil, fmt.Errorf("no username provided") + } + if password.IsZero() { + return nil, fmt.Errorf("no password provided") + } + + log.Info().Msg("Starting worker goroutine") + err := d.startWorker(ctx, &discordauth.Creds{ + Login: username, + Password: password, + }) + if err != nil { + return nil, fmt.Errorf("failed to start login worker: %w", err) + } + + return waitStep(), nil +} + +func (d *DiscordMachineLogin) tryDrainingPendingPrompt(ctx context.Context, input map[string]string) *bridgev2.LoginStep { + log := zerolog.Ctx(ctx) + + d.currentlyPendingMu.Lock() + // (Avoid holding the mutex across the channel send.) + pending := d.currentlyPending + d.currentlyPending = nil + d.currentlyPendingMu.Unlock() + + if pending == nil { + log.Debug().Msg("No pending prompt") + return nil + } + + log.Info().Str("pending_step_id", pending.step.StepID). + Msg("Received user input for pending step ID, sending reply") + pending.reply <- input + + // Go back to waiting for the worker to send a signal. + return waitStep() +} + +func (d *DiscordMachineLogin) startWorker(ctx context.Context, creds *discordauth.Creds) error { + // Act as a sort of "mailbox"; only buffer 1 signal at a time. Not + // unbuffered because it wouldn't be ideal to block the worker goroutine on + // waiting for the signal to be "consumed" per se. + d.signals = make(chan machineSignal, 1) + + // Don't want ourselves to get cancelled if the enclosing context does, but + // we do want to preserve the data inside of the context (such as logging + // stuff). + // + // Also, shadow the original context to avoid using it by accident. + ctx, d.cancelMachine = context.WithCancel(context.WithoutCancel(ctx)) + d.machineCtx = ctx + + go func() { + // It's important that these calls occur on a goroutine because + // AuthMachine methods can call into our handlers (e.g. ContinueMFA), + // which need to synchronously prompt the user, and we need both sides + // of the reply/signal channels to work in order to avoid a deadlock. + + err := d.Machine.Prepare(ctx) + if err != nil { + err = fmt.Errorf("failed to prepare login: %w", err) + _ = d.signal(d.machineCtx, machineSignal{err: err}) + return + } + + done, err := d.Machine.Login(ctx, creds) + log := zerolog.Ctx(ctx) + if err == nil { + log.Info(). + Any("required_actions", done.RequiredActions). + Msg("Login finished") + } else { + // FIXME detect bad password/username and just retry the step + // instead of failing out + log.Err(err).Msg("Login failed") + } + + // At the moment this can only error if we get canceled, and we don't + // really care about that here. Just signal so we can tell bridgev2. + _ = d.signal(d.machineCtx, machineSignal{done: done, err: err}) + }() + + return nil +} + +// signal should only be called by the background goroutine, and is used to +// control the bridgev2 login process. +func (d *DiscordMachineLogin) signal(ctx context.Context, sig machineSignal) error { + select { + case d.signals <- sig: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// promptUser should only be called by the background goroutine, and is used to +// send a [bridgev2.LoginStep] to be presented to the user. The submitted +// inputs are collected via channel and returned. +func (d *DiscordMachineLogin) promptUser(ctx context.Context, step *bridgev2.LoginStep) (map[string]string, error) { + reply := make(chan map[string]string, 1) + pending := &pendingPrompt{step, reply} + if err := d.signal(ctx, machineSignal{prompt: pending}); err != nil { + return nil, err + } + + select { + case input, ok := <-pending.reply: + if !ok { + return nil, context.Canceled + } + return input, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (d *DiscordMachineLogin) finalize(ctx context.Context, done *discordauth.LoginCompleted) (*bridgev2.LoginStep, error) { + ul, err := d.FinalizeCreatingLogin(ctx, done.Token.UnwrapSensitive()) + if err != nil { + return nil, fmt.Errorf("couldn't log in via machine: %w", err) + } + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeComplete, + StepID: LoginStepIDComplete, + CompleteParams: &bridgev2.LoginCompleteParams{ + UserLoginID: ul.ID, + UserLogin: ul, + }, + }, nil +} + +func (d *DiscordMachineLogin) Wait(ctx context.Context) (*bridgev2.LoginStep, error) { + select { + case signal := <-d.signals: + if signal.err != nil { + return nil, signal.err + } + + if signal.done != nil { + return d.finalize(ctx, signal.done) + } + + // Sanity check. + if signal.prompt == nil { + return nil, fmt.Errorf("unexpected empty prompt") + } + + // Stash the prompt that we're about to show to the user so that we + // can properly reply when mautrix calls our SubmitUserInput method. + d.currentlyPendingMu.Lock() + d.currentlyPending = signal.prompt + d.currentlyPendingMu.Unlock() + + return signal.prompt.step, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/pkg/connector/login_remoteauth.go b/pkg/connector/login_remoteauth.go new file mode 100644 index 0000000..e116090 --- /dev/null +++ b/pkg/connector/login_remoteauth.go @@ -0,0 +1,141 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + + "go.mau.fi/mautrix-discord/pkg/remoteauth" +) + +const LoginFlowIDRemoteAuth = "qr" + +type DiscordRemoteAuthLogin struct { + *DiscordGenericLogin + + hasClosed bool + remoteAuthClient *remoteauth.Client + qrChan chan string + doneChan chan struct{} +} + +var _ bridgev2.LoginProcessDisplayAndWait = (*DiscordRemoteAuthLogin)(nil) + +func (dl *DiscordRemoteAuthLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { + log := zerolog.Ctx(ctx) + + log.Debug().Msg("Creating new remoteauth client") + client, err := remoteauth.New() + if err != nil { + return nil, fmt.Errorf("couldn't create Discord remoteauth client: %w", err) + } + + dl.remoteAuthClient = client + + dl.qrChan = make(chan string) + dl.doneChan = make(chan struct{}) + + log.Info().Msg("Starting the QR code login process") + err = client.Dial(ctx, dl.qrChan, dl.doneChan) + if err != nil { + log.Err(err).Msg("Couldn't connect to Discord remoteauth websocket") + close(dl.qrChan) + close(dl.doneChan) + return nil, fmt.Errorf("couldn't connect to Discord remoteauth websocket: %w", err) + } + + log.Info().Msg("Waiting for QR code to be ready") + + select { + case qrCode := <-dl.qrChan: + log.Info().Int("qr_code_data_len", len(qrCode)).Msg("Received QR code, creating login step") + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeDisplayAndWait, + StepID: "fi.mau.discord.qr", + Instructions: "On your phone, find “Scan QR Code” in Discord’s settings.", + DisplayAndWaitParams: &bridgev2.LoginDisplayAndWaitParams{ + Type: bridgev2.LoginDisplayTypeQR, + Data: qrCode, + }, + }, nil + case <-ctx.Done(): + log.Debug().Msg("Cancelled while waiting for QR code") + return nil, nil + } +} + +// Wait implements bridgev2.LoginProcessDisplayAndWait. +func (dl *DiscordRemoteAuthLogin) Wait(ctx context.Context) (*bridgev2.LoginStep, error) { + if dl.doneChan == nil { + panic("can't wait for discord remoteauth without a doneChan") + } + + log := zerolog.Ctx(ctx) + + log.Debug().Msg("Waiting for remoteauth") + select { + case <-dl.doneChan: + user, err := dl.remoteAuthClient.Result() + if err != nil { + log.Err(err).Msg("Discord remoteauth failed") + return nil, fmt.Errorf("discord remoteauth failed: %w", err) + } + log.Debug().Msg("Discord remoteauth succeeded") + + return dl.finalizeSuccessfulLogin(ctx, user) + case <-ctx.Done(): + log.Debug().Msg("Cancelled while waiting for remoteauth to complete") + return nil, nil + } +} + +func (dl *DiscordRemoteAuthLogin) finalizeSuccessfulLogin(ctx context.Context, user remoteauth.User) (*bridgev2.LoginStep, error) { + ul, err := dl.FinalizeCreatingLogin(ctx, user.Token) + if err != nil { + return nil, fmt.Errorf("couldn't log in via remoteauth: %w", err) + } + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeComplete, + StepID: LoginStepIDComplete, + Instructions: dl.CompleteInstructions(), + CompleteParams: &bridgev2.LoginCompleteParams{ + UserLoginID: ul.ID, + UserLogin: ul, + }, + }, nil +} + +func (dl *DiscordRemoteAuthLogin) Cancel() { + // Tolerate multiple attempts to cancel. + if dl.hasClosed { + return + } + dl.hasClosed = true + + dl.User.Log.Debug().Msg("Discord remoteauth cancelled") + dl.DiscordGenericLogin.Cancel() + + // remoteauth.Client doesn't seem to expose a cancellation method. + close(dl.doneChan) + close(dl.qrChan) +} diff --git a/pkg/connector/login_token.go b/pkg/connector/login_token.go new file mode 100644 index 0000000..1d234ae --- /dev/null +++ b/pkg/connector/login_token.go @@ -0,0 +1,72 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "maunium.net/go/mautrix/bridgev2" +) + +const LoginFlowIDToken = "DEBUG_USERINPUT_token" + +type DiscordTokenLogin struct { + *DiscordGenericLogin +} + +var _ bridgev2.LoginProcessUserInput = (*DiscordTokenLogin)(nil) + +func (dl *DiscordTokenLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeUserInput, + StepID: "fi.mau.discord.enter_token", + UserInputParams: &bridgev2.LoginUserInputParams{ + Fields: []bridgev2.LoginInputDataField{ + { + Type: bridgev2.LoginInputFieldTypePassword, + ID: "token", + Name: "Discord user account token", + // Cribbed from https://regex101.com/r/1GMR0y/1. + Pattern: `^(mfa\.[a-zA-Z0-9_-]{20,})|([a-zA-Z0-9_-]{23,}\.[a-zA-Z0-9_-]{6,7}\.[a-zA-Z0-9_-]{27,})$`, + }, + }, + }, + }, nil +} + +func (dl *DiscordTokenLogin) SubmitUserInput(ctx context.Context, input map[string]string) (*bridgev2.LoginStep, error) { + token := input["token"] + if token == "" { + return nil, fmt.Errorf("no token provided") + } + + ul, err := dl.FinalizeCreatingLogin(ctx, token) + if err != nil { + return nil, fmt.Errorf("couldn't login from token: %w", err) + } + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeComplete, + StepID: LoginStepIDComplete, + Instructions: dl.CompleteInstructions(), + CompleteParams: &bridgev2.LoginCompleteParams{ + UserLoginID: ul.ID, + UserLogin: ul, + }, + }, nil +} diff --git a/pkg/connector/provisioning.go b/pkg/connector/provisioning.go new file mode 100644 index 0000000..ea4c6fb --- /dev/null +++ b/pkg/connector/provisioning.go @@ -0,0 +1,464 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "go.mau.fi/util/exhttp" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +const ( + ErrCodeNotConnected = "FI.MAU.DISCORD.NOT_CONNECTED" + ErrCodeAlreadyLoggedIn = "FI.MAU.DISCORD.ALREADY_LOGGED_IN" + ErrCodeAlreadyConnected = "FI.MAU.DISCORD.ALREADY_CONNECTED" + ErrCodeConnectFailed = "FI.MAU.DISCORD.CONNECT_FAILED" + ErrCodeDisconnectFailed = "FI.MAU.DISCORD.DISCONNECT_FAILED" + ErrCodeGuildBridgeFailed = "M_UNKNOWN" + ErrCodeGuildUnbridgeFailed = "M_UNKNOWN" + ErrCodeGuildNotBridged = "FI.MAU.DISCORD.GUILD_NOT_BRIDGED" + ErrCodeLoginPrepareFailed = "FI.MAU.DISCORD.LOGIN_PREPARE_FAILED" + ErrCodeLoginConnectionFailed = "FI.MAU.DISCORD.LOGIN_CONN_FAILED" + ErrCodeLoginFailed = "FI.MAU.DISCORD.LOGIN_FAILED" + ErrCodePostLoginConnFailed = "FI.MAU.DISCORD.POST_LOGIN_CONNECTION_FAILED" +) + +type ProvisioningAPI struct { + log zerolog.Logger + connector *DiscordConnector + prov bridgev2.IProvisioningAPI +} + +func (d *DiscordConnector) setUpProvisioningAPIs() error { + c, ok := d.Bridge.Matrix.(bridgev2.MatrixConnectorWithProvisioning) + if !ok { + return errors.New("matrix connector doesn't support provisioning; not setting up") + } + + prov := c.GetProvisioning() + r := prov.GetRouter() + if r == nil { + return errors.New("matrix connector's provisioning api didn't return a router") + } + + log := d.Bridge.Log.With().Str("component", "provisioning").Logger() + p := &ProvisioningAPI{ + connector: d, + log: log, + prov: prov, + } + + // NOTE: aim to provide backwards compatibility with v1 provisioning APIs + r.HandleFunc("POST /v1/login/token", p.legacyTokenLogin) + r.HandleFunc("GET /v1/ping", p.legacyPing) + r.HandleFunc("POST /v1/logout", p.legacyLogout) + r.HandleFunc("GET /v1/guilds", p.makeHandler(p.guildsList, true)) + r.HandleFunc("POST /v1/guilds/{guildID}", p.makeHandler(p.bridgeGuild, true)) + // Unbridging doesn't touch discordgo, so it's okay to do it even when + // logged out. + r.HandleFunc("DELETE /v1/guilds/{guildID}", p.makeHandler(p.unbridgeGuild, false)) + + return nil +} + +type provHandler func(http.ResponseWriter, *http.Request, *bridgev2.UserLogin, *DiscordClient) + +func (p *ProvisioningAPI) makeHandler(handler provHandler, enforceLoggedIn bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := p.prov.GetUser(r) + + logins := user.GetUserLogins() + if len(logins) < 1 { + mautrix.RespError{ + ErrCode: ErrCodeNotConnected, + Err: "user has no logins", + }.Write(w) + return + } + + login := logins[0] + client := login.Client.(*DiscordClient) + + if !client.IsLoggedIn() && enforceLoggedIn { + mautrix.RespError{ + ErrCode: ErrCodeNotConnected, + Err: "not logged in to discord", + }.Write(w) + return + } + + handler(w, r, login, client) + } +} + +type guildEntry struct { + ID string `json:"id"` + Name string `json:"name"` + // TODO v1 uses `id.ContentURI` whereas we stuff the discord cdn url here + AvatarURL string `json:"avatar_url"` + + // new in v2: + Bridged bool `json:"bridged"` + Available bool `json:"available"` + + // legacy fields from v1: + MXID string `json:"mxid"` + AutoBridge bool `json:"auto_bridge_channels"` + BridgingMode string `json:"bridging_mode"` +} +type respGuildsList struct { + Guilds []guildEntry `json:"guilds"` +} + +func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) { + ctx := r.Context() + p.log.Info().Str("login_id", discordid.ParseUserLoginID(login.ID)).Msg("guilds list requested via provisioning api") + + bridgedGuildIDs := client.bridgedGuildIDs() + + var resp respGuildsList + resp.Guilds = []guildEntry{} + for _, guild := range client.Session.State.Guilds { + portalKey := client.guildPortalKey(guild.ID) + portal, err := p.connector.Bridge.GetExistingPortalByKey(ctx, portalKey) + if err != nil { + p.log.Err(err). + Str("guild_id", guild.ID). + Msg("Failed to get guild portal for provisioning list") + } + + _, beingBridged := bridgedGuildIDs[guild.ID] + mxid := "" + if portal != nil && portal.MXID != "" { + mxid = portal.MXID.String() + } else if beingBridged { + // Beeper Desktop expects the space to exist by the time it receives + // our HTTP response. If it doesn't, then the space won't appear + // until the app is reloaded, and the toggle in the user interface + // won't respond to the user's click. + // + // Pre-bridgev2, we synchronously bridged guilds. However, this + // might take a while for guilds with many channels. + // + // To solve this, generate a deterministic room ID to use as the + // MXID so that it recognizes the guild as bridged, even if the + // portals haven't been created just yet. This lets us + // asynchronously bridge guilds while keeping the UI responsive. + mxid = p.connector.Bridge.Matrix.GenerateDeterministicRoomID(portalKey).String() + } + + resp.Guilds = append(resp.Guilds, guildEntry{ + // For now, have the ID exactly correspond to the portal ID. This + // practically means that the ID will begin with an asterisk (the + // guild portal ID sigil). + // + // Otherwise, Beeper Desktop will show a duplicate space for every + // guild, as it recognizes the guild returned from this HTTP + // endpoint and the actual space itself as separate "entities". + // (Despite this, they point to identical rooms.) + ID: string(discordid.MakeGuildPortalIDWithID(guild.ID)), + Name: guild.Name, + AvatarURL: discordgo.EndpointGuildIcon(guild.ID, guild.Icon), + Bridged: beingBridged, + Available: !guild.Unavailable, + + // v1 (legacy) backwards compat: + MXID: mxid, + AutoBridge: beingBridged, + BridgingMode: "everything", + }) + } + + exhttp.WriteJSONResponse(w, 200, resp) +} + +// normalizeGuildID removes the guild portal sigil from a guild ID if it's +// there. +// +// This helps facilitate code that would like to accept portal keys +// corresponding to guilds as well as plain Discord guild IDs. +func normalizeGuildID(guildID string) string { + return strings.TrimPrefix(guildID, discordid.GuildPortalKeySigil) +} + +// collectAllGuildPortals fetches all portals associated with a guild. This +// includes the guild space portal itself as well as child portals inside of +// portals that represent guild category channels. +// +// The order of the returned slice is undefined. +func (p *ProvisioningAPI) collectAllGuildPortals(ctx context.Context, guild *bridgev2.Portal) ([]*bridgev2.Portal, error) { + if guild == nil { + return nil, nil + } + + // Fetch all top-level channels and category channels. + children, err := p.connector.Bridge.GetChildPortals(ctx, guild.PortalKey) + if err != nil { + return nil, err + } + + portals := make([]*bridgev2.Portal, 0, 1+len(children)) + portals = append(portals, guild) + portals = append(portals, children...) + + // Fetch channels that are inside of categories. + for _, child := range children { + grandchildren, err := p.connector.Bridge.GetChildPortals(ctx, child.PortalKey) + if err != nil { + return nil, err + } + portals = append(portals, grandchildren...) + } + + return portals, nil +} + +func (p *ProvisioningAPI) bridgeGuild(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) { + guildID := normalizeGuildID(r.PathValue("guildID")) + if guildID == "" { + mautrix.MInvalidParam.WithMessage("no guild id").Write(w) + return + } + + p.log.Info(). + Str("login_id", discordid.ParseUserLoginID(login.ID)). + Str("guild_id", guildID). + Msg("requested to bridge guild via provisioning api") + + meta := login.Metadata.(*discordid.UserLoginMetadata) + + if meta.BridgedGuildIDs == nil { + meta.BridgedGuildIDs = map[string]bool{} + } + _, alreadyBridged := meta.BridgedGuildIDs[guildID] + meta.BridgedGuildIDs[guildID] = true + + if err := login.Save(r.Context()); err != nil { + p.log.Err(err).Msg("Failed to save login after guild bridge request") + mautrix.MUnknown.WithMessage("failed to save login: %v", err).Write(w) + return + } + + go client.syncGuild(p.connector.Bridge.BackgroundCtx, guildID) + + responseStatus := 201 + if alreadyBridged { + responseStatus = 200 + } + exhttp.WriteJSONResponse(w, responseStatus, nil) +} + +// Legacy v1 provisioning endpoints for backwards compatibility with clients +// that haven't migrated to the bridgev2 provisioning API yet. + +func (p *ProvisioningAPI) legacyTokenLogin(w http.ResponseWriter, r *http.Request) { + user := p.prov.GetUser(r) + + if logins := user.GetUserLogins(); len(logins) > 0 { + for _, login := range logins { + client := login.Client.(*DiscordClient) + if client.HasToken() { + exhttp.WriteJSONResponse(w, http.StatusConflict, mautrix.RespError{ + ErrCode: ErrCodeAlreadyLoggedIn, + Err: "already logged in to Discord", + }) + return + } + } + } + + var body struct { + Token string `json:"token"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + exhttp.WriteJSONResponse(w, http.StatusBadRequest, mautrix.RespError{ + ErrCode: mautrix.MBadJSON.ErrCode, + Err: "failed to parse request body", + }) + return + } + if body.Token == "" { + exhttp.WriteJSONResponse(w, http.StatusBadRequest, mautrix.RespError{ + ErrCode: mautrix.MBadJSON.ErrCode, + Err: "missing token", + }) + return + } + + login, err := p.connector.CreateLogin(r.Context(), user, LoginFlowIDToken) + if err != nil { + p.log.Err(err).Msg("Failed to create login process") + exhttp.WriteJSONResponse(w, http.StatusInternalServerError, mautrix.RespError{ + ErrCode: ErrCodeLoginPrepareFailed, + Err: "failed to prepare login", + }) + return + } + _, err = login.Start(r.Context()) + if err != nil { + p.log.Err(err).Msg("Failed to start login process") + exhttp.WriteJSONResponse(w, http.StatusInternalServerError, mautrix.RespError{ + ErrCode: ErrCodeLoginPrepareFailed, + Err: "failed to start login", + }) + return + } + _, err = login.(bridgev2.LoginProcessUserInput).SubmitUserInput(r.Context(), map[string]string{ + "token": body.Token, + }) + if err != nil { + p.log.Err(err).Msg("Failed to submit token") + exhttp.WriteJSONResponse(w, http.StatusUnauthorized, mautrix.RespError{ + ErrCode: ErrCodePostLoginConnFailed, + Err: "failed to connect to Discord", + }) + return + } + + discordUser := login.(*DiscordTokenLogin).DiscordUser + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{ + "success": true, + "id": discordUser.ID, + "username": discordUser.Username, + "discriminator": discordUser.Discriminator, + }) +} + +func (p *ProvisioningAPI) legacyPing(w http.ResponseWriter, r *http.Request) { + user := p.prov.GetUser(r) + + resp := map[string]any{ + "mxid": user.MXID, + "management_room": user.ManagementRoom, + } + + discord := map[string]any{ + "logged_in": false, + "connected": false, + } + + if logins := user.GetUserLogins(); len(logins) > 0 { + login := logins[0] + client := login.Client.(*DiscordClient) + discord["id"] = discordid.ParseUserLoginID(login.ID) + discord["logged_in"] = client.HasToken() + discord["connected"] = client.IsLoggedIn() + if client.Session != nil { + discord["conn"] = map[string]any{ + "last_heartbeat_ack": client.Session.LastHeartbeatAck.UnixMilli(), + "last_heartbeat_sent": client.Session.LastHeartbeatSent.UnixMilli(), + } + } + } + + resp["Discord"] = discord + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +func (p *ProvisioningAPI) legacyLogout(w http.ResponseWriter, r *http.Request) { + user := p.prov.GetUser(r) + logins := user.GetUserLogins() + if len(logins) == 0 { + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{ + "success": true, + "status": "not logged in", + }) + return + } + logins[0].Logout(r.Context()) + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{ + "success": true, + "status": "logged out successfully", + }) +} + +func (p *ProvisioningAPI) unbridgeGuild(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) { + guildID := normalizeGuildID(r.PathValue("guildID")) + if guildID == "" { + mautrix.MInvalidParam.WithMessage("no guild id").Write(w) + return + } + + log := p.log.With(). + Str("login_id", discordid.ParseUserLoginID(login.ID)). + Str("guild_id", guildID). + Str("action", "unbridge guild"). + Logger() + ctx := log.WithContext(r.Context()) + + log.Info().Msg("Unbridging guild via provisioning API") + + // Immediately record user intent by committing the change to UserLogin + // metadata, even if the portal deletion we're about to attempt fails. + meta := login.Metadata.(*discordid.UserLoginMetadata) + if meta.BridgedGuildIDs != nil { + delete(meta.BridgedGuildIDs, guildID) + } + if err := login.Save(ctx); err != nil { + log.Err(err).Msg("Failed to save login after guild unbridge request") + mautrix.MUnknown.WithMessage("failed to save login: %v", err).Write(w) + return + } + + portalKey := client.guildPortalKey(guildID) + guildPortal, err := p.connector.Bridge.GetExistingPortalByKey(ctx, portalKey) + if err != nil { + log.Err(err).Msg("Failed to get guild portal") + mautrix.MUnknown.WithMessage("failed to get portal: %v", err).Write(w) + return + } + if guildPortal == nil || guildPortal.MXID == "" { + mautrix.RespError{ + ErrCode: ErrCodeGuildNotBridged, + Err: "guild is not bridged", + }.Write(w) + return + } + + deletingPortals, err := p.collectAllGuildPortals(ctx, guildPortal) + if err != nil { + log.Err(err).Msg("Failed to collect portal subtree for deletion") + mautrix.MUnknown.WithMessage("failed to collect portal subtree for deletion: %v", err).Write(w) + return + } + // DeleteManyPortals will sort by depth for us so children get deleted + // before their parents. + bridgev2.DeleteManyPortals(ctx, deletingPortals, func(portal *bridgev2.Portal, del bool, err error) { + log.Err(err). + Stringer("portal_mxid", portal.MXID). + Bool("delete_room", del). + Msg("Failed during portal cleanup") + }) + + log.Info(). + Int("deleted_portals", len(deletingPortals)). + Msg("Finished unbridging") + exhttp.WriteJSONResponse(w, 200, map[string]any{ + "success": true, + "deleted_portals": len(deletingPortals), + }) +} diff --git a/pkg/connector/role.go b/pkg/connector/role.go new file mode 100644 index 0000000..6538041 --- /dev/null +++ b/pkg/connector/role.go @@ -0,0 +1,106 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/bwmarrin/discordgo" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" +) + +// (Used by formatter_tag.go via an interface.) +func (d *DiscordConnector) GetRoleByID(ctx context.Context, guildID, roleID string) (*discorddb.Role, error) { + return d.DB.Role.GetByID(ctx, guildID, roleID) +} + +func guildRoleChanged(oldRole *discorddb.Role, newRole *discordgo.Role) bool { + return oldRole.Name != newRole.Name || + oldRole.Icon != newRole.Icon || + oldRole.Mentionable != newRole.Mentionable || + oldRole.Managed != newRole.Managed || + oldRole.Hoist != newRole.Hoist || + oldRole.Color != newRole.Color || + oldRole.Position != newRole.Position || + oldRole.Permissions != newRole.Permissions +} + +func (d *DiscordClient) syncGuildRoles(ctx context.Context, guildID string, roles []*discordgo.Role) error { + if len(roles) == 0 { + return nil + } + + existingRoles, err := d.connector.DB.Role.GetByGuildID(ctx, guildID) + if err != nil { + return fmt.Errorf("failed to get existing guild roles: %w", err) + } + + existingRoleMap := make(map[string]*discorddb.Role, len(existingRoles)) + for _, role := range existingRoles { + existingRoleMap[role.ID] = role + } + + err = d.connector.DB.Role.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + for _, role := range roles { + if role == nil { + continue + } + + existingRole := existingRoleMap[role.ID] + if existingRole == nil || guildRoleChanged(existingRole, role) { + if err := d.connector.DB.Role.Put(ctx, &discorddb.Role{ + GuildID: guildID, + Role: *role, + }); err != nil { + return fmt.Errorf("failed to upsert guild role: %w", err) + } + } + + delete(existingRoleMap, role.ID) + } + + for _, removedRole := range existingRoleMap { + if err := d.connector.DB.Role.DeleteByID(ctx, guildID, removedRole.ID); err != nil { + return fmt.Errorf("failed to delete removed guild role: %w", err) + } + } + + return nil + }) + if err != nil { + return fmt.Errorf("failed to sync guild roles: %w", err) + } + + return nil +} + +func (d *DiscordClient) upsertGuildRole(ctx context.Context, guildID string, role *discordgo.Role) error { + if role == nil { + return nil + } + + if err := d.connector.DB.Role.Put(ctx, &discorddb.Role{ + GuildID: guildID, + Role: *role, + }); err != nil { + return fmt.Errorf("failed to upsert guild role: %w", err) + } + + return nil +} diff --git a/pkg/connector/router.go b/pkg/connector/router.go new file mode 100644 index 0000000..a2d25dd --- /dev/null +++ b/pkg/connector/router.go @@ -0,0 +1,127 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + + "github.com/rs/zerolog" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" + "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/router" +) + +var _ router.Router = (*DiscordClient)(nil) + +func (d *DiscordClient) uncertainRoute(ctx context.Context, channelID string) *router.Route { + log := zerolog.Ctx(ctx) + log.Warn().Str("channel_id", channelID).Msg("Creating an uncertain route") + + return &router.Route{ + // It's generally bad to call into discordid for PortalKey + // construction since the helpers on DiscordClient ensure receiver + // correctness, but this is a bit of a special case as we're + // uncertain who the receiver should even be. + PortalKey: discordid.MakeChannelPortalKey(channelID, d.UserLogin.ID, true), + PortalChannelID: channelID, + Uncertain: true, + } +} + +// FIXME(skip): This method is infallible now, remove the error from the +// signature in the interface and refactor. +func (d *DiscordClient) Route(ctx context.Context, channelID string) (*router.Route, error) { + ch := d.channelWithID(ctx, channelID) + dbThread, err := d.connector.DB.Thread.GetByThreadChannelID( + ctx, + discordid.ParseUserLoginID(d.UserLogin.ID), + channelID, + ) + if err != nil { + // Even if we can't touch the database right now, we can try examining + // the channel from State to make a routing decision. + zerolog.Ctx(ctx).Warn(). + Err(err). + Str("channel_id", channelID). + Msg("Failed to look up potential thread channel ID, proceeding with route") + dbThread = nil + } + + // Most routes will just go to the channel the event originated from. (Not + // true for threads right now.) + r := router.Route{ + PortalChannelID: channelID, + FromChannel: ch, + FromThread: dbThread, + } + + if dbThread != nil { + // If the channel exists in the database as a thread, we immediately + // know how to be receiver-correct (i.e. we can set a correct PortalKey), + // even if the channel doesn't exist in State. + + // Threaded Discord messages need to be bridged to the Matrix room + // that portals to the _parent_ Discord channel, since we always bridge + // threads via m.thread right now. + r.PortalChannelID = dbThread.ParentChannelID + r.PortalKey = d.guildChannelPortalKey(dbThread.ParentChannelID) + + if ch == nil { + return &r, nil + } + } + + if ch == nil { + // We can't know the proper PortalKey for this channel. Return an + // uncertain route instead. + // + // TODO: Maybe we can just ask the REST API for the channel? + return d.uncertainRoute(ctx, channelID), nil + } + + if isThread(ch) { + if dbThread == nil { + // This is a thread we haven't seen before, so insert it into the database. + rootMsgID := defaultThreadRootMessageID(ch) + if upsertErr := d.upsertThreadInfo(ctx, channelID, rootMsgID, ch.ParentID); upsertErr != nil { + // Even if we can't save the thread to the database, we can still + // use the routing decision. + zerolog.Ctx(ctx).Warn(). + Err(upsertErr). + Str("thread_channel_id", channelID). + Str("parent_channel_id", ch.ParentID). + Msg("Failed to upsert newly discovered thread, proceeding with route") + } + thread := discorddb.Thread{ + UserLoginID: discordid.ParseUserLoginID(d.UserLogin.ID), + ThreadChannelID: channelID, + RootMessageID: rootMsgID, + ParentChannelID: ch.ParentID, + } + + // Duplicated from above. + r.PortalChannelID = thread.ParentChannelID + r.PortalKey = d.guildChannelPortalKey(thread.ParentChannelID) + r.FromThread = &thread + } + } else { + r.PortalKey = d.portalKeyForChannel(ch) + } + + return &r, nil +} diff --git a/pkg/connector/session.go b/pkg/connector/session.go new file mode 100644 index 0000000..4dcd561 --- /dev/null +++ b/pkg/connector/session.go @@ -0,0 +1,49 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + "strings" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" +) + +func NewDiscordSession(ctx context.Context, token string) (*discordgo.Session, error) { + log := zerolog.Ctx(ctx) + + session, err := discordgo.New(token) + if err != nil { + return nil, fmt.Errorf("couldn't create discord session: %w", err) + } + + // Don't bother tracking things we don't care/support right now. Presences + // are especially expensive to track as they occur extremely frequently. + session.State.TrackPresences = false + session.State.TrackVoice = false + + // Set up logging. + session.LogLevel = discordgo.LogInformational + session.Logger = func(msgL, caller int, format string, a ...any) { + // FIXME(skip): Hook up zerolog properly. + log.Debug().Str("component", "discordgo").Msgf(strings.TrimSpace(format), a...) // zerolog-allow-msgf + } + + return session, nil +} diff --git a/pkg/connector/thread.go b/pkg/connector/thread.go new file mode 100644 index 0000000..091a0fb --- /dev/null +++ b/pkg/connector/thread.go @@ -0,0 +1,189 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + "strings" + + "github.com/bwmarrin/discordgo" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +func isThread(ch *discordgo.Channel) bool { + return ch.Type == discordgo.ChannelTypeGuildPublicThread || + ch.Type == discordgo.ChannelTypeGuildPrivateThread || + ch.Type == discordgo.ChannelTypeGuildNewsThread +} + +func defaultThreadRootMessageID(ch *discordgo.Channel) string { + if ch == nil || !isThread(ch) { + return "" + } + if ch.Type == discordgo.ChannelTypeGuildPrivateThread { + return "" + } + return ch.ID +} + +func (d *DiscordClient) upsertThreadInfo(ctx context.Context, threadChannelID, rootMessageID, parentChannelID string) error { + if threadChannelID == "" || parentChannelID == "" { + return nil + } + return d.connector.DB.Thread.Put(ctx, &discorddb.Thread{ + UserLoginID: string(d.UserLogin.ID), + ThreadChannelID: threadChannelID, + RootMessageID: rootMessageID, + ParentChannelID: parentChannelID, + }) +} + +func (d *DiscordClient) upsertThreadInfoFromChannel(ctx context.Context, ch *discordgo.Channel) error { + if ch == nil || !isThread(ch) { + return nil + } + return d.upsertThreadInfo(ctx, ch.ID, defaultThreadRootMessageID(ch), ch.ParentID) +} + +func (d *DiscordClient) upsertThreadInfoFromMessage(ctx context.Context, msg *discordgo.Message) error { + if msg == nil || msg.Flags&discordgo.MessageFlagsHasThread == 0 || msg.Thread == nil { + return nil + } + threadChannelID := msg.Thread.ID + if threadChannelID == "" { + threadChannelID = msg.ID + } + parentChannelID := msg.Thread.ParentID + if parentChannelID == "" { + parentChannelID = msg.ChannelID + } + return d.upsertThreadInfo(ctx, threadChannelID, msg.ID, parentChannelID) +} + +func (d *DiscordClient) getThreadByRootMessageID(ctx context.Context, rootMessageID string) (*discorddb.Thread, error) { + if rootMessageID == "" { + return nil, nil + } + thread, err := d.connector.DB.Thread.GetByRootMessageID(ctx, string(d.UserLogin.ID), rootMessageID) + if err != nil || thread != nil { + return thread, err + } + + ch, err := d.Session.State.Channel(rootMessageID) + if err == nil && ch != nil && isThread(ch) && defaultThreadRootMessageID(ch) == rootMessageID { + if upsertErr := d.upsertThreadInfo(ctx, ch.ID, rootMessageID, ch.ParentID); upsertErr != nil { + return nil, upsertErr + } + return &discorddb.Thread{ + UserLoginID: string(d.UserLogin.ID), + ThreadChannelID: ch.ID, + RootMessageID: rootMessageID, + ParentChannelID: ch.ParentID, + }, nil + } + + return nil, nil +} + +func getMatrixThreadRootRemoteMessageID(threadRoot *database.Message) string { + if threadRoot == nil { + return "" + } + remoteID := discordid.ParseMessageID(threadRoot.ID) + if threadRoot.ThreadRoot != "" { + remoteID = discordid.ParseMessageID(threadRoot.ThreadRoot) + } + return remoteID +} + +func makeDiscordReferer(guildID, parentChannelID, threadChannelID string) discordgo.RequestOption { + if threadChannelID != "" && threadChannelID != parentChannelID { + return discordgo.WithThreadReferer(guildID, parentChannelID, threadChannelID) + } + return discordgo.WithChannelReferer(guildID, parentChannelID) +} + +func getThreadName(content *event.MessageEventContent) string { + body := "" + if content != nil { + body = content.Body + } + if len(body) == 0 { + return "thread" + } + + fields := strings.Fields(body) + var title string + for _, field := range fields { + if len(title)+len(field) < 40 { + title += field + " " + } else if len(title) == 0 { + title = field[:40] + break + } else { + break + } + } + title = strings.TrimSpace(title) + if title == "" { + return "thread" + } + return title +} + +func (d *DiscordClient) startThreadFromMatrix( + ctx context.Context, + guildID string, + parentChannelID string, + rootMessageID string, + threadName string, +) (string, error) { + if !d.IsLoggedIn() { + return "", fmt.Errorf("can't create thread without being logged into Discord") + } + + threadType := discordgo.ChannelTypeGuildPublicThread + parentCh, err := d.Session.State.Channel(parentChannelID) + if err == nil && parentCh != nil && parentCh.Type == discordgo.ChannelTypeGuildNews { + threadType = discordgo.ChannelTypeGuildNewsThread + } + + ch, err := d.Session.MessageThreadStartComplex( + parentChannelID, + rootMessageID, + &discordgo.ThreadStart{ + Name: threadName, + AutoArchiveDuration: 24 * 60, + Type: threadType, + Location: "Message", + }, + makeDiscordReferer(guildID, parentChannelID, ""), + ) + if err != nil { + return "", d.tryWrappingError(ctx, err) + } + + if upsertErr := d.upsertThreadInfo(ctx, ch.ID, rootMessageID, parentChannelID); upsertErr != nil { + return "", upsertErr + } + return ch.ID, nil +} diff --git a/pkg/connector/usercache.go b/pkg/connector/usercache.go new file mode 100644 index 0000000..d22a53b --- /dev/null +++ b/pkg/connector/usercache.go @@ -0,0 +1,175 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "errors" + "maps" + "net/http" + "slices" + "sync" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +// NOTE: Not simply using `exsync.Map` because we want the lock to be held +// during HTTP requests. + +type UserCache struct { + session *discordgo.Session + cache map[string]*discordgo.User + lock sync.Mutex +} + +func NewUserCache(session *discordgo.Session) *UserCache { + return &UserCache{ + session: session, + cache: make(map[string]*discordgo.User), + } +} + +func (uc *UserCache) UpdateWithReady(ready *discordgo.Ready) { + if ready == nil { + return + } + + uc.lock.Lock() + defer uc.lock.Unlock() + + self := ready.User + uc.cache[self.ID] = self + + for _, user := range ready.Users { + uc.cache[user.ID] = user + } +} + +// UpdateWithMessage updates the user cache with the users involved in a single +// message (author, mentioned, mentioned author, etc.) +// +// The updated user IDs are returned. +func (uc *UserCache) UpdateWithMessage(msg *discordgo.Message) []string { + if msg == nil { + return []string{} + } + + // For now just forward to HandleMessages until a need for a specialized + // path makes itself known. + return uc.UpdateWithMessages([]*discordgo.Message{msg}) +} + +// UpdateWithMessages updates the user cache with the total set of users involved +// with multiple messages (authors, mentioned users, mentioned authors, etc.) +// +// The updated user IDs are returned. +func (uc *UserCache) UpdateWithMessages(msgs []*discordgo.Message) []string { + if len(msgs) == 0 { + return []string{} + } + + collectedUsers := map[string]*discordgo.User{} + for _, msg := range msgs { + collectedUsers[msg.Author.ID] = msg.Author + + referenced := msg.ReferencedMessage + if referenced != nil && referenced.Author != nil { + collectedUsers[referenced.Author.ID] = referenced.Author + } + + for _, mentioned := range msg.Mentions { + collectedUsers[mentioned.ID] = mentioned + } + + // Message snapshots lack `author` entirely and seemingly have an empty + // `mentions` array, even when the original message actually mentions + // someone. + } + + uc.lock.Lock() + defer uc.lock.Unlock() + + for _, user := range collectedUsers { + uc.cache[user.ID] = user + } + + return slices.Collect(maps.Keys(collectedUsers)) +} + +func (uc *UserCache) UpdateWithUserUpdate(update *discordgo.UserUpdate) { + if update == nil || update.User == nil { + return + } + + uc.lock.Lock() + defer uc.lock.Unlock() + + uc.cache[update.ID] = update.User +} + +// Resolve looks up a user in the cache, requesting the user from the Discord +// HTTP API if not present. +// +// If the user cannot be found, then its nonexistence is cached. This is to +// avoid excessive requests when e.g. backfilling messages from a user that has +// since been deleted since connecting. If some other error occurs, the cache +// isn't touched and nil is returned. +// +// Otherwise, the cache is updated as you'd expect. +func (uc *UserCache) Resolve(ctx context.Context, userID string) *discordgo.User { + if userID == discordid.DeletedGuildUserID { + return &discordid.DeletedGuildUser + } + + // Hopefully this isn't too contentious? + uc.lock.Lock() + defer uc.lock.Unlock() + + cachedUser, present := uc.cache[userID] + if cachedUser != nil { + return cachedUser + } else if present { + // If a `nil` is present in the map, then we already know that the user + // doesn't exist. + return nil + } + + log := zerolog.Ctx(ctx).With(). + Str("action", "resolve user"). + Str("user_id", userID).Logger() + + log.Trace().Msg("Fetching user") + user, err := uc.session.User(userID) + + var restError *discordgo.RESTError + if errors.As(err, &restError) && restError.Response.StatusCode == http.StatusNotFound { + log.Info().Msg("Tried to resolve a user that doesn't exist, caching nonexistence") + uc.cache[userID] = nil + + return nil + } else if err != nil { + log.Err(err).Msg("Failed to resolve user") + return nil + } + + uc.cache[userID] = user + + return user +} diff --git a/pkg/connector/userinfo.go b/pkg/connector/userinfo.go new file mode 100644 index 0000000..a8fb17e --- /dev/null +++ b/pkg/connector/userinfo.go @@ -0,0 +1,118 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +func readableRelationshipType(rel discordgo.RelationshipType) (desc string) { + desc = "unknown" + + switch rel { + case discordgo.RelationshipBlocked: + desc = "blocked" + case discordgo.RelationshipFriend: + desc = "friend" + case discordgo.RelationshipIncomingFriendRequest: + desc = "recipient wants to be friends" + case discordgo.RelationshipOutgoingFriendRequest: + desc = "sender wants to be friends" + } + + return +} + +// makeRemoteName computes an appropriate value for the RemoteName field on +// [bridgev2.UserLogin]. +func makeRemoteName(u *discordgo.User) string { + return u.String() +} + +// makeRemoteProfile creates a [status.makeRemoteProfile] from a +// [discordgo.User]. A [bridgev2.Ghost] may optionally be passed to provide an +// avatar. +func makeRemoteProfile(u *discordgo.User, ghost *bridgev2.Ghost) (p status.RemoteProfile) { + p.Phone = u.Phone + p.Email = u.Email + p.Username = u.String() + p.Name = u.GlobalName + if ghost != nil { + p.Avatar = ghost.AvatarMXC + } + return +} + +func (d *DiscordClient) IsThisUser(ctx context.Context, userID networkid.UserID) bool { + // We define `UserID`s and `UserLoginID`s to be interchangeable, i.e. they map + // directly to Discord user IDs ("snowflakes"), so we can perform a direct comparison. + return userID == discordid.UserLoginIDToUserID(d.UserLogin.ID) +} + +func (d *DiscordClient) makeUserAvatar(u *discordgo.User) *bridgev2.Avatar { + url := u.AvatarURL("256") + + return &bridgev2.Avatar{ + ID: discordid.MakeAvatarID(url), + Get: func(ctx context.Context) ([]byte, error) { + return httpGet(ctx, d.httpClient, url, "user avatar") + }, + } +} + +func (d *DiscordClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { + if !d.IsLoggedIn() { + return nil, bridgev2.ErrNotLoggedIn + } + + log := zerolog.Ctx(ctx) + + if ghost.ID == "" { + log.Warn().Msg("Tried to get user info for ghost with no ID") + return nil, nil + } + + discordUserID := discordid.ParseUserID(ghost.ID) + discordUser := d.userCache.Resolve(ctx, discordUserID) + if discordUser == nil { + log.Error().Str("discord_user_id", discordUserID). + Msg("Failed to resolve user") + return nil, nil + } + + return d.getUserInfo(ctx, discordUser), nil +} + +func (d *DiscordClient) getUserInfo(ctx context.Context, user *discordgo.User) *bridgev2.UserInfo { + return &bridgev2.UserInfo{ + // FIXME clear this for webhooks (stash in ghost metadata) + Identifiers: []string{fmt.Sprintf("discord:%s", user.String())}, + Name: ptr.Ptr(user.DisplayName()), + Avatar: d.makeUserAvatar(user), + IsBot: &user.Bot, + } +} diff --git a/pkg/discordauth/NOTES.md b/pkg/discordauth/NOTES.md new file mode 100644 index 0000000..f23c8e3 --- /dev/null +++ b/pkg/discordauth/NOTES.md @@ -0,0 +1,51 @@ +## `POST /api/v9/auth/login` + +### request + +... + +### response + +#### new login location + +HTTP 400 + +```json +{ + "message": "Invalid Form Body", + "code": 50035, + "errors": { + "login": { + "_errors": [ + { + "code": "ACCOUNT_LOGIN_VERIFICATION_EMAIL", + "message": "New login location detected, please check your e-mail." + } + ] + } + } +} +``` + +## `POST /api/v9/auth/authorize-ip` + +### request + +```json +{ + "token": "..." +} +``` + +### response + +#### when link has expired + +```json +{ + "message": "Invalid authentication token", + "code": 50014 +} +``` + +- UI prompts the user to log in again to get another link diff --git a/pkg/discordauth/SEQUENCE_DIAGRAM.md b/pkg/discordauth/SEQUENCE_DIAGRAM.md new file mode 100644 index 0000000..ecab025 --- /dev/null +++ b/pkg/discordauth/SEQUENCE_DIAGRAM.md @@ -0,0 +1,81 @@ +```mermaid +sequenceDiagram + actor User + participant Bridge + participant Discord + + note over User: Login preemption flows: + + rect rgb(254 246 181 / 50%) + note over User,Discord: This flow may occur spontaneously, as a response to ANY request, even those containing a CAPTCHA solution, as well as OUTSIDE OF LOGIN FLOWS.
    As Discord can reply to a CAPTCHA solution with another CAPTCHA challenge, an implementation will likely require a loop.

    In other words: ANY HTTP arrow going from Discord to Bridge may suddenly enter this flow without prior warning. + Discord->>Bridge: HTTP 400, CAPTCHA challenge (regardless of the would-be outcome) + alt Challenge is invisible + Bridge->>Bridge: ??? + note right of Bridge: How this is handled is currently unknown. + else Challenge isn't invisible (majority of cases) + Bridge->>User: Modally present CAPTCHA challenge + end + User->>Bridge: CAPTCHA solution + Bridge->>Discord: Retry request with the same body, incorporating CAPTCHA solution in headers + end + + rect rgb(181 244 254 / 50%) + note over User,Discord: When attempting to log in from a "new location" (IP address unfamiliar to Discord), the following occurs for a login that would otherwise complete successfully (returning a user token and ID, among other data): + Discord->>Bridge: HTTP 400, error code 50035, "Invalid Form Body" + note right of Bridge: The form error code sent by Discord is "ACCOUNT_LOGIN_VERIFICATION_EMAIL".
    The message is "New login location detected, please check your e-mail." + Bridge->>User: Fail the entire log in flow. The user must authorize the IP address first, then attempt the log in again.
    As with ordinary login attempts, MFA and or CAPTCHAs may be involved. + User->>Discord: Visits the email-provided log in link. After a redirect, the page performs POST /auth/authorize-ip with an opaque token. + Discord->>Discord: The IP address is now allowed to log in to the user's account. + end + + rect rgb(200 210 255 / 50%) + note over User,Discord: If the user's Discord account is suspended, a would-be successful login attempt instead yields a "suspended user token." + Discord->>Bridge: HTTP 403, user ID and "suspended user token" + end + + note over User: Login flows: + + alt + note over User,Discord: Log in with email or phone number, and password (Creds) + User->>Bridge: Specifies an email or phone number as well as
    a password + Bridge->>Discord: POST /auth/login + + alt User does not have MFA set up (LoginCompleted) + note over Bridge: ("New location" preemption flow is possible. When skipped, the following occurs:) + Discord->>Bridge: User token, ID, locale, and theme settings + Bridge->>Bridge: Save token and log in + else User has MFA set up and it is required for log in + Discord->>Bridge: HTTP 200, which MFA methods the user has set up, and an opaque "ticket" (LoginMFARequired) + Bridge->>+User: Modally ask the user which MFA method to use + activate User + activate User + + alt Chosen MFA method: SMS + User->>-Bridge: I would like to proceed with SMS-based MFA + Bridge->>Discord: POST /auth/mfa/sms/send with the "ticket" from earlier (SMSSendRequest) + Discord->>User: Sends a short numeric code to the user via SMS + note over User: "Your Discord verification code is: 123456" + User->>Bridge: Provides the received code + Bridge->>Discord: POST /auth/mfa/sms with the code and the "ticket" (MFAContinuation) + else Chosen MFA method: TOTP + User->>-Bridge: I would like to proceed with TOTP-based MFA, providing the TOTP code + Bridge->>Discord: POST /auth/mfa/totp with the code and the "ticket" (MFAContinuation) + else Chosen MFA method: TOTP backup code + User->>-Bridge: I would like to proceed with a TOTP backup code, providing it + Bridge->>Discord: PSOT /auth/mfa/backup with the code and the "ticket" (MFAContinuation) + end + end + note over Bridge: After making a successful request to /auth/mfa/… with any MFA type, the login either succeeds or is preempted due to login location (IP address) or user suspension. + else + note over User,Discord: Log in by scanning QR code with Discord mobile app + User->>Bridge: I want to log in with a QR code + Bridge->>Discord: Connect to "remoteauth" gateway (WebSocket) + Discord->>Bridge: … + Bridge->>User: Present QR code, wait for scan + note right of User: The remainder of
    this flow is omitted for now. + else + note over User,Discord: Log in via WebAuthn (passkey, security key) + + note right of User: This flow is omitted for now. + end +``` diff --git a/pkg/discordauth/captcha.go b/pkg/discordauth/captcha.go new file mode 100644 index 0000000..f835b3e --- /dev/null +++ b/pkg/discordauth/captcha.go @@ -0,0 +1,113 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/rs/zerolog" +) + +const HeaderCaptchaKey = "x-captcha-key" +const HeaderCaptchaSessionID = "x-captcha-session-id" +const HeaderCaptchaRqToken = "x-captcha-rqtoken" + +type CaptchaService string + +const ( + CaptchaServiceHCaptcha CaptchaService = "hcaptcha" + CaptchaServiceReCaptcha CaptchaService = "recaptcha" + CaptchaServiceReCaptchaEnterprise CaptchaService = "recaptcha_enterprise" +) + +// HCaptcha holds the information specific to hCaptcha within a [Captcha] +// challenge. This is only separated for organizational purposes. +type HCaptcha struct { + SiteKey *string `json:"captcha_sitekey"` + SessionID *string `json:"captcha_session_id"` // re-sent in `x-captcha-session-id` + RqData *string `json:"captcha_rqdata"` + RqToken *string `json:"captcha_rqtoken"` // re-sent in `x-captcha-rqtoken` +} + +func (hc *HCaptcha) SpotCheck() bool { + return hc.SiteKey != nil && *hc.SiteKey != "" && + hc.SessionID != nil && *hc.SessionID != "" && + hc.RqData != nil && *hc.RqData != "" && + hc.RqToken != nil && *hc.RqToken != "" +} + +func (hc *HCaptcha) UpdateHeaders(header *http.Header) { + header.Del(HeaderCaptchaSessionID) + header.Del(HeaderCaptchaRqToken) + + if hc.SessionID != nil { + header.Set(HeaderCaptchaSessionID, *hc.SessionID) + } + if hc.RqToken != nil { + header.Set(HeaderCaptchaRqToken, *hc.RqToken) + } +} + +// A CAPTCHA challenge from Discord. +// +// This may be returned from any endpoint at any time. To test for the presence +// of a captcha challenge, test the following criteria: +// +// 1. The HTTP status of the response is 400. +// +// 2. The captcha_key field is present on the root object of the response body +// when parsed as JSON. +type Captcha struct { + HCaptcha + Key []string `json:"captcha_key"` + Service CaptchaService `json:"captcha_service"` + Invisible bool `json:"should_serve_invisible"` + UserFlow *string `json:"user_flow"` // Unknown. +} + +func (c *Captcha) LogContext(ctx zerolog.Context) zerolog.Context { + return ctx. + Str("captcha_service", string(c.Service)). + Strs("captcha_key", c.Key). + Bool("captcha_invisible", c.Invisible) +} + +func TryUnmarshalingCaptcha(ctx context.Context, resp *http.Response, body []byte) *Captcha { + if resp.StatusCode != 400 { + return nil + } + + log := zerolog.Ctx(ctx) + + var challenge Captcha + + err := json.Unmarshal(body, &challenge) + if err != nil { + // We should only hit this if the JSON is malformed or something, which + // is probably worth knowing about. + log.Warn().Err(err).Msg("Failed to unmarshal potential captcha challenge") + return nil + } + + if len(challenge.Key) > 0 { + return &challenge + } + + return nil +} diff --git a/pkg/discordauth/context.go b/pkg/discordauth/context.go new file mode 100644 index 0000000..116240b --- /dev/null +++ b/pkg/discordauth/context.go @@ -0,0 +1,58 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "encoding/base64" + "encoding/json" + "fmt" +) + +// EncodeBasicContextProperties creates a value for [HeaderContextProperties] +// when only the "location" key is needed. This is unlikely to suffice for most +// location types. +func EncodeBasicContextProperties(location ContextLocation) (string, error) { + encoded, err := json.Marshal(map[string]string{"location": string(location)}) + if err != nil { + return "", fmt.Errorf("failed to marshal basic context properties: %w", err) + } + + return base64.StdEncoding.EncodeToString(encoded), nil +} + +type ContextLocation string + +// This is not a comprehensive listing. +const ( + ContextLocationLogin ContextLocation = "Login" + ContextLocationRegister ContextLocation = "Register" + ContextLocationInvite ContextLocation = "Accept Invite Page" + ContextLocationVerify ContextLocation = "Verify Email" + ContextLocationDisableEmailNotifications ContextLocation = "Disable Email Notifications" + ContextLocationDisableServerHighlightNotifications ContextLocation = "Disable Server Highlight Notifications" + ContextLocationAuthorizeIp ContextLocation = "Authorize Ip" + ContextLocationRejectIp ContextLocation = "Reject Ip" + ContextLocationRejectMfa ContextLocation = "Reject MFA" + ContextLocationReport ContextLocation = "Report Illegal Content" + ContextLocationReportSecondLook ContextLocation = "Report Second Look" + ContextLocationAuthorizePayment ContextLocation = "Authorize Payment" + ContextLocationReset ContextLocation = "Reset" + ContextLocationAccountRevert ContextLocation = "Account Revert" + ContextLocationHandoff ContextLocation = "Handoff" + ContextLocationUnknown ContextLocation = "Unknown" + ContextLocationLanding ContextLocation = "Landing" +) diff --git a/pkg/discordauth/error.go b/pkg/discordauth/error.go new file mode 100644 index 0000000..09f49e7 --- /dev/null +++ b/pkg/discordauth/error.go @@ -0,0 +1,183 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "encoding/json" + "fmt" + "strings" +) + +// TODO(skip): Some overlap with this and discordgo. Sort that out. + +type APIError struct { + Message string `json:"message"` + Code ErrCode `json:"code"` + + // Detailed errors. Returned for e.g. [InvalidFormBody]. + // + // The keys in this map correspond to the top-level keys that were sent in + // your request body. The value reconstructs the shape of the data that was + // sent, with arbitrary depth. + // + // For example, a request body such as + // + // { "friends": [ { "enjoys_pineapple_on_pizza": false } ] } + // + // might result in this erroneous reply: + // + // { + // ..., + // "errors": { + // "friends": { + // "0": { + // "enjoys_pineapple_on_pizza": { + // "_errors": [ + // { + // "code": "CHECK_YOUR_OPINION", + // "message": "Everybody likes pineapple on pizza. Try again." + // } + // ] + // } + // } + // } + // } + // } + // + // Notice how: + // + // - The intermediate values are always objects. Array indices are + // represented with strings. + // + // - The erroneous request value terminates in an object containing an + // array keyed under _errors. + // + // The _errors array further contains objects of shape { code, message }. + Errors map[string]json.RawMessage `json:"errors"` + + // The raw HTTP response body. + ResponseBody []byte `json:"-"` +} + +// A FormError communicates detailed error information for certain JSON field. +type FormError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type FormErrorCode string + +const ( + // AccountLoginVerificationEmail is raised when the user is logging in from + // a new IP address and must check their email for a verification link. + AccountLoginVerificationEmail FormErrorCode = "ACCOUNT_LOGIN_VERIFICATION_EMAIL" + + // InvalidLogin is raised when the username/phone or password was + // incorrect. + InvalidLogin FormErrorCode = "INVALID_LOGIN" +) + +// FormFieldErrors returns the [FormError] values associated with the +// given key that had been sent in the request. If the key present or +// the errors array is empty for whatever reason, nil is returned. +// +// NOTE/TODO: This function does not currently support accessing fields beyond +// the first level. +func (err *APIError) FormFieldErrors(key string) ([]FormError, error) { + leafMsg, ok := err.Errors[key] + if !ok { + return nil, nil + } + + type ErrorsLeaf struct { + Errors []FormError `json:"_errors"` + } + + var leaf ErrorsLeaf + if err := json.Unmarshal(leafMsg, &leaf); err != nil { + return nil, fmt.Errorf("failed to unmarshal errors leaf: %w", err) + } + + return leaf.Errors, nil +} + +var _ error = (*APIError)(nil) + +func (err APIError) Error() string { + msg := fmt.Sprintf("Discord API error %d: \"%s\"", err.Code, err.Message) + + if err.Code == InvalidFormBody && err.Errors != nil { + fieldErrors := make([]string, 0) + + for key := range err.Errors { + errors, err := err.FormFieldErrors(key) + if err != nil { + continue + } + + summaries := make([]string, 0) + for _, error := range errors { + summaries = append(summaries, fmt.Sprintf("\"%s\" (%s)", error.Message, error.Code)) + } + + fieldErrors = append(fieldErrors, fmt.Sprintf("%s: %s", key, strings.Join(summaries, ", "))) + } + + return msg + ": " + strings.Join(fieldErrors, "; ") + } + + return msg +} + +type ErrCode int + +const ( + RateLimited ErrCode = 31001 + RateLimitedResource ErrCode = 31002 + + AccountScheduledForDeletion ErrCode = 20011 + AccountDisabled ErrCode = 20013 + + Unauthorized ErrCode = 40001 + AccountVerificationNeeded ErrCode = 40002 + CloudflareBlocked ErrCode = 40333 + + InvalidFormBody ErrCode = 50035 + + MFAAlreadyEnrolled ErrCode = 60001 + MFANotEnrolled ErrCode = 60002 + MFARequired ErrCode = 60003 + MustBeVerified ErrCode = 60004 + MFAInvalidSecret ErrCode = 60005 + MFAInvalidAuthTicket ErrCode = 60006 + MFAInvalidCode ErrCode = 60008 + MFAInvalidSession ErrCode = 60009 + SMSAuthNotEnrolled ErrCode = 60010 + InvalidKey ErrCode = 60011 + SMSAuthCannotBeEnabled ErrCode = 60012 + MFARequiredForShopListings ErrCode = 60015 + MFAEmailIneligible ErrCode = 60019 + CredentialUndiscoverableOrInvalid ErrCode = 60021 + + SMSAuthUnableToSendMessage ErrCode = 70003 + SMSAuthPhoneNumberRecentlyUsedElsewhere ErrCode = 70004 + SMSAuthPhoneNumberIsVoIPOrLandline ErrCode = 70005 + SMSAuthVerificationNeeded ErrCode = 70007 + SMSAuthPhoneNumberAlreadyUsedElsewhere ErrCode = 70008 + PasswordResetLinkSentToEmail ErrCode = 70009 + SMSAuthPhoneNumberCannotBeAssociated ErrCode = 70011 +) diff --git a/pkg/discordauth/error_test.go b/pkg/discordauth/error_test.go new file mode 100644 index 0000000..3ea2d36 --- /dev/null +++ b/pkg/discordauth/error_test.go @@ -0,0 +1,43 @@ +package discordauth + +import ( + "encoding/json" + "testing" +) + +func TestFormFieldErrors_AccountLoginVerificationEmail(t *testing.T) { + body := []byte(`{ + "message": "Invalid Form Body", + "code": 50035, + "errors": { + "login": { + "_errors": [ + { + "code": "ACCOUNT_LOGIN_VERIFICATION_EMAIL", + "message": "New login location detected, please check your e-mail." + } + ] + } + } + }`) + + var apiErr APIError + if err := json.Unmarshal(body, &apiErr); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if apiErr.Code != InvalidFormBody { + t.Fatalf("expected code %d, got %d", InvalidFormBody, apiErr.Code) + } + + errs, err := apiErr.FormFieldErrors("login") + if err != nil { + t.Fatalf("FormFieldErrors returned error: %v", err) + } + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if FormErrorCode(errs[0].Code) != AccountLoginVerificationEmail { + t.Fatalf("expected code %s, got %s", AccountLoginVerificationEmail, errs[0].Code) + } +} diff --git a/database/upgrades/upgrades.go b/pkg/discordauth/experiments_legacy.go similarity index 68% rename from database/upgrades/upgrades.go rename to pkg/discordauth/experiments_legacy.go index d6954d5..a28697a 100644 --- a/database/upgrades/upgrades.go +++ b/pkg/discordauth/experiments_legacy.go @@ -1,5 +1,5 @@ // mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2026 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -14,19 +14,20 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package upgrades +package discordauth -import ( - "embed" +type Fingerprint string - "go.mau.fi/util/dbutil" -) - -var Table dbutil.UpgradeTable - -//go:embed *.sql -var rawUpgrades embed.FS - -func init() { - Table.RegisterFS(rawUpgrades) +func (f Fingerprint) HeaderValue() string { + return string(f) +} + +func (f Fingerprint) IsZero() bool { + return f == "" +} + +type ExperimentsLegacy struct { + Fingerprint Fingerprint `json:"fingerprint"` + // `json:"assignments"` + // `json:"guild_experiments"` } diff --git a/pkg/discordauth/handler.go b/pkg/discordauth/handler.go new file mode 100644 index 0000000..7c6009a --- /dev/null +++ b/pkg/discordauth/handler.go @@ -0,0 +1,42 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import "context" + +// Specify a ChallengeHandler when creating an [AuthMachine] (via +// [NewAuthMachine]) to implement handling for flows that "interrupt" the login +// process, such as CAPTCHAs or MFA. +// +// In other words, this interface constitutes the essential client "hook point" +// where you may inject your own behaviors into the login flow. In these +// methods, you will likely need to update your user interface and or prompt +// the user. +// +// [AuthMachine] will call the methods on this interface as necessary at the +// correct moments and handle all of the required plumbing. +type ChallengeHandler interface { + // Discord presented a CAPTCHA. Let the user solve it and return their + // solution out of this method. + SolveCaptcha(context.Context, *Captcha) (*CaptchaSolution, error) + + // The password was accepted as part of the login, but MFA is at play. + // Inspect the MFAChallenge to see which MFA methods are permitted and + // prompt the user accordingly. This is also how you may request an SMS + // code. Once you have a code, return it via MFAContinue. + ContinueMFA(context.Context, *MFAChallenge) (*MFAContinue, error) +} diff --git a/pkg/discordauth/http.go b/pkg/discordauth/http.go new file mode 100644 index 0000000..5221116 --- /dev/null +++ b/pkg/discordauth/http.go @@ -0,0 +1,68 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "context" + "fmt" + "io" + "net/http" +) + +type HTTP interface { + Do(req *http.Request) (*http.Response, error) +} + +func respIsOk(resp *http.Response) bool { + if resp == nil { + return false + } + + return resp.StatusCode >= 200 && resp.StatusCode < 300 +} + +type HTTPError struct { + body []byte + resp *http.Response +} + +func (err HTTPError) Error() string { + if err.body != nil && len(err.body) < 1_024*16 { // arbitrarily cap at 16 KiB + return fmt.Sprintf("Discord replied with HTTP %d: %s", err.resp.StatusCode, string(err.body)) + } + + return fmt.Sprintf("Discord replied with HTTP %d", err.resp.StatusCode) +} + +func refreshReq(ctx context.Context, req *http.Request) (*http.Request, error) { + var newBody io.ReadCloser + var err error + + if req.Body != nil && req.ContentLength > 0 { + newBody, err = req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to clone request body when retrying: %w", err) + } + } + req = req.Clone(ctx) + + if newBody != nil { + req.Body = newBody + } + + return req, nil +} diff --git a/pkg/discordauth/login.go b/pkg/discordauth/login.go new file mode 100644 index 0000000..04b7a0d --- /dev/null +++ b/pkg/discordauth/login.go @@ -0,0 +1,54 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +// Creds are some credentials that you use to initiate a login to Discord. +// +// This isn't all that is needed to log in successfully; you may have to solve +// a CAPTCHA, verify your login location, participate in MFA, etc. +type Creds struct { + GiftCodeSKUID *string `json:"gift_code_sku_id"` + Login string `json:"login"` + LoginSource *string `json:"login_source"` + Password Sensitive[string] `json:"password"` + Undelete bool `json:"undelete"` +} + +func NewCreds(emailOrPhone string, password string) *Creds { + return &Creds{ + Login: emailOrPhone, + Password: NewSensitive(password), + Undelete: false, + } +} + +// A LoginCompleted is returned from Discord when a log in flow concludes. +type LoginCompleted struct { + Token Sensitive[string] `json:"token"` + UserID string `json:"user_id"` + UserSettings UserSettings `json:"user_settings"` + RequiredActions []string `json:"required_actions"` +} + +func (lc *LoginCompleted) HasToken() bool { + return !lc.Token.IsZero() +} + +type UserSettings struct { + Locale string `json:"locale"` + Theme string `json:"theme"` +} diff --git a/pkg/discordauth/machine.go b/pkg/discordauth/machine.go new file mode 100644 index 0000000..faea10b --- /dev/null +++ b/pkg/discordauth/machine.go @@ -0,0 +1,528 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "slices" + "strings" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" +) + +// An AuthMachine governs the core logic to authenticate with Discord. It is +// concerned with: +// +// - Detecting CAPTCHA challenges. +// - Sending the correct set of headers to each endpoint. +// - Stashing the necessary state in-memory and threading them into requests +// as necessary. +type AuthMachine struct { + log *zerolog.Logger + LogFilters AuthMachineLogFilters + + http HTTP + APIBase string + handler ChallengeHandler + + State AuthMachineState + + Personality *Personality +} + +type AuthMachineState struct { + Fingerprint Fingerprint +} + +type CaptchaSolution struct { + Solution string +} +type CaptchaHandler func(ctx context.Context, captcha *Captcha) (*CaptchaSolution, error) + +func NewAuthMachine(ctx context.Context, http HTTP, personality *Personality, handler ChallengeHandler) *AuthMachine { + if http == nil { + panic("http interface is required") + } + if personality == nil { + panic("personality is required") + } + if handler == nil { + panic("handler is required") + } + + log := zerolog.Ctx(ctx).With().Str("component", "discord auth").Logger() + + return &AuthMachine{ + log: &log, + + http: http, + handler: handler, + + APIBase: "https://discord.com/api/v9", + Personality: personality, + } +} + +func formatHTTPHeaderDump(prefix string, headers http.Header) string { + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + slices.Sort(keys) + + var msg strings.Builder + msg.WriteString(prefix) + for _, key := range keys { + for _, value := range headers[key] { + msg.WriteByte('\n') + msg.WriteString(key) + msg.WriteString(": ") + msg.WriteString(value) + } + } + + return msg.String() +} + +func (am *AuthMachine) captchaRetryLoop(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { + // Check if we can clone the request body. We need this since we might need + // to retry the request. + if req.GetBody == nil && req.ContentLength > 0 { + return nil, nil, fmt.Errorf("tried to make request with a body that isn't retriable") + } + + log := zerolog.Ctx(ctx) + nCaptchas := 0 + var resp *http.Response + var err error + + defer func() { + if resp == nil { + return + } + + respLogLevel := zerolog.DebugLevel + respStatusOk := respIsOk(resp) + if !respStatusOk { + respLogLevel = zerolog.ErrorLevel + } + + if am.LogFilters.EveryHTTPResponse || !respStatusOk { + // Erroneous responses are always logged. + log.WithLevel(respLogLevel). + Int("n_captchas", nCaptchas). + Int("http_status", resp.StatusCode). + Int("http_content_length", int(resp.ContentLength)). + Msg("Received response") + } + }() + + for { + if am.LogFilters.EveryHTTPRequest { + log.Debug(). + Int("n_captchas", nCaptchas). + Msg("Making request") + } + if am.LogFilters.DangerouslyLeakyHTTPHeaders { + log.Debug(). + Int("n_captchas", nCaptchas). + Msg(formatHTTPHeaderDump("Sending request headers", req.Header)) + } + + // Make the HTTP request. + resp, err = am.http.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("failed to make http request: %w", err) + } + if am.LogFilters.DangerouslyLeakyHTTPHeaders { + log.Debug(). + Int("n_captchas", nCaptchas). + Msg(formatHTTPHeaderDump("Received response headers", resp.Header)) + } + + // We need to consume the entire response body so we can test for a + // CAPTCHA challenge. + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to slurp http response body: %w", err) + } + if err := resp.Body.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close response body, proceeding") + } + + captcha := TryUnmarshalingCaptcha(ctx, resp, body) + if captcha != nil { + goto solveCaptchaAndRetry + } + + if !respIsOk(resp) { + // (defer block above logs for us.) + + var apiError APIError + err := json.Unmarshal(body, &apiError) + + if err != nil || apiError.Code == 0 { + // Doesn't look like we got {"code": 00000, "message": "..."} + return nil, nil, HTTPError{body: body, resp: resp} + } else { + apiError.ResponseBody = body + return nil, nil, apiError + } + } + + // No CAPTCHA, we're good. + return resp, body, nil + + solveCaptchaAndRetry: + // We got a CAPTCHA. Invoke the handler provided by the client and + // retry with the challenge response once the CAPTCHA is completed. + + log = ptr.Ptr(captcha.LogContext(log.With()).Logger()) + log.Info().Msg("Encountered CAPTCHA challenge") + + solution, err := am.waitForCaptchaSolve(ctx, captcha) + if err != nil { + return nil, nil, fmt.Errorf("failed to wait for captcha solution: %w", err) + } + + // We're going to try the request again once we come back around in the + // loop. + req, err = refreshReq(ctx, req) + if err != nil { + return nil, nil, fmt.Errorf("failed to refresh request: %w", err) + } + // Add the solution and other CAPTCHA state to the headers. + req.Header.Set(HeaderCaptchaKey, solution.Solution) + captcha.UpdateHeaders(&req.Header) + } +} + +func (am *AuthMachine) waitForCaptchaSolve(ctx context.Context, captcha *Captcha) (*CaptchaSolution, error) { + log := zerolog.Ctx(ctx).With().Str("action", "wait for discord captcha solve").Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("Invoking CAPTCHA handler") + solution, err := am.handler.SolveCaptcha(ctx, captcha) + if err != nil { + return nil, fmt.Errorf("captcha handler failed: %w", err) + } + if solution == nil { + return nil, fmt.Errorf("captcha handler returned nil solution") + } + + return solution, nil +} + +// doHandlingCaptcha performs an HTTP request, mutating it to contain headers +// from the [Personality]. +// +// - In order to detect and respond to CAPTCHA challenges, this method buffers +// all request and response bodies into memory. +// +// - Should a CAPTCHA challenge occur, note that multiple attempts to solve the +// CAPTCHA may be necessary. +func (am *AuthMachine) doHandlingCaptcha(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { + log := zerolog.Ctx(ctx).With(). + Str("http_method", req.Method). + Stringer("http_url", req.URL). + Logger() + ctx = log.WithContext(ctx) + + // Add all personality headers to the request. + personalityHeaders, err := am.Personality.Headers() + if err != nil { + return nil, nil, fmt.Errorf("failed to get personality headers: %w", err) + } + maps.Copy(req.Header, personalityHeaders) + // Set X-Debug-Options if we have one. + debugOptions := am.Personality.DebugOptions + if debugOptions != "" { + req.Header.Set(HeaderDebugOptions, debugOptions) + } + // Set X-Fingerprint if we have one. + if !am.State.Fingerprint.IsZero() { + req.Header.Set(HeaderFingerprint, am.State.Fingerprint.HeaderValue()) + } + + // Make the request, anticipating any potential CAPTCHAs. + resp, body, err := am.captchaRetryLoop(ctx, req) + if err != nil { + return nil, nil, fmt.Errorf("failed to make request: %w", err) + } + + return resp, body, err +} + +func (am *AuthMachine) performLegacyExperiments(ctx context.Context) (*ExperimentsLegacy, error) { + url := fmt.Sprintf("%s/experiments?with_guild_experiments=true", am.APIBase) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to construct legacy experiments request: %w", err) + } + + // Set X-Context-Properties. + contextProps, err := EncodeBasicContextProperties(ContextLocationLogin) + if err != nil { + return nil, fmt.Errorf("failed to encode login context properties: %w", err) + } + req.Header.Set(HeaderContextProperties, contextProps) + + _, body, err := am.doHandlingCaptcha(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to request legacy experiments: %w", err) + } + + var legacy ExperimentsLegacy + err = json.Unmarshal(body, &legacy) + if err != nil { + return nil, fmt.Errorf("failed to decode legacy experiments: %w", err) + } + + return &legacy, nil +} + +func (am *AuthMachine) performApexExperiments(ctx context.Context) (any, error) { + url := fmt.Sprintf("%s/apex/experiments?surface=2", am.APIBase) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to construct apex experiments request: %w", err) + } + + // (Apex experiments don't get `X-Context-Properties`.) + _, _, err = am.doHandlingCaptcha(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to request apex experiments: %w", err) + } + + return nil, nil +} + +// Prepare loads the login page and situates the AuthMachine with an +// experiments-related [Fingerprint]. It is important for Prepare to be called +// before [AuthMachine.Login]. +// +// Calling this method can lead to your [ChallengeHandler] being called. +func (am *AuthMachine) Prepare(ctx context.Context) error { + log := am.log.With().Str("action", "prepare discord auth machine").Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Preparing Discord auth") + + legacy, err := am.performLegacyExperiments(ctx) + if err != nil { + return fmt.Errorf("failed to perform legacy experiments: %w", err) + } + + _, err = am.performApexExperiments(ctx) + if err != nil { + return fmt.Errorf("failed to perform apex experiments: %w", err) + } + + // (Apex experiments aren't fetched with the fingerprint, so only set it + // now.) + am.State.Fingerprint = legacy.Fingerprint + if am.LogFilters.Fingerprint { + log.Info().Str("fingerprint", am.State.Fingerprint.HeaderValue()).Msg("Loaded Discord fingerprint") + } + + return nil +} + +// FIXME(skip): Load the HTML /login page before anything else so we can seed our cookies with Cloudflare stuff. +// FIXME(skip): Handle IP verification. +// FIXME(skip): Handle suspended user tokens. + +// Once you have called [AuthMachine.Prepare], Login kicks off the login +// process and doesn't return until the login is complete and a token is +// acquired, unless an error occurs at any point. +// +// CAPTCHA and MFA handling is automatically relegated to your +// [ChallengeHandler] and its methods will be called as necessary. +func (am *AuthMachine) Login(ctx context.Context, creds *Creds) (*LoginCompleted, error) { + log := zerolog.Ctx(ctx) + + if am.State.Fingerprint.IsZero() { + return nil, fmt.Errorf("can't log in without a fingerprint (forgot to call Prepare?)") + } + + firstLoginReq, err := am.POST(ctx, "/auth/login", creds) + if err != nil { + return nil, fmt.Errorf("failed to construct login request: %w", err) + } + + _, body, err := am.doHandlingCaptcha(ctx, firstLoginReq) + if err != nil { + return nil, fmt.Errorf("failed to request login: %w", err) + } + + loginResponse, err := am.handleFirstLoginResponse(ctx, body) + if err != nil { + return nil, err + } + + if am.LogFilters.SuccessfulLogin { + ev := log.Info() + if am.LogFilters.LoggedInUserID { + ev = ev.Str("user_id", loginResponse.UserID).Str("user_locale", loginResponse.UserSettings.Locale) + } + ev.Msg("Logged in successfully") + } + + return loginResponse, nil +} + +// handleFirstLoginResponse handles the response body from POSTing to +// /auth/login. This will either complete the login or begin an MFA flow. +func (am *AuthMachine) handleFirstLoginResponse(ctx context.Context, loginRespBody []byte) (*LoginCompleted, error) { + log := zerolog.Ctx(ctx) + + var completed LoginCompleted + err := json.Unmarshal(loginRespBody, &completed) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal login response: %w", err) + } + if !completed.HasToken() { + log.Debug().Msg("Response lacked a token, attempting to handle as MFA") + completedMfa, err := am.tryHandlingMFA(ctx, loginRespBody) + if err != nil { + return nil, fmt.Errorf("failed to handle potential MFA: %w", err) + } + + if completedMfa == nil || !completedMfa.HasToken() { + // Still unable to handle whatever we got as a response from POST + // /auth/login, give up. Log the response for diagnostics. + log.Error().Str("response_body", string(loginRespBody)).Msg("Received corrupted login response") + return nil, fmt.Errorf("corrupted login response") + } + return completedMfa, nil + } + + return &completed, nil +} + +func (am *AuthMachine) requestSMSCode(ctx context.Context, state *MFAState) (*SMSSendResponse, error) { + smsSendReq, err := am.POST(ctx, "/auth/mfa/sms/send", SMSSendRequest{ + Ticket: state.Ticket, + }) + if err != nil { + return nil, fmt.Errorf("failed to construct SMS send code request: %w", err) + } + smsSendReq.Header.Set("Content-Type", "application/json") + + _, body, err := am.doHandlingCaptcha(ctx, smsSendReq) + if err != nil { + return nil, fmt.Errorf("failed to request SMS code: %w", err) + } + + var resp SMSSendResponse + err = json.Unmarshal(body, &resp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal SMS code response: %w", err) + } + + return &resp, nil +} + +func (am *AuthMachine) tryHandlingMFA(ctx context.Context, loginRespBody []byte) (*LoginCompleted, error) { + baseLog := zerolog.Ctx(ctx) + + var required LoginMFARequired + err := json.Unmarshal(loginRespBody, &required) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal mfa required: %w", err) + } + + if !required.MFARequired { + // This isn't actually a LoginMFARequired. + return nil, nil + } + + logCtx := baseLog.With(). + Str("mfa_login_instance_id", required.LoginInstanceID). + Bool("mfa_accepting_backup_codes", required.BackupCodesAccepted). + Bool("mfa_sms_enabled", required.SMSEnabled). + Bool("mfa_totp_enabled", required.TOTPEnabled). + Bool("mfa_has_webauthn_credential", required.WebAuthnCredential != nil) + if am.LogFilters.LoggedInUserID { + logCtx = logCtx.Str("user_id", required.UserID) + } + log := logCtx.Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("Need to log in with MFA") + cont, err := am.handler.ContinueMFA(ctx, &MFAChallenge{ + LoginMFARequired: &required, + RequestSMS: func(ctx context.Context) (*SMSSendResponse, error) { + // Thread the MFAState through on behalf of the client. + return am.requestSMSCode(ctx, &required.MFAState) + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to continue mfa flow: %w", err) + } + if cont == nil { + return nil, fmt.Errorf("no MFA continuation returned") + } + + log.Info().Str("mfa_type", string(cont.Type)).Msg("Continuing with MFA flow") + + contReq, err := am.POST(ctx, fmt.Sprintf("/auth/mfa/%s", cont.Type), cont.MFAContinuation) + if err != nil { + return nil, fmt.Errorf("failed to construct MFA continuation request: %w", err) + } + + _, body, err := am.doHandlingCaptcha(ctx, contReq) + if err != nil { + return nil, fmt.Errorf("failed to complete MFA flow: %w", err) + } + var completed LoginCompleted + err = json.Unmarshal(body, &completed) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal completed MFA: %w", err) + } + + // Discord omits the user ID when completing the MFA flow as we already + // received it as part of LoginMFARequired. Re-add it here. + if completed.UserID == "" { + log.Trace().Msg("Fixing up MFA completion with the user ID") + completed.UserID = required.UserID + } + + return &completed, nil +} + +func (am *AuthMachine) POST(ctx context.Context, endpoint string, jsonBody any) (*http.Request, error) { + jsonBytes, err := json.Marshal(jsonBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal body for request: %w", err) + } + + url := am.APIBase + endpoint + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBytes)) + if err != nil { + return nil, fmt.Errorf("failed to make POST request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + return req, nil +} diff --git a/pkg/discordauth/machine_log_filters.go b/pkg/discordauth/machine_log_filters.go new file mode 100644 index 0000000..dbaf732 --- /dev/null +++ b/pkg/discordauth/machine_log_filters.go @@ -0,0 +1,53 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +type AuthMachineLogFilters struct { + EveryHTTPRequest bool + EveryHTTPResponse bool + + SuccessfulLogin bool + LoggedInUserID bool + Fingerprint bool + + // The following fields are likely to log credentials and other sensitive + // stuff when enabled. ONLY FOR USE DURING DEVELOPMENT. + + DangerouslyLeakyHTTPHeaders bool +} + +var DefaultAuthMachineLogFilters = AuthMachineLogFilters{ + EveryHTTPRequest: true, + EveryHTTPResponse: true, + + SuccessfulLogin: true, + LoggedInUserID: false, + Fingerprint: false, + + DangerouslyLeakyHTTPHeaders: false, +} + +var LeakyDevelopmentAuthMachineLogFilters = AuthMachineLogFilters{ + EveryHTTPRequest: true, + EveryHTTPResponse: true, + + SuccessfulLogin: true, + LoggedInUserID: true, + Fingerprint: true, + + DangerouslyLeakyHTTPHeaders: true, +} diff --git a/pkg/discordauth/machine_test.go b/pkg/discordauth/machine_test.go new file mode 100644 index 0000000..293a035 --- /dev/null +++ b/pkg/discordauth/machine_test.go @@ -0,0 +1,76 @@ +package discordauth + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" +) + +type testHTTPClient func(req *http.Request) (*http.Response, error) + +func (thc testHTTPClient) Do(req *http.Request) (*http.Response, error) { + return thc(req) +} + +type testChallengeHandler struct{} + +func (testChallengeHandler) SolveCaptcha(context.Context, *Captcha) (*CaptchaSolution, error) { + return &CaptchaSolution{Solution: "test-captcha-solution"}, nil +} + +func (testChallengeHandler) ContinueMFA(context.Context, *MFAChallenge) (*MFAContinue, error) { + return nil, errors.New("unexpected MFA continuation in test") +} + +func newTestPersonality() *Personality { + return &Personality{ + UserAgent: "test-agent", + Locale: "en-US", + TimeZone: "UTC", + DebugOptions: DefaultDebugOptions, + SuperProperties: SuperProperties{ + OS: "Windows", + Browser: "Chrome", + BrowserUserAgent: "test-agent", + BrowserVersion: "1.0.0.0", + OSVersion: "10", + ReleaseChannel: "stable", + ClientBuildNumber: 1, + ClientLaunchID: "launch-id", + ClientAppState: "focused", + }, + } +} + +func newResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestDoHandlingCaptchaAddsDebugOptionsHeader(t *testing.T) { + var gotHeader http.Header + client := testHTTPClient(func(req *http.Request) (*http.Response, error) { + gotHeader = req.Header.Clone() + return newResponse(http.StatusOK, `{"ok":true}`), nil + }) + + am := NewAuthMachine(context.Background(), client, newTestPersonality(), testChallengeHandler{}) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/test", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + _, _, err = am.doHandlingCaptcha(context.Background(), req) + if err != nil { + t.Fatalf("doHandlingCaptcha returned error: %v", err) + } + if gotHeader.Get(HeaderDebugOptions) != "bugReporterEnabled" { + t.Fatalf("expected %s header to be set, got %q", HeaderDebugOptions, gotHeader.Get(HeaderDebugOptions)) + } +} diff --git a/pkg/discordauth/mfa.go b/pkg/discordauth/mfa.go new file mode 100644 index 0000000..878b41f --- /dev/null +++ b/pkg/discordauth/mfa.go @@ -0,0 +1,108 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import "context" + +// An AuthenticatorType is what you append to "/auth/mfa/" to respond to an MFA +// challenge with that MFA method. +// +// For example, to use a TOTP code, you'd POST to "/auth/mfa/totp". +// [AuthenticatorTOTP] is "totp". +type AuthenticatorType string + +const ( + AuthenticatorTOTP AuthenticatorType = "totp" + AuthenticatorSMS AuthenticatorType = "sms" + AuthenticatorBackup AuthenticatorType = "backup" + AuthenticatorWebAuthn AuthenticatorType = "webauthn" + AuthenticatorPassword AuthenticatorType = "password" +) + +// MFAChallenge encapsulates the context regarding an in-progress MFA flow. Use +// the data from this struct to inform how you will proceed with +// authentication. +// +// This is received by a [ChallengeHandler]. +type MFAChallenge struct { + *LoginMFARequired + + // RequestSMS asks Discord to send a MFA code to the user's phone number. + // This will only work if the user has SMS MFA enabled. + RequestSMS func(context.Context) (*SMSSendResponse, error) +} + +// An [MFAContinue] combines an [AuthenticatorType] with an [MFAContinuation], +// which lets the [AuthMachine] know how to make the HTTP request to Discord. +// +// This is returned out of a [ChallengeHandler] when the client is ready to let +// the library know how to proceed with the MFA log in flow. +type MFAContinue struct { + Type AuthenticatorType + MFAContinuation +} + +// An MFAState encapsulates the essential, opaque data that is received from +// Discord when MFA is required to proceed with a log in. This data must be +// sent back as part of your MFA response ([MFAContinuation]). +// +// This struct exists solely for organizational purposes. +type MFAState struct { + Ticket Sensitive[string] `json:"ticket"` + LoginInstanceID string `json:"login_instance_id"` +} + +// A LoginMFARequired is returned from Discord's login endpoint when the +// password is accepted, but another authentication factor is required. +type LoginMFARequired struct { + MFAState + + UserID string `json:"user_id"` + MFARequired bool `json:"mfa"` // multi-factor authentication is required to log in + SMSEnabled bool `json:"sms"` // whether SMS-based MFA is enabled + BackupCodesAccepted bool `json:"backup"` // whether backup codes can be used in the response + TOTPEnabled bool `json:"totp"` + WebAuthnCredential *string `json:"webauthn"` // JSON string of {"publicKey": {"challenge": ...}} +} + +// POST an MFAContinuation to Discord upon receiving a [LoginMFARequired] and +// you have the necessary code (TOTP, SMS, backup, WebAuthn, etc.) to continue. +type MFAContinuation struct { + MFAState + + // The TOTP, SMS code, backup code, or Webauthn credential used to complete + // the MFA flow. + // + // Backup codes are displayed hyphenated in Discord's UI, which visually + // splits them in half. Discord's API will not accept backup codes with the + // hyphens intact, so they must be stripped before submission. + Code string `json:"code"` + + GiftCodeSKUID *string `json:"gift_code_sku_id"` + LoginSource *string `json:"login_source"` +} + +// POST an SMSSendRequest to Discord upon receiving a [LoginMFARequired] if SMS +// is a permitted MFA path and you'd like to send an SMS code to the user. +type SMSSendRequest struct { + Ticket Sensitive[string] `json:"ticket"` +} + +// SMSSendResponse is what Discord returns from /auth/mfa/sms/send. +type SMSSendResponse struct { + Phone string `json:"phone"` // partially redacted phone number +} diff --git a/pkg/discordauth/personality.go b/pkg/discordauth/personality.go new file mode 100644 index 0000000..19981cd --- /dev/null +++ b/pkg/discordauth/personality.go @@ -0,0 +1,110 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "encoding" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + + "github.com/bwmarrin/discordgo" +) + +const HeaderDiscordLocale = "x-discord-locale" +const HeaderDiscordTimezone = "x-discord-timezone" +const HeaderSuperProperties = "x-super-properties" +const HeaderContextProperties = "x-context-properties" +const HeaderFingerprint = "x-fingerprint" +const HeaderDebugOptions = "x-debug-options" + +const DefaultDebugOptions = "bugReporterEnabled" + +// Personality encapsulates some settings that clients are likely to want to +// customize. These values are sent in nearly every HTTP request to Discord. +type Personality struct { + UserAgent string + Locale string // `x-discord-locale` + TimeZone string // `x-discord-timezone` + DebugOptions string // `x-debug-options` + SuperProperties SuperProperties // `x-super-properties` (base64) + + ExtraHeaders map[string]string +} + +func (p *Personality) Headers() (http.Header, error) { + superProps, err := p.SuperProperties.MarshalText() + if err != nil { + return nil, fmt.Errorf("failed to marshal super properties: %w", err) + } + + header := make(http.Header) + header.Set("User-Agent", p.UserAgent) + header.Set(HeaderDiscordLocale, p.Locale) + header.Set(HeaderDiscordTimezone, p.TimeZone) + header.Set(HeaderSuperProperties, string(superProps)) + + for k, v := range p.ExtraHeaders { + header.Set(k, v) + } + + return header, nil +} + +// FIXME(skip): This is missing client_heartbeat_session_id... that's only +// relevant when you have a gateway connection, though (?) + +type SuperProperties struct { + OS string `json:"os"` + Browser string `json:"browser"` + Device string `json:"device"` + SystemLocale string `json:"system_locale"` + HasClientMods bool `json:"has_client_mods"` + BrowserUserAgent string `json:"browser_user_agent"` + BrowserVersion string `json:"browser_version"` + OSVersion string `json:"os_version"` + Referrer string `json:"referrer"` + ReferringDomain string `json:"referring_domain"` + ReferrerCurrent string `json:"referrer_current"` + ReferringDomainCurrent string `json:"referring_domain_current"` + ReleaseChannel string `json:"release_channel"` + ClientBuildNumber int `json:"client_build_number"` + ClientEventSource *string `json:"client_event_source"` + ClientLaunchID string `json:"client_launch_id"` + LaunchSignature discordgo.LaunchSignature `json:"launch_signature"` + ClientAppState string `json:"client_app_state"` +} + +var _ encoding.TextMarshaler = (*SuperProperties)(nil) + +func (sp *SuperProperties) MarshalText() ([]byte, error) { + // TODO(skip): Little bit of weird looking indirection here so we don't + // recurse infinitely. Should probably just remove this, then. + type superProperties SuperProperties + spJson, err := json.Marshal((*superProperties)(sp)) + if err != nil { + return nil, err + } + + // Avoid the string() call that EncodeToString incurs. + encoding := base64.StdEncoding + buf := make([]byte, encoding.EncodedLen(len(spJson))) + encoding.Encode(buf, spJson) + + return buf, nil +} diff --git a/pkg/discordauth/personality_test.go b/pkg/discordauth/personality_test.go new file mode 100644 index 0000000..dc12ff6 --- /dev/null +++ b/pkg/discordauth/personality_test.go @@ -0,0 +1,34 @@ +package discordauth + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestPersonalityHeadersEncodesSuperProperties(t *testing.T) { + personality := newTestPersonality() + + headers, err := personality.Headers() + if err != nil { + t.Fatalf("Headers returned error: %v", err) + } + + encoded := headers.Get(HeaderSuperProperties) + if encoded == "" { + t.Fatal("expected super properties header to be set") + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatalf("failed to decode super properties header: %v", err) + } + + var parsed map[string]any + if err = json.Unmarshal(decoded, &parsed); err != nil { + t.Fatalf("failed to unmarshal super properties JSON: %v", err) + } + if parsed["client_build_number"] != float64(1) { + t.Fatalf("expected client_build_number to equal 1, got %#v", parsed["client_build_number"]) + } +} diff --git a/pkg/discordauth/sensitive.go b/pkg/discordauth/sensitive.go new file mode 100644 index 0000000..f35b269 --- /dev/null +++ b/pkg/discordauth/sensitive.go @@ -0,0 +1,67 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordauth + +import ( + "encoding/json" + "fmt" + "io" + "reflect" +) + +// Sensitive is a trivial mitigation that guards against accidental leakage of +// important values such as passwords. It is not a security boundary and may be +// trivially unwrapped via [Sensitive.UnwrapSensitive], reflection, etc. +type Sensitive[T any] struct { + inner T +} + +var _ json.Marshaler = (*Sensitive[any])(nil) +var _ json.Unmarshaler = (*Sensitive[any])(nil) + +func NewSensitive[T any](inner T) Sensitive[T] { + return Sensitive[T]{inner} +} + +func (s Sensitive[T]) IsZero() bool { + return reflect.ValueOf(s.inner).IsZero() +} + +// UnwrapSensitive returns the sensitive data inside. +func (s Sensitive[T]) UnwrapSensitive() T { + return s.inner +} + +func (Sensitive[T]) Format(f fmt.State, verb rune) { + _, _ = io.WriteString(f, "") +} + +func (Sensitive[T]) String() string { + return "" +} + +func (Sensitive[T]) GoString() string { + return "" +} + +func (s Sensitive[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(s.inner) +} + +func (s *Sensitive[T]) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &s.inner) +} diff --git a/pkg/discordid/dbmeta.go b/pkg/discordid/dbmeta.go new file mode 100644 index 0000000..d8b6ba3 --- /dev/null +++ b/pkg/discordid/dbmeta.go @@ -0,0 +1,58 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordid + +import ( + "github.com/bwmarrin/discordgo" + "maunium.net/go/mautrix/bridgev2/database" +) + +type PortalMetadata struct { + // The ID of the Discord guild that the channel corresponding to this portal + // belongs to. + // + // For private channels (DMs and group DMs), this will be the zero value + // (an empty string). + GuildID string `json:"guild_id"` + + // The type of Discord channel this portal corresponds to. + // + // This is omitted for guild space portals. + ChannelType *discordgo.ChannelType `json:"channel_type,omitempty"` +} + +type UserLoginMetadata struct { + Token string `json:"token"` + HeartbeatSession discordgo.HeartbeatSession `json:"heartbeat_session"` + BridgedGuildIDs map[string]bool `json:"bridged_guild_ids,omitempty"` +} + +var _ database.MetaMerger = (*UserLoginMetadata)(nil) + +func (ulm *UserLoginMetadata) CopyFrom(incoming any) { + incomingMeta, ok := incoming.(*UserLoginMetadata) + if !ok || incomingMeta == nil { + return + } + + if incomingMeta.Token != "" { + ulm.Token = incomingMeta.Token + } + ulm.HeartbeatSession = discordgo.NewHeartbeatSession() + + // Retain the BridgedGuildIDs from the existing login. +} diff --git a/pkg/discordid/id.go b/pkg/discordid/id.go new file mode 100644 index 0000000..686b4b1 --- /dev/null +++ b/pkg/discordid/id.go @@ -0,0 +1,176 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordid + +import ( + "strconv" + "strings" + "time" + + "github.com/bwmarrin/discordgo" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// DeletedGuildUserID is a magic user ID that is used in place of an actual user +// ID once they have deleted their account. This only applies in non-private +// (i.e. guild) contexts, such as guild channel message authors and mentions. +// +// Note that this user ID can also appear in message content as part of user +// mention markup ("<@456226577798135808>"). +const DeletedGuildUserID = "456226577798135808" + +// DeletedGuildUser is the user returned from the Discord API as a stand-in for +// users who have since deleted their account. As the name suggests, this only +// applies to fetched entities within guilds. +var DeletedGuildUser = discordgo.User{ + ID: DeletedGuildUserID, + Username: "Deleted User", + Discriminator: "0000", +} + +const DiscordEpochMillis = 1420070400000 + +// GenerateNonce creates a Discord-style snowflake nonce for message idempotency. +func GenerateNonce() string { + snowflake := (time.Now().UnixMilli() - DiscordEpochMillis) << 22 + return strconv.FormatInt(snowflake, 10) +} + +func MakeUserID(userID string) networkid.UserID { + return networkid.UserID(userID) +} + +func ParseUserID(userID networkid.UserID) string { + return string(userID) +} + +func MakeUserLoginID(userID string) networkid.UserLoginID { + return networkid.UserLoginID(userID) +} + +func ParseUserLoginID(id networkid.UserLoginID) string { + return string(id) +} + +// UserLoginIDToUserID converts a UserLoginID to a UserID. In Discord, both +// are the same underlying snowflake. +func UserLoginIDToUserID(id networkid.UserLoginID) networkid.UserID { + return networkid.UserID(id) +} + +// MakeChannelPortalKey creates a PortalKey from a Discord channel ID and the +// user login it was received from. +// +// If you can reach a DiscordClient, prefer calling the helper methods defined +// on it instead, as split portal configuration will be respected for you. +func MakeChannelPortalKey(channelID string, userLoginID networkid.UserLoginID, wantReceiver bool) (key networkid.PortalKey) { + key.ID = MakeChannelPortalIDWithID(channelID) + if wantReceiver { + key.Receiver = userLoginID + } + return +} + +func MakeChannelPortalKeyWithID(channelID string) (key networkid.PortalKey) { + key.ID = MakeChannelPortalIDWithID(channelID) + return +} + +func MakeChannelPortalIDWithID(channelID string) networkid.PortalID { + return networkid.PortalID(channelID) +} + +func ParseChannelPortalID(portalID networkid.PortalID) string { + return string(portalID) +} + +func MakeMessageID(messageID string) networkid.MessageID { + return networkid.MessageID(messageID) +} + +func ParseMessageID(messageID networkid.MessageID) string { + return string(messageID) +} + +func MakeEmojiID(emojiName string) networkid.EmojiID { + return networkid.EmojiID(emojiName) +} + +func ParseEmojiID(emojiID networkid.EmojiID) string { + return string(emojiID) +} + +func MakeAvatarID(avatar string) networkid.AvatarID { + return networkid.AvatarID(avatar) +} + +func MakePartID(attachmentID string) networkid.PartID { + return networkid.PartID(attachmentID) +} + +func ParsePartID(attachmentID string) string { + return string(attachmentID) +} + +// The string prepended to [networkid.PortalKey]s identifying spaces that +// bridge Discord guilds. +// +// Every Discord guild created before August 2017 contained a channel +// having _the same ID as the guild itself_. This channel also functioned as +// the "default channel" in that incoming members would view this channel by +// default. It was also impossible to delete. +// +// After this date, these "default channels" became deletable, and fresh guilds +// were no longer created with a channel that exactly corresponded to the guild +// ID. +// +// To accommodate Discord guilds created before this API change that have also +// never deleted the default channel, we need a way to distinguish between the +// guild and the default channel. Otherwise, we wouldn't be able to bridge both +// the channel portal as well as the guild space; their keys would conflict. +// +// "*" was chosen as the asterisk character is used to filter by guilds in +// the quick switcher (in Discord's first-party clients). +// +// For more information, see: https://discord.com/developers/docs/change-log#breaking-change-default-channels:~:text=New%20guilds%20will%20no%20longer. +const GuildPortalKeySigil = "*" + +func MakeGuildPortalIDWithID(guildID string) networkid.PortalID { + return networkid.PortalID(GuildPortalKeySigil + guildID) +} + +func MakeGuildPortalKey(guildID string, userLoginID networkid.UserLoginID, wantReceiver bool) (key networkid.PortalKey) { + key.ID = MakeGuildPortalIDWithID(guildID) + if wantReceiver { + key.Receiver = userLoginID + } + return +} + +// ParseGuildPortalID converts a [network.PortalID] pointing to a guild space +// back into the guild's ID on Discord. +// +// If the portal ID does not point to a guild, then an empty string is returned. +func ParseGuildPortalID(portalID networkid.PortalID) string { + opaque := string(portalID) + if strings.HasPrefix(opaque, GuildPortalKeySigil) { + guildID := opaque[1:] + return guildID + } + + return "" +} diff --git a/pkg/discordid/mediaid.go b/pkg/discordid/mediaid.go new file mode 100644 index 0000000..bcfa789 --- /dev/null +++ b/pkg/discordid/mediaid.go @@ -0,0 +1,166 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordid + +import ( + "encoding/binary" + "fmt" + "strconv" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type DirectMediaType byte + +const ( + DirectMediaTypeV1 DirectMediaType = 1 + + encodedSnowflakeSize = 8 + encodedMediaIDV1Size = 1 + 4*encodedSnowflakeSize +) + +func (dmt DirectMediaType) isSupported() bool { + switch dmt { + case DirectMediaTypeV1: + return true + } + return false +} + +type MediaInfoV1 struct { + UserLoginID networkid.UserLoginID + ChannelID string + MessageID string + AttachmentID string +} + +type MediaInfo struct { + Type DirectMediaType + MediaInfoV1 +} + +func NewMediaInfoV1(userLoginID networkid.UserLoginID, channelID, messageID, attachmentID string) MediaInfo { + return MediaInfo{ + Type: DirectMediaTypeV1, + MediaInfoV1: MediaInfoV1{ + UserLoginID: userLoginID, + ChannelID: channelID, + MessageID: messageID, + AttachmentID: attachmentID, + }, + } +} + +func (mi *MediaInfo) Encode() ([]byte, error) { + buf := make([]byte, 1, encodedMediaIDV1Size) + buf[0] = byte(mi.Type) + + appendSnowflake := func(what, snowflakeStr string) error { + snowflake, err := strconv.ParseUint(snowflakeStr, 10, 64) + if err != nil { + return fmt.Errorf("invalid %s: %w", what, err) + } + + buf = binary.BigEndian.AppendUint64(buf, snowflake) + return nil + } + + if err := appendSnowflake("user login id", ParseUserLoginID(mi.UserLoginID)); err != nil { + return nil, err + } + if err := appendSnowflake("channel id", mi.ChannelID); err != nil { + return nil, err + } + if err := appendSnowflake("message id", mi.MessageID); err != nil { + return nil, err + } + if err := appendSnowflake("attachment id", mi.AttachmentID); err != nil { + return nil, err + } + + return buf, nil +} + +func ParseMediaID(mediaID networkid.MediaID) (*MediaInfo, error) { + var info MediaInfo + + ptr := 0 + read := func(size int, what string) ([]byte, error) { + if len(mediaID) < ptr+size { + return nil, fmt.Errorf("media ID too short (%d bytes) to read %d byte %s starting at byte %d", len(mediaID), size, what, ptr) + } + b := mediaID[ptr : ptr+size] + ptr += size + return b, nil + } + readOne := func(what string) (byte, error) { + b, err := read(1, what) + if err != nil { + return 0, err + } + return b[0], nil + } + readSnowflake := func(what string) (string, error) { + snowflakeBytes, err := read(encodedSnowflakeSize, what) + if err != nil { + return "", err + } + + snowflake := binary.BigEndian.Uint64(snowflakeBytes) + return strconv.FormatUint(snowflake, 10), nil + } + + mediaType, err := readOne("media type") + if err != nil { + return nil, err + } + info.Type = DirectMediaType(mediaType) + + if !info.Type.isSupported() { + return nil, fmt.Errorf("unrecognized media type %d", info.Type) + } + + userLoginID, err := readSnowflake("user login id") + info.UserLoginID = networkid.UserLoginID(userLoginID) + if err != nil { + return nil, err + } + + channelID, err := readSnowflake("channel id") + info.ChannelID = channelID + if err != nil { + return nil, err + } + + messageID, err := readSnowflake("message id") + info.MessageID = messageID + if err != nil { + return nil, err + } + + attachmentID, err := readSnowflake("attachment id") + info.AttachmentID = attachmentID + if err != nil { + return nil, err + } + + if ptr != len(mediaID) { + return nil, fmt.Errorf("media ID has %d trailing bytes", len(mediaID)-ptr) + } + + return &info, nil +} diff --git a/pkg/discordid/mediaid_test.go b/pkg/discordid/mediaid_test.go new file mode 100644 index 0000000..4fbe14a --- /dev/null +++ b/pkg/discordid/mediaid_test.go @@ -0,0 +1,104 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package discordid + +import "testing" + +func TestMediaIDRoundTrip(t *testing.T) { + testCases := []struct { + name string + userLoginID string + channelID string + messageID string + attachmentID string + }{ + { + name: "single digit", + userLoginID: "1", + channelID: "2", + messageID: "3", + attachmentID: "4", + }, + { + name: "mixed short lengths", + userLoginID: "12", + channelID: "345", + messageID: "6789", + attachmentID: "12345", + }, + { + name: "discord sized", + userLoginID: "12345678901234567", + channelID: "234567890123456789", + messageID: "345678901234567890", + attachmentID: "456789012345678901", + }, + { + name: "nineteen digits", + userLoginID: "1000000000000000000", + channelID: "1000000000000000001", + messageID: "1000000000000000002", + attachmentID: "1000000000000000003", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + want := NewMediaInfoV1( + MakeUserLoginID(tc.userLoginID), + tc.channelID, + tc.messageID, + tc.attachmentID, + ) + + encoded, err := want.Encode() + if err != nil { + t.Fatalf("Encode() failed: %v", err) + } + if len(encoded) != encodedMediaIDV1Size { + t.Fatalf("Encode() returned %d bytes, want %d", len(encoded), encodedMediaIDV1Size) + } + + got, err := ParseMediaID(encoded) + if err != nil { + t.Fatalf("ParseMediaID() failed: %v", err) + } + if *got != want { + t.Fatalf("roundtrip mismatch:\n got: %#v\n want: %#v", *got, want) + } + }) + } +} + +func TestParseMediaIDRejectsTruncatedData(t *testing.T) { + info := NewMediaInfoV1( + MakeUserLoginID("123456789012345678"), + "223456789012345678", + "323456789012345678", + "423456789012345678", + ) + + encoded, err := info.Encode() + if err != nil { + t.Fatalf("Encode() returned error: %v", err) + } + + _, err = ParseMediaID(encoded[:len(encoded)-1]) + if err == nil { + t.Fatal("ParseMediaID() unexpectedly succeeded for truncated data") + } +} diff --git a/pkg/msgconv/attachments.go b/pkg/msgconv/attachments.go new file mode 100644 index 0000000..281b3fd --- /dev/null +++ b/pkg/msgconv/attachments.go @@ -0,0 +1,135 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type ReuploadedAttachment struct { + MXC id.ContentURIString + File *event.EncryptedFileInfo + Size int + MimeType string +} + +func (mc *MessageConverter) ReuploadUnknownMedia( + ctx context.Context, + url string, + allowEncryption bool, +) (*ReuploadedAttachment, error) { + return mc.ReuploadMedia(ctx, url, "", "", -1, allowEncryption) +} + +func mib(size int64) float64 { + return float64(size) / 1024 / 1024 +} + +func (mc *MessageConverter) ReuploadMedia( + ctx context.Context, + downloadURL string, + mimeType string, + fileName string, + estimatedSize int, + allowEncryption bool, +) (*ReuploadedAttachment, error) { + sess := ctx.Value(contextKeyDiscordClient).(*discordgo.Session) + httpClient := sess.Client + intent := ctx.Value(contextKeyIntent).(bridgev2.MatrixAPI) + var roomID id.RoomID + if allowEncryption { + roomID = ctx.Value(contextKeyPortal).(*bridgev2.Portal).MXID + } + + req, err := http.NewRequest(http.MethodGet, downloadURL, nil) + if err != nil { + return nil, err + } + if sess.IsUser { + for key, value := range discordgo.DroidDownloadHeaders { + req.Header.Set(key, value) + } + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode > 300 { + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + logEvt := zerolog.Ctx(ctx).Error(). + Str("media_url", downloadURL). + Int("status_code", resp.StatusCode) + if json.Valid(errBody) { + logEvt.RawJSON("error_json", errBody) + } else { + logEvt.Bytes("error_body", errBody) + } + logEvt.Msg("Media download failed") + return nil, fmt.Errorf("%w: unexpected status code %d", bridgev2.ErrMediaDownloadFailed, resp.StatusCode) + } else if resp.ContentLength > mc.MaxFileSize { + return nil, fmt.Errorf("%w (%.2f MiB > %.2f MiB)", bridgev2.ErrMediaTooLarge, mib(resp.ContentLength), mib(mc.MaxFileSize)) + } + + requireFile := mimeType == "" + var size int64 + mxc, file, err := intent.UploadMediaStream(ctx, roomID, int64(estimatedSize), requireFile, func(file io.Writer) (*bridgev2.FileStreamResult, error) { + var mbe *http.MaxBytesError + size, err = io.Copy(file, http.MaxBytesReader(nil, resp.Body, mc.MaxFileSize)) + if err != nil { + if errors.As(err, &mbe) { + return nil, fmt.Errorf("%w (over %.2f MiB)", bridgev2.ErrMediaTooLarge, mib(mc.MaxFileSize)) + } + return nil, err + } + if mimeType == "" { + mimeBuf := make([]byte, 512) + n, err := file.(*os.File).ReadAt(mimeBuf, 0) + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("couldn't read file for mime detection: %w", err) + } + mimeType = http.DetectContentType(mimeBuf[:n]) + } + return &bridgev2.FileStreamResult{ + FileName: fileName, + MimeType: mimeType, + }, nil + }) + if err != nil { + return nil, err + } + + return &ReuploadedAttachment{ + Size: int(size), + MXC: mxc, + File: file, + MimeType: mimeType, + }, nil +} diff --git a/pkg/msgconv/embed.go b/pkg/msgconv/embed.go new file mode 100644 index 0000000..0bb6921 --- /dev/null +++ b/pkg/msgconv/embed.go @@ -0,0 +1,97 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "regexp" + + "github.com/bwmarrin/discordgo" +) + +type BridgeEmbedType int + +const ( + EmbedUnknown BridgeEmbedType = iota + EmbedRich + EmbedLinkPreview + EmbedVideo +) + +const discordLinkPattern = `https?://[^<\p{Zs}\x{feff}]*[^"'),.:;\]\p{Zs}\x{feff}]` + +// Discord links start with http:// or https://, contain at least two characters afterwards, +// don't contain < or whitespace anywhere, and don't end with "'),.:;] +// +// Zero-width whitespace is mostly in the Format category and is allowed, except \uFEFF isn't for some reason +var discordLinkRegex = regexp.MustCompile(discordLinkPattern) +var discordLinkRegexFull = regexp.MustCompile("^" + discordLinkPattern + "$") + +func isActuallyLinkPreview(embed *discordgo.MessageEmbed) bool { + // Sending YouTube links creates a video embed, but we want to bridge it as a URL preview, + // so this is a hacky way to detect those. + return embed.Video != nil && embed.Video.ProxyURL == "" +} + +// isPlainGifMessage returns whether a Discord message consists entirely of a +// link to a GIF-like animated image. A single embed must also be present on the +// message. +// +// This helps replicate Discord first-party client behavior, where the link is +// hidden when these same conditions are fulfilled. +func isPlainGifMessage(msg *discordgo.Message) bool { + if len(msg.Embeds) != 1 { + return false + } + embed := msg.Embeds[0] + isGifVideo := embed.Type == discordgo.EmbedTypeGifv && embed.Video != nil + isGifImage := embed.Type == discordgo.EmbedTypeImage && embed.Image == nil && embed.Thumbnail != nil && embed.Title == "" + contentIsOnlyURL := msg.Content == embed.URL || discordLinkRegexFull.MatchString(msg.Content) + return contentIsOnlyURL && (isGifVideo || isGifImage) +} + +// getEmbedType determines how a Discord embed should be bridged to Matrix by +// returning a BridgeEmbedType. +func getEmbedType(msg *discordgo.Message, embed *discordgo.MessageEmbed) BridgeEmbedType { + switch embed.Type { + case discordgo.EmbedTypeLink, discordgo.EmbedTypeArticle: + return EmbedLinkPreview + case discordgo.EmbedTypeVideo: + if isActuallyLinkPreview(embed) { + return EmbedLinkPreview + } + return EmbedVideo + case discordgo.EmbedTypeGifv: + return EmbedVideo + case discordgo.EmbedTypeImage: + if msg != nil && isPlainGifMessage(msg) { + return EmbedVideo + } else if embed.Image == nil && embed.Thumbnail != nil { + return EmbedLinkPreview + } + return EmbedRich + case discordgo.EmbedTypeRich: + return EmbedRich + default: + return EmbedUnknown + } +} + +var hackyReplyPattern = regexp.MustCompile(`^\*\*\[Replying to]\(https://discord.com/channels/(\d+)/(\d+)/(\d+)\)`) + +func isReplyEmbed(embed *discordgo.MessageEmbed) bool { + return hackyReplyPattern.MatchString(embed.Description) +} diff --git a/pkg/msgconv/formatter.go b/pkg/msgconv/formatter.go new file mode 100644 index 0000000..88a859f --- /dev/null +++ b/pkg/msgconv/formatter.go @@ -0,0 +1,132 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "fmt" + "regexp" + "strings" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/extension" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/util" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/format/mdext" +) + +// escapeFixer is a hacky partial fix for the difference in escaping markdown, used with escapeReplacement +// +// Discord allows escaping with just one backslash, e.g. \__a__, +// but standard markdown requires both to be escaped (\_\_a__) +var escapeFixer = regexp.MustCompile(`\\(__[^_]|\*\*[^*])`) + +func escapeReplacement(s string) string { + return s[:2] + `\` + s[2:] +} + +// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine. +// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear). +type indentableParagraphParser struct { + parser.BlockParser +} + +var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()} + +func (b *indentableParagraphParser) CanAcceptIndentedLine() bool { + return true +} + +var removeFeaturesExceptLinks = []any{ + parser.NewListParser(), parser.NewListItemParser(), parser.NewHTMLBlockParser(), parser.NewRawHTMLParser(), + parser.NewSetextHeadingParser(), parser.NewThematicBreakParser(), + parser.NewCodeBlockParser(), +} +var removeFeaturesAndLinks = append(removeFeaturesExceptLinks, parser.NewLinkParser()) +var fixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500))) +var discordExtensions = goldmark.WithExtensions(extension.Strikethrough, mdext.SimpleSpoiler, mdext.DiscordUnderline, ExtDiscordEveryone, ExtDiscordTag) + +var discordRenderer = goldmark.New( + goldmark.WithParser(mdext.ParserWithoutFeatures(removeFeaturesAndLinks...)), + fixIndentedParagraphs, format.HTMLOptions, discordExtensions, +) +var discordRendererWithInlineLinks = goldmark.New( + goldmark.WithParser(mdext.ParserWithoutFeatures(removeFeaturesExceptLinks...)), + fixIndentedParagraphs, format.HTMLOptions, discordExtensions, +) + +// renderDiscordMarkdownOnlyHTML converts Discord-flavored Markdown text to HTML. +// +// After conversion, if the text is surrounded by a single outermost paragraph +// tag, it is unwrapped. +func (mc *MessageConverter) renderDiscordMarkdownOnlyHTML(portal *bridgev2.Portal, source *bridgev2.UserLogin, text string, allowInlineLinks bool) string { + return format.UnwrapSingleParagraph(mc.renderDiscordMarkdownOnlyHTMLNoUnwrap(portal, source, text, allowInlineLinks)) +} + +// renderDiscordMarkdownOnlyHTMLNoUnwrap converts Discord-flavored Markdown text to HTML. +func (mc *MessageConverter) renderDiscordMarkdownOnlyHTMLNoUnwrap(portal *bridgev2.Portal, source *bridgev2.UserLogin, text string, allowInlineLinks bool) string { + text = escapeFixer.ReplaceAllStringFunc(text, escapeReplacement) + + var buf strings.Builder + ctx := parser.NewContext() + ctx.Set(parserContextPortal, portal) + ctx.Set(parserContextUserLogin, source) + renderer := discordRenderer + if allowInlineLinks { + renderer = discordRendererWithInlineLinks + } + err := renderer.Convert([]byte(text), &buf, parser.WithContext(ctx)) + if err != nil { + panic(fmt.Errorf("markdown parser errored: %w", err)) + } + return buf.String() +} + +const formatterContextPortalKey = "fi.mau.discord.portal" +const formatterContextAllowedMentionsKey = "fi.mau.discord.allowed_mentions" +const formatterContextInputAllowedMentionsKey = "fi.mau.discord.input_allowed_mentions" +const formatterContextInputAllowedLinkPreviewsKey = "fi.mau.discord.input_allowed_link_previews" + +var discordMarkdownEscaper = strings.NewReplacer( + `\`, `\\`, + `_`, `\_`, + `*`, `\*`, + `~`, `\~`, + "`", "\\`", + `|`, `\|`, + `<`, `\<`, + `#`, `\#`, +) + +func escapeDiscordMarkdown(s string) string { + submatches := discordLinkRegex.FindAllStringIndex(s, -1) + if submatches == nil { + return discordMarkdownEscaper.Replace(s) + } + var builder strings.Builder + offset := 0 + for _, match := range submatches { + start := match[0] + end := match[1] + builder.WriteString(discordMarkdownEscaper.Replace(s[offset:start])) + builder.WriteString(s[start:end]) + offset = end + } + builder.WriteString(discordMarkdownEscaper.Replace(s[offset:])) + return builder.String() +} diff --git a/formatter_everyone.go b/pkg/msgconv/formatter_everyone.go similarity index 98% rename from formatter_everyone.go rename to pkg/msgconv/formatter_everyone.go index b1aed5a..6a2195f 100644 --- a/formatter_everyone.go +++ b/pkg/msgconv/formatter_everyone.go @@ -1,5 +1,5 @@ // mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2023 Tulir Asokan +// Copyright (C) 2026 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package main +package msgconv import ( "fmt" diff --git a/formatter_tag.go b/pkg/msgconv/formatter_tag.go similarity index 66% rename from formatter_tag.go rename to pkg/msgconv/formatter_tag.go index fb7f741..559b73f 100644 --- a/formatter_tag.go +++ b/pkg/msgconv/formatter_tag.go @@ -1,5 +1,5 @@ // mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2026 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -14,30 +14,37 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package main +package msgconv import ( + "context" "fmt" + "html" "math" "regexp" "strconv" "strings" "time" + "github.com/rs/zerolog" "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" "github.com/yuin/goldmark/parser" "github.com/yuin/goldmark/renderer" "github.com/yuin/goldmark/text" "github.com/yuin/goldmark/util" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" - "go.mau.fi/mautrix-discord/database" + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" + "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/router" ) type astDiscordTag struct { ast.BaseInline - portal *Portal + source *bridgev2.UserLogin + portal *bridgev2.Portal id int64 } @@ -136,6 +143,15 @@ func (n *astDiscordCustomEmoji) String() string { type discordTagParser struct{} +type customEmojiMXCProvider interface { + GetCustomEmojiMXC(ctx context.Context, emojiID, name string, animated bool) (id.ContentURIString, error) +} + +// (This interface is to avoid an import cycle.) +type roleInfoProvider interface { + GetRoleByID(ctx context.Context, guildID, roleID string) (*discorddb.Role, error) +} + // Regex to match everything in https://discord.com/developers/docs/reference#message-formatting var discordTagRegex = regexp.MustCompile(`<(a?:\w+:|@[!&]?|#|t:)(\d+)(?::([tTdDfFR])|(\d+):(.+?))?>`) var defaultDiscordTagParser = &discordTagParser{} @@ -145,9 +161,11 @@ func (s *discordTagParser) Trigger() []byte { } var parserContextPortal = parser.NewContextKey() +var parserContextUserLogin = parser.NewContextKey() func (s *discordTagParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { - portal := pc.Get(parserContextPortal).(*Portal) + portal := pc.Get(parserContextPortal).(*bridgev2.Portal) + source := pc.Get(parserContextUserLogin).(*bridgev2.UserLogin) //before := block.PrecendingCharacter() line, _ := block.PeekLine() match := discordTagRegex.FindSubmatch(line) @@ -161,7 +179,7 @@ func (s *discordTagParser) Parse(parent ast.Node, block text.Reader, pc parser.C if err != nil { return nil } - tag := astDiscordTag{id: id, portal: portal} + tag := astDiscordTag{id: id, source: source, portal: portal} tagName := string(match[1]) switch { case tagName == "@": @@ -261,50 +279,100 @@ func (r *discordTagHTMLRenderer) renderDiscordMention(w util.BufWriter, source [ if !entering { return } + + log := zerolog.DefaultContextLogger.With().Str("action", "render discord mention").Logger() + ctx := log.WithContext(context.TODO()) + switch node := n.(type) { case *astDiscordUserMention: var mxid id.UserID var name string - if puppet := node.portal.bridge.GetPuppetByID(strconv.FormatInt(node.id, 10)); puppet != nil { - mxid = puppet.MXID - name = puppet.Name + discordUserID := strconv.FormatInt(node.id, 10) + bridge := node.portal.Bridge + + if ghost, _ := bridge.GetGhostByID(ctx, discordid.MakeUserID(discordUserID)); ghost != nil { + // TODO: Provide some kind of config option for this in the future. + // msgconv being in its own package means we can't just reach into + // the config. For now, avoid. + // + // if ghost.Name == "" { + // ghost.UpdateInfoIfNecessary(ctx, node.source, bridgev2.RemoteEventUnknown) + // } + mxid = ghost.Intent.GetMXID() + name = ghost.Name } - if user := node.portal.bridge.GetUserByID(strconv.FormatInt(node.id, 10)); user != nil { - mxid = user.MXID + if discordUserID == discordid.ParseUserLoginID(node.source.ID) { + // Mentioning ourselves. + mxid = node.source.UserMXID + } else if ul := node.portal.Bridge.GetCachedUserLoginByID(discordid.MakeUserLoginID(discordUserID)); ul != nil { + // If the Discord user mentioned corresponds to someone else logged + // into the bridge, prefer their "real" MXID instead of the + // ghost's. + mxid = ul.UserMXID + } + + if mxid != "" { if name == "" { - name = user.MXID.Localpart() + name = fmt.Sprintf("@%d", node.id) } + _, _ = fmt.Fprintf(w, `%s`, mxid.URI().MatrixToURL(), html.EscapeString(name)) + } else { + _, _ = fmt.Fprintf(w, "<@%d>", node.id) } - _, _ = fmt.Fprintf(w, `%s`, mxid.URI().MatrixToURL(), name) return case *astDiscordRoleMention: - role := node.portal.bridge.DB.Role.GetByID(node.portal.GuildID, strconv.FormatInt(node.id, 10)) - if role != nil { - _, _ = fmt.Fprintf(w, `@%s`, role.Color, role.Name) - return + meta, _ := node.portal.Metadata.(*discordid.PortalMetadata) + if meta != nil && meta.GuildID != "" { + if provider, ok := node.portal.Bridge.Network.(roleInfoProvider); ok { + role, roleErr := provider.GetRoleByID(ctx, meta.GuildID, strconv.FormatInt(node.id, 10)) + if roleErr != nil { + node.portal.Log.Warn(). + Err(roleErr). + Str("guild_id", meta.GuildID). + Int64("role_id", node.id). + Msg("Failed to resolve role while rendering mention") + } else if role != nil { + _, _ = fmt.Fprintf(w, `@%s`, role.Color, html.EscapeString(role.Name)) + return + } + } } case *astDiscordChannelMention: - portal := node.portal.bridge.GetExistingPortalByID(database.PortalKey{ - ChannelID: strconv.FormatInt(node.id, 10), - Receiver: "", - }) - if portal != nil { - if portal.MXID != "" { - _, _ = fmt.Fprintf(w, `%s`, portal.MXID.URI(portal.bridge.AS.HomeserverDomain).MatrixToURL(), portal.Name) - } else { - _, _ = w.WriteString(portal.Name) + rtr, ok := node.source.Client.(router.Router) + + if ok { + var r *router.Route + mentionedChannelID := strconv.FormatInt(node.id, 10) + r, err = rtr.Route(ctx, mentionedChannelID) + + if err == nil && !r.Uncertain { + if portal, _ := node.portal.Bridge.GetExistingPortalByKey(ctx, r.PortalKey); portal != nil { + if portal.MXID != "" { + _, _ = fmt.Fprintf(w, `%s`, portal.MXID.URI(portal.Bridge.Matrix.ServerName()).MatrixToURL(), html.EscapeString(portal.Name)) + } else { + _, _ = w.WriteString(portal.Name) + } + return + } + } else if err != nil { + node.portal.Log.Err(err).Msg("Failed to route mentioned channel") } - return } case *astDiscordCustomEmoji: - reactionMXC := node.portal.getEmojiMXCByDiscordID(strconv.FormatInt(node.id, 10), node.name, node.animated) - if !reactionMXC.IsEmpty() { - attrs := "data-mx-emoticon" - if node.animated { - attrs += " data-mau-animated-emoji" + if resolver, ok := node.portal.Bridge.Network.(customEmojiMXCProvider); ok { + reactionMXC, resolveErr := resolver.GetCustomEmojiMXC(ctx, strconv.FormatInt(node.id, 10), node.name, node.animated) + + if resolveErr != nil { + node.portal.Log.Warn().Err(resolveErr).Int64("emoji_id", node.id).Msg("Failed to resolve custom emoji while rendering message") + } else if reactionMXC != "" { + attrs := "data-mx-emoticon" + if node.animated { + attrs += " data-mau-animated-emoji" + } + + _, _ = fmt.Fprintf(w, `%[2]s`, string(reactionMXC), node.name, attrs) + return } - _, _ = fmt.Fprintf(w, `%[2]s`, reactionMXC.String(), node.name, attrs) - return } case *astDiscordTimestamp: ts := time.Unix(node.timestamp, 0).UTC() diff --git a/pkg/msgconv/from-discord.go b/pkg/msgconv/from-discord.go new file mode 100644 index 0000000..51b03e1 --- /dev/null +++ b/pkg/msgconv/from-discord.go @@ -0,0 +1,826 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "context" + "fmt" + "html" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "go.mau.fi/util/exmaps" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + + "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/router" +) + +type contextKey int + +const ( + contextKeyPortal contextKey = iota + contextKeyIntent + contextKeyUserLogin + contextKeyDiscordClient +) + +// ToMatrix bridges a Discord message to Matrix. +// +// This method expects ghost information to be up-to-date. +func (mc *MessageConverter) ToMatrix( + ctx context.Context, + portal *bridgev2.Portal, + intent bridgev2.MatrixAPI, + source *bridgev2.UserLogin, + session *discordgo.Session, + msg *discordgo.Message, + knownThreadRoot *networkid.MessageID, +) *bridgev2.ConvertedMessage { + ctx = context.WithValue(ctx, contextKeyUserLogin, source) + ctx = context.WithValue(ctx, contextKeyIntent, intent) + ctx = context.WithValue(ctx, contextKeyPortal, portal) + ctx = context.WithValue(ctx, contextKeyDiscordClient, session) + predictedLength := len(msg.Attachments) + len(msg.StickerItems) + if msg.Content != "" { + predictedLength++ + } + parts := make([]*bridgev2.ConvertedMessagePart, 0, predictedLength) + if textPart := mc.renderDiscordTextMessage(ctx, intent, portal, msg, source); textPart != nil { + parts = append(parts, textPart) + } + + ctx = zerolog.Ctx(ctx).With(). + Str("action", "convert discord message to matrix"). + Str("message_id", msg.ID). + Logger().WithContext(ctx) + log := zerolog.Ctx(ctx) + handledIDs := make(exmaps.Set[string]) + + for _, att := range msg.Attachments { + if !handledIDs.Add(att.ID) { + continue + } + + log := log.With().Str("attachment_id", att.ID).Logger() + mediaInfo := discordid.NewMediaInfoV1(source.ID, msg.ChannelID, msg.ID, att.ID) + if part := mc.renderDiscordAttachment(log.WithContext(ctx), att, &mediaInfo); part != nil { + parts = append(parts, part) + } + } + + for _, sticker := range msg.StickerItems { + if !handledIDs.Add(sticker.ID) { + continue + } + + log := log.With().Str("sticker_id", sticker.ID).Logger() + if part := mc.renderDiscordSticker(log.WithContext(ctx), sticker); part != nil { + parts = append(parts, part) + } + } + + for i, embed := range msg.Embeds { + // Ignore non-video embeds, they're handled in convertDiscordTextMessage + if getEmbedType(msg, embed) != EmbedVideo { + continue + } + // Discord deduplicates embeds by URL. It makes things easier for us too. + if !handledIDs.Add(embed.URL) { + continue + } + + log := log.With(). + Str("computed_embed_type", "video"). + Str("embed_type", string(embed.Type)). + Int("embed_index", i). + Logger() + part := mc.renderDiscordVideoEmbed(log.WithContext(ctx), embed) + if part != nil { + parts = append(parts, part) + } + } + + if len(parts) == 0 && msg.Thread != nil { + parts = append(parts, &bridgev2.ConvertedMessagePart{Type: event.EventMessage, Content: &event.MessageEventContent{ + MsgType: event.MsgText, + Body: fmt.Sprintf("Created a thread: %s", msg.Thread.Name), + }}) + } + + // TODO(skip): Add extra metadata. + // for _, part := range parts { + // puppet.addWebhookMeta(part, msg) + // puppet.addMemberMeta(part, msg) + // } + + sender := discordid.MakeUserID(msg.Author.ID) + var pmp event.BeeperPerMessageProfile + ghost, err := portal.Bridge.GetGhostByID(ctx, sender) + if err != nil { + log.Err(err).Msg("Failed to get ghost for per-message profile") + } else { + pmp.ID = string(ghost.Intent.GetMXID()) + pmp.Displayname = ghost.Name + if ghost.AvatarMXC != "" { + pmp.AvatarURL = &ghost.AvatarMXC + } + } + + // Assign incrementing part IDs. + for i, part := range parts { + part.ID = networkid.PartID(strconv.Itoa(i)) + + // Beeper clients support backfilling backwards (scrolling up to load + // more messages). Adding per-message profiles to every part helps them + // present the right message authorship information even when a + // membership event isn't present. + part.Content.BeeperPerMessageProfile = &pmp + } + + converted := &bridgev2.ConvertedMessage{Parts: parts} + if knownThreadRoot != nil { + threadRoot := *knownThreadRoot + converted.ThreadRoot = &threadRoot + } + + // TODO This is sorta gross; it might be worth bundling these parameters + // into a struct. + mc.addReplyToConvertedMessage( + ctx, + converted, + source, + msg, + ) + + return converted +} + +const forwardTemplateHTML = `
    +

    ↷ Forwarded

    +%s +

    %s

    +
    ` + +const msgInteractionTemplateHTML = `
    +%s used /%s +
    ` + +const msgComponentTemplateHTML = `

    This message contains interactive elements. Use the Discord app to interact with the message.

    ` + +func (mc *MessageConverter) addReplyToConvertedMessage( + ctx context.Context, + converted *bridgev2.ConvertedMessage, + source *bridgev2.UserLogin, + msg *discordgo.Message, +) { + ref := msg.MessageReference + if ref == nil || ref.Type != discordgo.MessageReferenceTypeDefault { + return + } + + log := zerolog.Ctx(ctx).With(). + Str("referenced_channel_id", ref.ChannelID). + Str("referenced_guild_id", ref.GuildID). + Str("referenced_message_id", ref.MessageID).Logger() + ctx = log.WithContext(ctx) + + targetMessageID := discordid.MakeMessageID(ref.MessageID) + converted.ReplyTo = &networkid.MessageOptionalPartID{ + MessageID: targetMessageID, + // This needs to point to a valid part. Since we assign part ids + // counting upwards from zero, default to it. + PartID: ptr.Ptr(networkid.PartID("0")), + } + if msg.ReferencedMessage != nil { + // ReferencedMessage will be nil if Discord's backend didn't feel like + // fetching the message or if the message has been deleted. + converted.ReplyToUser = discordid.MakeUserID(msg.ReferencedMessage.Author.ID) + } + + // Try to provide a more correct ReplyTo.PartID if the message is already + // in the database. This won't be the case for e.g. initial backfill. + targetMatrixMsg, err := mc.Bridge.DB.Message.GetFirstPartByID(ctx, source.ID, targetMessageID) + if err != nil { + log.Warn().Err(err).Msg("Failed to query database for first message part; proceeding") + return + } + if targetMatrixMsg == nil { + log.Debug().Msg("Couldn't find a first message part for reply target; proceeding") + return + } + converted.ReplyTo.PartID = &targetMatrixMsg.PartID +} + +func (mc *MessageConverter) renderDiscordTextMessage(ctx context.Context, intent bridgev2.MatrixAPI, portal *bridgev2.Portal, msg *discordgo.Message, source *bridgev2.UserLogin) *bridgev2.ConvertedMessagePart { + log := zerolog.Ctx(ctx) + switch msg.Type { + case discordgo.MessageTypeCall: + return &bridgev2.ConvertedMessagePart{Type: event.EventMessage, Content: &event.MessageEventContent{ + MsgType: event.MsgEmote, + // TODO: Use ghost name instead? + Body: fmt.Sprintf("(%s started a call. Use the Discord app to answer.)", msg.Author.String()), + }} + case discordgo.MessageTypeGuildMemberJoin: + // This is only used for backfilled user join notices (e.g. "Good to + // see you, [username]."). Live user joins are handled as membership + // changes in the guild system channel and aren't even converted. + return &bridgev2.ConvertedMessagePart{Type: event.EventMessage, Content: &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "(Joined the server.)", + }} + } + + var htmlParts []string + + if msg.Interaction != nil { + ghost, err := mc.Bridge.GetGhostByID(ctx, discordid.MakeUserID(msg.Interaction.User.ID)) + // TODO(skip): Try doing ghost.UpdateInfoIfNecessary. + if err == nil { + htmlParts = append(htmlParts, fmt.Sprintf(msgInteractionTemplateHTML, ghost.Intent.GetMXID(), ghost.Name, msg.Interaction.Name)) + } else { + log.Err(err).Msg("Couldn't get ghost by ID while bridging interaction") + } + } + + if msg.Content != "" && !isPlainGifMessage(msg) { + // Bridge basic text messages. + htmlParts = append(htmlParts, mc.renderDiscordMarkdownOnlyHTML(portal, source, msg.Content, true)) + } else if msg.MessageReference != nil && + msg.MessageReference.Type == discordgo.MessageReferenceTypeForward && + len(msg.MessageSnapshots) > 0 && + msg.MessageSnapshots[0].Message != nil { + // Bridge forwarded messages. + htmlParts = append(htmlParts, mc.forwardedMessageHTMLPart(ctx, portal, source, msg)) + } + + previews := make([]*event.BeeperLinkPreview, 0) + for i, embed := range msg.Embeds { + if i == 0 && msg.MessageReference == nil && isReplyEmbed(embed) { + continue + } + + with := log.With(). + Str("embed_type", string(embed.Type)). + Int("embed_index", i) + + switch getEmbedType(msg, embed) { + case EmbedRich: + log := with.Str("computed_embed_type", "rich").Logger() + htmlParts = append(htmlParts, mc.renderDiscordRichEmbed(log.WithContext(ctx), source, embed)) + case EmbedLinkPreview: + log := with.Str("computed_embed_type", "link preview").Logger() + previews = append(previews, mc.renderDiscordLinkEmbed(log.WithContext(ctx), embed)) + case EmbedVideo: + // Video embeds are handled as separate messages via renderDiscordVideoEmbed. + default: + log := with.Logger() + log.Warn().Msg("Unknown embed type in message") + } + } + + if len(msg.Components) > 0 { + htmlParts = append(htmlParts, msgComponentTemplateHTML) + } + + if len(htmlParts) == 0 { + return nil + } + + fullHTML := strings.Join(htmlParts, "\n") + if !msg.MentionEveryone { + fullHTML = strings.ReplaceAll(fullHTML, "@room", "@\u2063ro\u2063om") + } + + content := format.HTMLToContent(fullHTML) + extraContent := map[string]any{ + "com.beeper.linkpreviews": previews, + } + + return &bridgev2.ConvertedMessagePart{Type: event.EventMessage, Content: &content, Extra: extraContent} +} + +func (mc *MessageConverter) forwardedMessageOrigLink(ctx context.Context, source *bridgev2.UserLogin, msg *discordgo.Message, msgTSText string) (string, error) { + router, ok := source.Client.(router.Router) + if !ok { + return "", fmt.Errorf("network api can't route") // impossible? + } + route, err := router.Route(ctx, msg.MessageReference.ChannelID) + if err != nil { + return "", fmt.Errorf("couldn't route forwarded message: %w", err) + } + + forwardedFromPortal, err := mc.Bridge.DB.Portal.GetByKey(ctx, route.PortalKey) + if err != nil { + return "", fmt.Errorf("couldn't get containing portal for forwarded message: %w", err) + } + if forwardedFromPortal == nil { + return "", fmt.Errorf("containing portal for forwarded message doesn't exist") + } + + origMessage, err := mc.Bridge.DB.Message.GetFirstPartByID(ctx, source.ID, discordid.MakeMessageID(msg.MessageReference.MessageID)) + if err != nil { + return "", fmt.Errorf("couldn't get forwarded message from db: %w", err) + } + + if origMessage != nil { + // We've bridged the message that was forwarded, so we can link to it directly. + return fmt.Sprintf( + `#%s • %s`, + forwardedFromPortal.MXID.EventURI(origMessage.MXID, mc.Bridge.Matrix.ServerName()), + forwardedFromPortal.Name, + msgTSText, + ), nil + } + + if forwardedFromPortal.MXID != "" { + // We don't have the message but we have the portal (and it has a room), + // so link to that. + return fmt.Sprintf( + `#%s • %s`, + forwardedFromPortal.MXID.URI(mc.Bridge.Matrix.ServerName()), + forwardedFromPortal.Name, + msgTSText, + ), nil + } else if forwardedFromPortal.Name != "" { + // We only have the name of the portal. + return fmt.Sprintf("%s • %s", forwardedFromPortal.Name, msgTSText), nil + } + + // Give up if we don't have the message nor any portal information. + return "", fmt.Errorf("couldn't resolve forwarded message link") +} + +func (mc *MessageConverter) forwardedMessageHTMLPart(ctx context.Context, portal *bridgev2.Portal, source *bridgev2.UserLogin, msg *discordgo.Message) string { + log := zerolog.Ctx(ctx) + + forwardedHTML := mc.renderDiscordMarkdownOnlyHTMLNoUnwrap(portal, source, msg.MessageSnapshots[0].Message.Content, true) + msgTSText := msg.MessageSnapshots[0].Message.Timestamp.Format("2006-01-02 15:04 MST") + origLink, err := mc.forwardedMessageOrigLink(ctx, source, msg, msgTSText) + if err != nil { + log.Err(err).Msg("Failed to render original link to forwarded message, using generic placeholder") + origLink = fmt.Sprintf("unknown channel • %s", msgTSText) + } + + return fmt.Sprintf(forwardTemplateHTML, forwardedHTML, origLink) +} + +func mediaFailedMessage(err error) *event.MessageEventContent { + return &event.MessageEventContent{ + Body: fmt.Sprintf("Failed to bridge media: %v", err), + MsgType: event.MsgNotice, + } +} + +func (mc *MessageConverter) renderDiscordVideoEmbed(ctx context.Context, embed *discordgo.MessageEmbed) *bridgev2.ConvertedMessagePart { + var proxyURL string + if embed.Video != nil { + proxyURL = embed.Video.ProxyURL + } else if embed.Thumbnail != nil { + proxyURL = embed.Thumbnail.ProxyURL + } else { + zerolog.Ctx(ctx).Warn().Str("embed_url", embed.URL).Msg("No video or thumbnail proxy URL found in embed") + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: &event.MessageEventContent{ + Body: "Failed to bridge media: no video or thumbnail proxy URL found in embed", + MsgType: event.MsgNotice, + }, + } + } + + reupload, err := mc.ReuploadUnknownMedia(ctx, proxyURL, true) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to copy video embed to Matrix") + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: mediaFailedMessage(err), + } + } + + content := &event.MessageEventContent{ + Body: embed.URL, + URL: reupload.MXC, + File: reupload.File, + Info: &event.FileInfo{ + MimeType: reupload.MimeType, + Size: reupload.Size, + }, + } + + if embed.Video != nil { + content.MsgType = event.MsgVideo + content.Info.Width = embed.Video.Width + content.Info.Height = embed.Video.Height + } else { + content.MsgType = event.MsgImage + content.Info.Width = embed.Thumbnail.Width + content.Info.Height = embed.Thumbnail.Height + } + + extra := map[string]any{} + if content.MsgType == event.MsgVideo && embed.Type == discordgo.EmbedTypeGifv { + extra["info"] = map[string]any{ + "fi.mau.discord.gifv": true, + "fi.mau.gif": true, + "fi.mau.loop": true, + "fi.mau.autoplay": true, + "fi.mau.hide_controls": true, + "fi.mau.no_audio": true, + } + } + + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: content, + Extra: extra, + } +} + +func (mc *MessageConverter) renderDiscordSticker(ctx context.Context, sticker *discordgo.StickerItem) *bridgev2.ConvertedMessagePart { + var mime string + switch sticker.FormatType { + case discordgo.StickerFormatTypePNG: + mime = "image/png" + case discordgo.StickerFormatTypeAPNG: + mime = "image/apng" + case discordgo.StickerFormatTypeLottie: + mime = "video/lottie+json" + case discordgo.StickerFormatTypeGIF: + mime = "image/gif" + default: + zerolog.Ctx(ctx).Warn(). + Int("sticker_format", int(sticker.FormatType)). + Str("sticker_id", sticker.ID). + Msg("Unknown sticker format") + } + + // TODO(skip): Support direct media. + reupload, err := mc.ReuploadMedia(ctx, sticker.URL(), mime, "", -1, true) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to copy sticker to Matrix") + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: mediaFailedMessage(err), + } + } + + content := &event.MessageEventContent{ + Body: sticker.Name, // TODO(skip): Find description from somewhere? + Info: &event.FileInfo{ + MimeType: reupload.MimeType, + Size: reupload.Size, + }, + } + content.URL, content.File = reupload.MXC, reupload.File + cleanupConvertedStickerInfo(content) + + return &bridgev2.ConvertedMessagePart{ + Type: event.EventSticker, + Content: content, + } +} + +const DiscordStickerSize = 160 + +func cleanupConvertedStickerInfo(content *event.MessageEventContent) { + if content.Info == nil { + return + } + + if content.Info.Width == 0 && content.Info.Height == 0 { + content.Info.Width = DiscordStickerSize + content.Info.Height = DiscordStickerSize + } else if content.Info.Width > DiscordStickerSize || content.Info.Height > DiscordStickerSize { + if content.Info.Width > content.Info.Height { + content.Info.Height /= content.Info.Width / DiscordStickerSize + content.Info.Width = DiscordStickerSize + } else if content.Info.Width < content.Info.Height { + content.Info.Width /= content.Info.Height / DiscordStickerSize + content.Info.Height = DiscordStickerSize + } else { + content.Info.Width = DiscordStickerSize + content.Info.Height = DiscordStickerSize + } + } +} + +const ( + embedHTMLWrapper = `
    %s
    ` + embedHTMLWrapperColor = `
    %s
    ` + embedHTMLAuthorWithImage = `

     %s

    ` + embedHTMLAuthorPlain = `

    %s

    ` + embedHTMLAuthorLink = `%s` + embedHTMLTitleWithLink = `

    %s

    ` + embedHTMLTitlePlain = `

    %s

    ` + embedHTMLDescription = `

    %s

    ` + embedHTMLFieldName = `%s` + embedHTMLFieldValue = `%s` + embedHTMLFields = `%s%s
    ` + embedHTMLLinearField = `

    %s
    %s

    ` + embedHTMLImage = `

    ` + embedHTMLFooterWithImage = `` + embedHTMLFooterPlain = `` + embedHTMLFooterOnlyDate = `` + embedHTMLDate = `` + embedFooterDateSeparator = ` • ` +) + +func (mc *MessageConverter) renderDiscordRichEmbed(ctx context.Context, source *bridgev2.UserLogin, embed *discordgo.MessageEmbed) string { + log := zerolog.Ctx(ctx) + var htmlParts []string + if embed.Author != nil { + var authorHTML string + authorNameHTML := html.EscapeString(embed.Author.Name) + if embed.Author.URL != "" { + authorNameHTML = fmt.Sprintf(embedHTMLAuthorLink, embed.Author.URL, authorNameHTML) + } + authorHTML = fmt.Sprintf(embedHTMLAuthorPlain, authorNameHTML) + if embed.Author.ProxyIconURL != "" { + reupload, err := mc.ReuploadUnknownMedia(ctx, embed.Author.ProxyIconURL, false) + + if err != nil { + log.Warn().Err(err).Msg("Failed to reupload author icon in embed") + } else { + authorHTML = fmt.Sprintf(embedHTMLAuthorWithImage, reupload.MXC, authorNameHTML) + } + } + htmlParts = append(htmlParts, authorHTML) + } + + portal := ctx.Value(contextKeyPortal).(*bridgev2.Portal) + if embed.Title != "" { + var titleHTML string + baseTitleHTML := mc.renderDiscordMarkdownOnlyHTML(portal, source, embed.Title, false) + if embed.URL != "" { + titleHTML = fmt.Sprintf(embedHTMLTitleWithLink, html.EscapeString(embed.URL), baseTitleHTML) + } else { + titleHTML = fmt.Sprintf(embedHTMLTitlePlain, baseTitleHTML) + } + htmlParts = append(htmlParts, titleHTML) + } + + if embed.Description != "" { + htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLDescription, mc.renderDiscordMarkdownOnlyHTML(portal, source, embed.Description, true))) + } + + for i := 0; i < len(embed.Fields); i++ { + item := embed.Fields[i] + // TODO(skip): Port EmbedFieldsAsTables. + if false { + splitItems := []*discordgo.MessageEmbedField{item} + if item.Inline && len(embed.Fields) > i+1 && embed.Fields[i+1].Inline { + splitItems = append(splitItems, embed.Fields[i+1]) + i++ + if len(embed.Fields) > i+1 && embed.Fields[i+1].Inline { + splitItems = append(splitItems, embed.Fields[i+1]) + i++ + } + } + headerParts := make([]string, len(splitItems)) + contentParts := make([]string, len(splitItems)) + for j, splitItem := range splitItems { + headerParts[j] = fmt.Sprintf(embedHTMLFieldName, mc.renderDiscordMarkdownOnlyHTML(portal, source, splitItem.Name, false)) + contentParts[j] = fmt.Sprintf(embedHTMLFieldValue, mc.renderDiscordMarkdownOnlyHTML(portal, source, splitItem.Value, true)) + } + htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLFields, strings.Join(headerParts, ""), strings.Join(contentParts, ""))) + } else { + htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLLinearField, + strconv.FormatBool(item.Inline), + mc.renderDiscordMarkdownOnlyHTML(portal, source, item.Name, false), + mc.renderDiscordMarkdownOnlyHTML(portal, source, item.Value, true), + )) + } + } + + if embed.Image != nil { + reupload, err := mc.ReuploadUnknownMedia(ctx, embed.Image.ProxyURL, false) + if err != nil { + log.Warn().Err(err).Msg("Failed to reupload image in embed") + } else { + htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLImage, reupload.MXC)) + } + } + + var embedDateHTML string + if embed.Timestamp != "" { + formattedTime := embed.Timestamp + parsedTS, err := time.Parse(time.RFC3339, embed.Timestamp) + if err != nil { + log.Warn().Err(err).Msg("Failed to parse timestamp in embed") + } else { + formattedTime = parsedTS.Format(discordTimestampStyle('F').Format()) + } + embedDateHTML = fmt.Sprintf(embedHTMLDate, embed.Timestamp, formattedTime) + } + + if embed.Footer != nil { + var footerHTML string + var datePart string + if embedDateHTML != "" { + datePart = embedFooterDateSeparator + embedDateHTML + } + footerHTML = fmt.Sprintf(embedHTMLFooterPlain, html.EscapeString(embed.Footer.Text), datePart) + if embed.Footer.ProxyIconURL != "" { + reupload, err := mc.ReuploadUnknownMedia(ctx, embed.Footer.ProxyIconURL, false) + + if err != nil { + log.Warn().Err(err).Msg("Failed to reupload footer icon in embed") + } else { + footerHTML = fmt.Sprintf(embedHTMLFooterWithImage, reupload.MXC, html.EscapeString(embed.Footer.Text), datePart) + } + } + htmlParts = append(htmlParts, footerHTML) + } else if embed.Timestamp != "" { + htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLFooterOnlyDate, embedDateHTML)) + } + + if len(htmlParts) == 0 { + return "" + } + + compiledHTML := strings.Join(htmlParts, "") + if embed.Color != 0 { + compiledHTML = fmt.Sprintf(embedHTMLWrapperColor, embed.Color, compiledHTML) + } else { + compiledHTML = fmt.Sprintf(embedHTMLWrapper, compiledHTML) + } + return compiledHTML +} + +func (mc *MessageConverter) renderDiscordLinkEmbedImage( + ctx context.Context, url string, width, height int, preview *event.BeeperLinkPreview, +) { + reupload, err := mc.ReuploadUnknownMedia(ctx, url, true) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to reupload image in URL preview, ignoring") + return + } + + if width != 0 || height != 0 { + preview.ImageWidth = event.IntOrString(width) + preview.ImageHeight = event.IntOrString(height) + } + preview.ImageSize = event.IntOrString(reupload.Size) + preview.ImageType = reupload.MimeType + preview.ImageURL, preview.ImageEncryption = reupload.MXC, reupload.File +} + +func (mc *MessageConverter) renderDiscordLinkEmbed(ctx context.Context, embed *discordgo.MessageEmbed) *event.BeeperLinkPreview { + var preview event.BeeperLinkPreview + preview.MatchedURL = embed.URL + preview.Title = embed.Title + preview.Description = embed.Description + if embed.Image != nil { + mc.renderDiscordLinkEmbedImage(ctx, embed.Image.ProxyURL, embed.Image.Width, embed.Image.Height, &preview) + } else if embed.Thumbnail != nil { + mc.renderDiscordLinkEmbedImage(ctx, embed.Thumbnail.ProxyURL, embed.Thumbnail.Width, embed.Thumbnail.Height, &preview) + } + return &preview +} + +func attachmentFileName(att *discordgo.MessageAttachment) (string, error) { + fileName := att.Filename + if fileName == "" { + parsedURL, err := url.Parse(att.URL) + if err != nil { + return "", fmt.Errorf("couldn't parse URL to detect attachment file name: %w", err) + } + fileName = path.Base(parsedURL.Path) + } + return fileName, nil +} + +func conversionFailedPart(err error) *bridgev2.ConvertedMessagePart { + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: mediaFailedMessage(err), + } +} + +func (mc *MessageConverter) renderDiscordAttachment( + ctx context.Context, + att *discordgo.MessageAttachment, + mediaInfo *discordid.MediaInfo, +) *bridgev2.ConvertedMessagePart { + log := zerolog.Ctx(ctx) + + fileName, err := attachmentFileName(att) + if err != nil { + return conversionFailedPart(err) + } + + // (The rest of this function can adjust these fields.) + content := &event.MessageEventContent{ + Body: fileName, + Info: &event.FileInfo{ + Width: att.Width, + Height: att.Height, + MimeType: att.ContentType, + Size: att.Size, + }, + } + + if mc.DirectMedia { + if mc.CacheDirectMediaAttachment != nil { + mc.CacheDirectMediaAttachment(mediaInfo, att.URL) + } + + mediaID, err := mediaInfo.Encode() + if err != nil { + log.Err(err).Msg("Failed to encode direct media ID") + return conversionFailedPart(err) + } + mxc, err := mc.Bridge.Matrix.GenerateContentURI(ctx, mediaID) + if err != nil { + log.Err(err).Msg("Failed to generate content URI") + return conversionFailedPart(err) + } + log.Trace().Str("direct_media_mxc", string(mxc)).Msg("Generated direct media MXC") + content.URL = mxc + } else { + reupload, err := mc.ReuploadMedia(ctx, att.URL, att.ContentType, att.Filename, att.Size, true) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to copy attachment to Matrix") + return &bridgev2.ConvertedMessagePart{ + Type: event.EventMessage, + Content: mediaFailedMessage(err), + } + } + content.Info.MimeType = reupload.MimeType + content.Info.Size = reupload.Size + content.URL = reupload.MXC + content.File = reupload.File + } + + var extra = make(map[string]any) + + if strings.HasPrefix(att.Filename, "SPOILER_") { + extra["page.codeberg.everypizza.msc4193.spoiler"] = true + } + + if att.Description != "" { + content.Body = att.Description + content.FileName = fileName + } + + switch strings.ToLower(strings.Split(content.Info.MimeType, "/")[0]) { + case "audio": + content.MsgType = event.MsgAudio + if att.Waveform != nil { + // Bridge a voice message. + + // TODO convert waveform + extra["org.matrix.msc1767.audio"] = map[string]any{ + "duration": int(att.DurationSeconds * 1000), + } + extra["org.matrix.msc3245.voice"] = map[string]any{} + } + case "image": + content.MsgType = event.MsgImage + case "video": + content.MsgType = event.MsgVideo + default: + content.MsgType = event.MsgFile + } + + if content.Info.Width == 0 && content.Info.Height == 0 { + content.Info.Width = att.Width + content.Info.Height = att.Height + } + + part := &bridgev2.ConvertedMessagePart{ + // TODO: Do this eventually. Edits and replies currently make certain assumptions + // about how part IDs are formed to make this safe. + // + // ID: discordid.MakePartID(att.ID), + Type: event.EventMessage, + Content: content, + Extra: extra, + } + + return part +} diff --git a/pkg/msgconv/from-matrix.go b/pkg/msgconv/from-matrix.go new file mode 100644 index 0000000..79b868b --- /dev/null +++ b/pkg/msgconv/from-matrix.go @@ -0,0 +1,222 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "go.mau.fi/util/variationselector" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +func parseAllowedLinkPreviews(raw map[string]any) []string { + if raw == nil { + return nil + } + linkPreviews, ok := raw["com.beeper.linkpreviews"].([]any) + if !ok { + return nil + } + allowedLinkPreviews := make([]string, 0, len(linkPreviews)) + for _, preview := range linkPreviews { + previewMap, ok := preview.(map[string]any) + if !ok { + continue + } + matchedURL, _ := previewMap["matched_url"].(string) + if matchedURL != "" { + allowedLinkPreviews = append(allowedLinkPreviews, matchedURL) + } + } + return allowedLinkPreviews +} + +func uploadDiscordAttachment(cli *http.Client, url string, data []byte, contentType string) error { + req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) + if err != nil { + return err + } + + for key, value := range discordgo.DroidBaseHeaders { + req.Header.Set(key, value) + } + if contentType == "" { + contentType = "application/octet-stream" + } + req.Header.Set("Content-Type", contentType) + req.Header.Set("Referer", "https://discord.com/") + req.Header.Set("Sec-Fetch-Dest", "empty") + req.Header.Set("Sec-Fetch-Mode", "cors") + req.Header.Set("Sec-Fetch-Site", "cross-site") + + resp, err := cli.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode > 300 { + respData, _ := io.ReadAll(resp.Body) + return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, respData) + } + return nil +} + +// ToDiscord converts a Matrix message into a discordgo.MessageSend that is appropriate +// for bridging the message to Discord. +func (mc *MessageConverter) ToDiscord( + ctx context.Context, + session *discordgo.Session, + msg *bridgev2.MatrixMessage, + channelID string, + refererOpt discordgo.RequestOption, +) (*discordgo.MessageSend, error) { + ctx = context.WithValue(ctx, contextKeyPortal, msg.Portal) + ctx = context.WithValue(ctx, contextKeyDiscordClient, session) + var req discordgo.MessageSend + if msg.InputTransactionID != "" { + req.Nonce = string(msg.InputTransactionID) + } else { + req.Nonce = discordid.GenerateNonce() + } + log := zerolog.Ctx(ctx) + + if msg.ReplyTo != nil { + req.Reference = &discordgo.MessageReference{ + ChannelID: discordid.ParseChannelPortalID(msg.ReplyTo.Room.ID), + MessageID: discordid.ParseMessageID(msg.ReplyTo.ID), + } + } + + content := msg.Content + + convertMatrix := func() { + // TODO: Handle (silent) replies. + // + // NOTE: Real users should never send allowed_mentions (except for + // silent replies). + // + // Since we only support real users at the moment, always ignore the + // returned allowed mentions. + req.Content, _ = mc.ConvertMatrixMessageContent(ctx, msg.Portal, content, parseAllowedLinkPreviews(msg.Event.Content.Raw)) + if content.MsgType == event.MsgEmote { + req.Content = fmt.Sprintf("_%s_", req.Content) + } + } + + switch content.MsgType { + case event.MsgText, event.MsgEmote, event.MsgNotice: + convertMatrix() + case event.MsgAudio, event.MsgFile, event.MsgImage, event.MsgVideo: + mediaData, err := mc.Bridge.Bot.DownloadMedia(ctx, content.URL, content.File) + if err != nil { + log.Err(err).Msg("Failed to download Matrix attachment for bridging") + return nil, bridgev2.ErrMediaDownloadFailed + } + + filename := content.Body + hasCaption := content.FileName != "" && content.FileName != content.Body + if content.FileName != "" { + filename = content.FileName + } + isSpoiler := msg.Event.Content.Raw["page.codeberg.everypizza.msc4193.spoiler"] == true + + var voiceMeta *discordVoiceMetadata + if content.MsgType == event.MsgAudio && !hasCaption && !isSpoiler { + voiceMeta = getDiscordVoiceMetadata(content) + } + + if hasCaption { + convertMatrix() + } + if isSpoiler { + filename = "SPOILER_" + filename + } + + // TODO: Support attachments for relay/webhook. (A branch was removed here.) + att := &discordgo.MessageAttachment{ + ID: "0", + Filename: filename, + } + if voiceMeta != nil { + flags := int(discordgo.MessageFlagsIsVoiceMessage) + req.Flags = &flags + att.ContentType = voiceMeta.ContentType + att.DurationSeconds = voiceMeta.DurationSeconds + att.Waveform = voiceMeta.Waveform + } + + uploadID := mc.NextDiscordUploadID() + log.Debug().Str("upload_id", uploadID).Msg("Preparing attachment") + filePrep := &discordgo.FilePrepare{ + Size: len(mediaData), + Name: att.Filename, + ID: uploadID, + } + prep, err := session.ChannelAttachmentCreate(channelID, &discordgo.ReqPrepareAttachments{ + Files: []*discordgo.FilePrepare{filePrep}, + }, refererOpt) + + if err != nil { + log.Err(err).Msg("Failed to create attachment in preparation for attachment reupload") + return nil, bridgev2.ErrMediaReuploadFailed + } + + prepared := prep.Attachments[0] + att.UploadedFilename = prepared.UploadFilename + + err = uploadDiscordAttachment(session.Client, prepared.UploadURL, mediaData, att.OriginalContentType) + if err != nil { + log.Err(err).Msg("Failed to reupload Discord attachment after preparing") + return nil, bridgev2.ErrMediaReuploadFailed + } + + req.Attachments = append(req.Attachments, att) + } + + return &req, nil +} + +func (mc *MessageConverter) ConvertMatrixMessageContent(ctx context.Context, portal *bridgev2.Portal, content *event.MessageEventContent, allowedLinkPreviews []string) (string, *discordgo.MessageAllowedMentions) { + allowedMentions := &discordgo.MessageAllowedMentions{ + Parse: []discordgo.AllowedMentionType{}, + Users: []string{}, + RepliedUser: true, + } + + if content.Format == event.FormatHTML && len(content.FormattedBody) > 0 { + ctx := format.NewContext(ctx) + ctx.ReturnData[formatterContextInputAllowedLinkPreviewsKey] = allowedLinkPreviews + ctx.ReturnData[formatterContextPortalKey] = portal + ctx.ReturnData[formatterContextAllowedMentionsKey] = allowedMentions + if content.Mentions != nil { + ctx.ReturnData[formatterContextInputAllowedMentionsKey] = content.Mentions.UserIDs + } + return variationselector.FullyQualify(mc.HTMLParser.Parse(content.FormattedBody, ctx)), allowedMentions + } else { + return variationselector.FullyQualify(escapeDiscordMarkdown(content.Body)), allowedMentions + } +} diff --git a/pkg/msgconv/msgconv.go b/pkg/msgconv/msgconv.go new file mode 100644 index 0000000..4bf2d05 --- /dev/null +++ b/pkg/msgconv/msgconv.go @@ -0,0 +1,169 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "context" + "fmt" + "math/rand" + "slices" + "strconv" + "sync/atomic" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +type MessageConverter struct { + Bridge *bridgev2.Bridge + DirectMedia bool + + CacheDirectMediaAttachment func(info *discordid.MediaInfo, discordURL string) + + HTMLParser *format.HTMLParser + + nextDiscordUploadID atomic.Int32 + + MaxFileSize int64 +} + +func NewMessageConverter(bridge *bridgev2.Bridge) *MessageConverter { + mc := &MessageConverter{ + Bridge: bridge, + MaxFileSize: 50 * 1024 * 1024, + } + mc.HTMLParser = &format.HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: mc.convertPill, + ItalicConverter: func(s string, ctx format.Context) string { + return fmt.Sprintf("*%s*", s) + }, + UnderlineConverter: func(s string, ctx format.Context) string { + return fmt.Sprintf("__%s__", s) + }, + TextConverter: func(s string, ctx format.Context) string { + if ctx.TagStack.Has("pre") || ctx.TagStack.Has("code") { + // If we're in a code block, don't escape markdown + return s + } + return escapeDiscordMarkdown(s) + }, + SpoilerConverter: func(text, reason string, ctx format.Context) string { + if reason != "" { + return fmt.Sprintf("(%s) ||%s||", reason, text) + } + return fmt.Sprintf("||%s||", text) + }, + LinkConverter: func(text, href string, ctx format.Context) string { + linkPreviews := ctx.ReturnData[formatterContextInputAllowedLinkPreviewsKey].([]string) + allowPreview := linkPreviews == nil || slices.Contains(linkPreviews, href) + if text == href { + if !allowPreview { + return fmt.Sprintf("<%s>", text) + } + return text + } else if !discordLinkRegexFull.MatchString(href) { + return fmt.Sprintf("%s (%s)", escapeDiscordMarkdown(text), escapeDiscordMarkdown(href)) + } else if !allowPreview { + return fmt.Sprintf("[%s](<%s>)", escapeDiscordMarkdown(text), href) + } else { + return fmt.Sprintf("[%s](%s)", escapeDiscordMarkdown(text), href) + } + }, + } + + mc.nextDiscordUploadID.Store(rand.Int31n(100)) + + return mc +} + +// resolveMentionedDiscordUserID tries to translate a mentioned user MXID to a +// Discord user ID, regardless of whether a ghost (remote Discord user) or +// Matrix user was mentioned. +func (mc *MessageConverter) resolveMentionedDiscordUserID( + ctx context.Context, + portal *bridgev2.Portal, + mxid id.UserID, +) (string, error) { + if ghostID, ok := mc.Bridge.Matrix.ParseGhostMXID(mxid); ok { + // A ghost was mentioned, so we can extract the Discord user ID + // directly from the MXID. + return discordid.ParseUserID(ghostID), nil + } + // The rest of this method is handling for when a "real" Matrix user was + // mentioned. This can be the user themselves or someone else in the portal + // (when split rooms are not in play). + + user, err := mc.Bridge.GetExistingUserByMXID(ctx, mxid) + if err != nil { + return "", err + } else if user == nil { + return "", nil + } + + login, _, err := portal.FindPreferredLogin(ctx, user, false) + if err != nil { + return "", err + } else if login == nil { + return "", nil + } + + return discordid.ParseUserLoginID(login.ID), nil +} + +func (mc *MessageConverter) convertPill(displayname, mxid, eventID string, ctx format.Context) string { + if len(mxid) == 0 || mxid[0] != '@' { + // Behave like mautrix-whatsapp. + return format.DefaultPillConverter(displayname, mxid, eventID, ctx) + } + + allowedMentions, _ := ctx.ReturnData[formatterContextInputAllowedMentionsKey].([]id.UserID) + portal := ctx.ReturnData[formatterContextPortalKey].(*bridgev2.Portal) + + mentionedUserID := id.UserID(mxid) + log := zerolog.Ctx(ctx.Ctx).With(). + Str("mentioned_mxid", mxid). + Str("mentioned_displayname", displayname). + Str("event_id", eventID). + Logger() + + if !slices.Contains(allowedMentions, mentionedUserID) { + return displayname + } + + mentionedDiscordUserID, err := mc.resolveMentionedDiscordUserID(ctx.Ctx, portal, mentionedUserID) + if err != nil { + log.Err(err).Msg("Failed to resolve the corresponding Discord user ID for the mentioned user, falling back to display name") + return displayname + } else if mentionedDiscordUserID == "" { + log.Error().Msg("Failed to find a corresponding Discord user ID for the mentioned user, falling back to display name") + return displayname + } + + return fmt.Sprintf("<@%s>", mentionedDiscordUserID) +} + +func (mc *MessageConverter) NextDiscordUploadID() string { + val := mc.nextDiscordUploadID.Add(2) + return strconv.Itoa(int(val)) +} diff --git a/pkg/msgconv/voicemsg.go b/pkg/msgconv/voicemsg.go new file mode 100644 index 0000000..c5f7db6 --- /dev/null +++ b/pkg/msgconv/voicemsg.go @@ -0,0 +1,124 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package msgconv + +import ( + "strings" + + "maunium.net/go/mautrix/event" +) + +type discordVoiceMetadata struct { + ContentType string + DurationSeconds float64 + Waveform []byte +} + +// NOTE these waveform calculations are pure conjecture; i.e. they aren't +// modeled after what first-party clients actually do + +func waveformBuckets(durationMs int) int { + targetLength := (durationMs + 99) / 100 // like math.Ceil(durationMs / 100) + targetLength = max(min(targetLength, 256), 1) // clamp to [1,256] + return targetLength +} + +func downsampleWaveform(waveform []int, buckets int) []int { + if len(waveform) <= buckets { + return waveform + } + + samples := make([]int, 0, buckets) + for i := range buckets { + start := i * len(waveform) / buckets + end := (i + 1) * len(waveform) / buckets + if end <= start { + end = start + 1 + } + + maxVal := waveform[start] + for _, sample := range waveform[start+1 : end] { + if sample > maxVal { + maxVal = sample + } + } + samples = append(samples, maxVal) + } + + return samples +} + +func matrixWaveformToDiscord(samples []int, durationMs int) []byte { + if len(samples) == 0 { + return nil + } + + samples = downsampleWaveform(samples, waveformBuckets(durationMs)) + + maxVal := 0 + for _, sample := range samples { + if sample > maxVal { + maxVal = sample + } + } + + clampedSamples := make([]byte, len(samples)) + for i, sample := range samples { + if maxVal > 256 { + sample /= 4 + } + if sample < 0 { + sample = 0 + } + if sample > 255 { + sample = 255 + } + clampedSamples[i] = byte(sample) + } + + return clampedSamples +} + +func getDiscordVoiceMetadata(content *event.MessageEventContent) *discordVoiceMetadata { + if content.MSC3245Voice == nil || content.MSC1767Audio == nil { + return nil + } + + mimeType := strings.TrimSpace(content.Info.MimeType) + if !strings.HasPrefix(strings.ToLower(mimeType), "audio/") { + return nil + } + + durationMs := content.Info.Duration + if durationMs == 0 { + durationMs = content.MSC1767Audio.Duration + } + if durationMs <= 0 { + return nil + } + + waveform := matrixWaveformToDiscord(content.MSC1767Audio.Waveform, durationMs) + if len(waveform) == 0 { + return nil + } + + return &discordVoiceMetadata{ + ContentType: mimeType, + DurationSeconds: float64(durationMs) / 1000, + Waveform: waveform, + } +} diff --git a/remoteauth/README.md b/pkg/remoteauth/README.md similarity index 100% rename from remoteauth/README.md rename to pkg/remoteauth/README.md diff --git a/remoteauth/client.go b/pkg/remoteauth/client.go similarity index 100% rename from remoteauth/client.go rename to pkg/remoteauth/client.go diff --git a/remoteauth/clientpackets.go b/pkg/remoteauth/clientpackets.go similarity index 100% rename from remoteauth/clientpackets.go rename to pkg/remoteauth/clientpackets.go diff --git a/remoteauth/serverpackets.go b/pkg/remoteauth/serverpackets.go similarity index 98% rename from remoteauth/serverpackets.go rename to pkg/remoteauth/serverpackets.go index b7376d3..5e44037 100644 --- a/remoteauth/serverpackets.go +++ b/pkg/remoteauth/serverpackets.go @@ -103,6 +103,7 @@ func (h *serverHello) process(client *Client) error { ticker := time.NewTicker(time.Duration(h.HeartbeatInterval) * time.Millisecond) go func() { defer ticker.Stop() + //lint:ignore S1000 - for { select { // case <-client.ctx.Done(): @@ -126,7 +127,7 @@ func (h *serverHello) process(client *Client) error { <-time.After(duration) client.Lock() - client.err = fmt.Errorf("Timed out after %s", duration) + client.err = fmt.Errorf("timed out after %s", duration) client.close() client.Unlock() }() diff --git a/remoteauth/user.go b/pkg/remoteauth/user.go similarity index 100% rename from remoteauth/user.go rename to pkg/remoteauth/user.go diff --git a/pkg/router/router.go b/pkg/router/router.go new file mode 100644 index 0000000..ed44988 --- /dev/null +++ b/pkg/router/router.go @@ -0,0 +1,79 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2026 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package router + +import ( + "context" + + "github.com/bwmarrin/discordgo" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/bridgev2/networkid" + + "go.mau.fi/mautrix-discord/pkg/connector/discorddb" + "go.mau.fi/mautrix-discord/pkg/discordid" +) + +type Router interface { + // Route embodies the core logic that determines where a Discord event from a + // certain channel is ultimately bridged to on Matrix. + // + // This is significant for threads; routing a thread channel ID currently + // redirects to the portal corresponding to the parent channel on Discord, + // because we bridge threads via m.thread. That means the Matrix messages need + // to go in the parent channel. The routing logic performs this resolution. + // + // Another way to think about this mechanism is that it hides the concern of + // constructing the correct PortalKey in response to something that happened + // in a Discord channel. It handles details like threading for you. + Route(ctx context.Context, channelID string) (*Route, error) +} + +// How and where a Discord event should be bridged to Matrix. +type Route struct { + // The key of the portal that the event should be "routed" to. This is + // almost like a destination. + PortalKey networkid.PortalKey + + // The corresponding Discord channel ID of the portal that PortalKey points + // to. + PortalChannelID string + + // Whether or not we're certain about the receiver of the PortalKey. In + // practice, this will only be true if we can't find the channel in state. + // If this is true, then FromChannel and FromThread will always be nil. + Uncertain bool + + // The Discord channel that the event originated from. This can be nil + // despite the channel actually existing if it wasn't found in state. + FromChannel *discordgo.Channel + + // Non-nil if Channel is a thread and the thread was found in state. + FromThread *discorddb.Thread +} + +func (r *Route) FromThreadRootMessageID() *networkid.MessageID { + if r.FromThread == nil { + return nil + } + + rootMsgID := r.FromThread.RootMessageID + if rootMsgID == "" { + return nil + } + + return ptr.Ptr(discordid.MakeMessageID(rootMsgID)) +} diff --git a/portal.go b/portal.go deleted file mode 100644 index 3a5db83..0000000 --- a/portal.go +++ /dev/null @@ -1,2631 +0,0 @@ -package main - -import ( - "bytes" - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "reflect" - "regexp" - "slices" - "strconv" - "strings" - "sync" - "syscall" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/gabriel-vasile/mimetype" - "github.com/gorilla/mux" - "github.com/rs/zerolog" - "go.mau.fi/util/exsync" - "go.mau.fi/util/variationselector" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/config" - "go.mau.fi/mautrix-discord/database" -) - -type portalDiscordMessage struct { - msg interface{} - user *User - - thread *Thread -} - -type portalMatrixMessage struct { - evt *event.Event - user *User -} - -var relayClient, _ = discordgo.New("") - -type Portal struct { - *database.Portal - - Parent *Portal - Guild *Guild - - bridge *DiscordBridge - log zerolog.Logger - - roomCreateLock sync.Mutex - encryptLock sync.Mutex - - discordMessages chan portalDiscordMessage - matrixMessages chan portalMatrixMessage - - recentMessages *exsync.RingBuffer[string, *discordgo.Message] - - commands map[string]*discordgo.ApplicationCommand - commandsLock sync.RWMutex - - forwardBackfillLock sync.Mutex - - currentlyTyping []id.UserID - currentlyTypingLock sync.Mutex -} - -const recentMessageBufferSize = 32 - -var _ bridge.Portal = (*Portal)(nil) -var _ bridge.ReadReceiptHandlingPortal = (*Portal)(nil) -var _ bridge.MembershipHandlingPortal = (*Portal)(nil) -var _ bridge.TypingPortal = (*Portal)(nil) - -//var _ bridge.MetaHandlingPortal = (*Portal)(nil) -//var _ bridge.DisappearingPortal = (*Portal)(nil) - -func (portal *Portal) IsEncrypted() bool { - return portal.Encrypted -} - -func (portal *Portal) MarkEncrypted() { - portal.Encrypted = true - portal.Update() -} - -func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) { - if user.GetPermissionLevel() >= bridgeconfig.PermissionLevelUser || portal.RelayWebhookID != "" { - portal.matrixMessages <- portalMatrixMessage{user: user.(*User), evt: evt} - } -} - -var ( - portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} -) - -func (br *DiscordBridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey, chanType discordgo.ChannelType) *Portal { - if dbPortal == nil { - if key == nil || chanType < 0 { - return nil - } - - dbPortal = br.DB.Portal.New() - dbPortal.Key = *key - dbPortal.Type = chanType - dbPortal.Insert() - } - - portal := br.NewPortal(dbPortal) - - br.portalsByID[portal.Key] = portal - if portal.MXID != "" { - br.portalsByMXID[portal.MXID] = portal - } - - if portal.GuildID != "" { - portal.Guild = portal.bridge.GetGuildByID(portal.GuildID, true) - } - if portal.ParentID != "" { - parentKey := database.NewPortalKey(portal.ParentID, "") - var ok bool - portal.Parent, ok = br.portalsByID[parentKey] - if !ok { - portal.Parent = br.loadPortal(br.DB.Portal.GetByID(parentKey), nil, -1) - } - } - - return portal -} - -func (br *DiscordBridge) GetPortalByMXID(mxid id.RoomID) *Portal { - br.portalsLock.Lock() - defer br.portalsLock.Unlock() - - portal, ok := br.portalsByMXID[mxid] - if !ok { - return br.loadPortal(br.DB.Portal.GetByMXID(mxid), nil, -1) - } - - return portal -} - -func (user *User) GetPortalByMeta(meta *discordgo.Channel) *Portal { - return user.GetPortalByID(meta.ID, meta.Type) -} - -func (user *User) GetExistingPortalByID(id string) *Portal { - return user.bridge.GetExistingPortalByID(database.NewPortalKey(id, user.DiscordID)) -} - -func (user *User) GetPortalByID(id string, chanType discordgo.ChannelType) *Portal { - return user.bridge.GetPortalByID(database.NewPortalKey(id, user.DiscordID), chanType) -} - -func (user *User) FindPrivateChatWith(userID string) *Portal { - user.bridge.portalsLock.Lock() - defer user.bridge.portalsLock.Unlock() - dbPortal := user.bridge.DB.Portal.FindPrivateChatBetween(userID, user.DiscordID) - if dbPortal == nil { - return nil - } - existing, ok := user.bridge.portalsByID[dbPortal.Key] - if ok { - return existing - } - return user.bridge.loadPortal(dbPortal, nil, discordgo.ChannelTypeDM) -} - -func (br *DiscordBridge) GetExistingPortalByID(key database.PortalKey) *Portal { - br.portalsLock.Lock() - defer br.portalsLock.Unlock() - portal, ok := br.portalsByID[key] - if !ok { - if key.Receiver != "" { - portal, ok = br.portalsByID[database.NewPortalKey(key.ChannelID, "")] - } - if !ok { - return br.loadPortal(br.DB.Portal.GetByID(key), nil, -1) - } - } - - return portal -} - -func (br *DiscordBridge) GetPortalByID(key database.PortalKey, chanType discordgo.ChannelType) *Portal { - br.portalsLock.Lock() - defer br.portalsLock.Unlock() - if chanType != discordgo.ChannelTypeDM { - key.Receiver = "" - } - - portal, ok := br.portalsByID[key] - if !ok { - return br.loadPortal(br.DB.Portal.GetByID(key), &key, chanType) - } - - return portal -} - -func (br *DiscordBridge) GetAllPortals() []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAll()) -} - -func (br *DiscordBridge) GetAllPortalsInGuild(guildID string) []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAllInGuild(guildID)) -} - -func (br *DiscordBridge) GetAllIPortals() (iportals []bridge.Portal) { - portals := br.GetAllPortals() - iportals = make([]bridge.Portal, len(portals)) - for i, portal := range portals { - iportals[i] = portal - } - return iportals -} - -func (br *DiscordBridge) GetDMPortalsWith(otherUserID string) []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.FindPrivateChatsWith(otherUserID)) -} - -func (br *DiscordBridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { - br.portalsLock.Lock() - defer br.portalsLock.Unlock() - - output := make([]*Portal, len(dbPortals)) - for index, dbPortal := range dbPortals { - if dbPortal == nil { - continue - } - - portal, ok := br.portalsByID[dbPortal.Key] - if !ok { - portal = br.loadPortal(dbPortal, nil, -1) - } - - output[index] = portal - } - - return output -} - -func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal { - portal := &Portal{ - Portal: dbPortal, - bridge: br, - log: br.ZLog.With(). - Str("channel_id", dbPortal.Key.ChannelID). - Str("channel_receiver", dbPortal.Key.Receiver). - Str("room_id", dbPortal.MXID.String()). - Logger(), - - discordMessages: make(chan portalDiscordMessage, br.Config.Bridge.PortalMessageBuffer), - matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), - - recentMessages: exsync.NewRingBuffer[string, *discordgo.Message](recentMessageBufferSize), - - commands: make(map[string]*discordgo.ApplicationCommand), - } - - go portal.messageLoop() - - return portal -} - -func (portal *Portal) messageLoop() { - for { - select { - case msg := <-portal.matrixMessages: - portal.handleMatrixMessages(msg) - case msg := <-portal.discordMessages: - portal.handleDiscordMessages(msg) - } - } -} - -func (portal *Portal) IsPrivateChat() bool { - return portal.Type == discordgo.ChannelTypeDM -} - -func (portal *Portal) MainIntent() *appservice.IntentAPI { - if portal.IsPrivateChat() && portal.OtherUserID != "" { - return portal.bridge.GetPuppetByID(portal.OtherUserID).DefaultIntent() - } - - return portal.bridge.Bot -} - -type CustomBridgeInfoContent struct { - event.BridgeEventContent - RoomType string `json:"com.beeper.room_type,omitempty"` - RoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` -} - -func init() { - event.TypeMap[event.StateBridge] = reflect.TypeOf(CustomBridgeInfoContent{}) - event.TypeMap[event.StateHalfShotBridge] = reflect.TypeOf(CustomBridgeInfoContent{}) -} - -func (portal *Portal) getBridgeInfo() (string, CustomBridgeInfoContent) { - bridgeInfo := event.BridgeEventContent{ - BridgeBot: portal.bridge.Bot.UserID, - Creator: portal.MainIntent().UserID, - Protocol: event.BridgeInfoSection{ - ID: "discordgo", - DisplayName: "Discord", - AvatarURL: portal.bridge.Config.AppService.Bot.ParsedAvatar.CUString(), - ExternalURL: "https://discord.com/", - }, - Channel: event.BridgeInfoSection{ - ID: portal.Key.ChannelID, - DisplayName: portal.Name, - }, - } - var bridgeInfoStateKey string - if portal.GuildID == "" { - bridgeInfoStateKey = fmt.Sprintf("fi.mau.discord://discord/dm/%s", portal.Key.ChannelID) - bridgeInfo.Channel.ExternalURL = fmt.Sprintf("https://discord.com/channels/@me/%s", portal.Key.ChannelID) - } else { - bridgeInfo.Network = &event.BridgeInfoSection{ - ID: portal.GuildID, - } - if portal.Guild != nil { - bridgeInfo.Network.DisplayName = portal.Guild.Name - bridgeInfo.Network.AvatarURL = portal.Guild.AvatarURL.CUString() - // TODO is it possible to find the URL? - } - bridgeInfoStateKey = fmt.Sprintf("fi.mau.discord://discord/%s/%s", portal.GuildID, portal.Key.ChannelID) - bridgeInfo.Channel.ExternalURL = fmt.Sprintf("https://discord.com/channels/%s/%s", portal.GuildID, portal.Key.ChannelID) - } - var roomType string - if portal.Type == discordgo.ChannelTypeDM || portal.Type == discordgo.ChannelTypeGroupDM { - roomType = "dm" - } - var roomTypeV2 string - if portal.Type == discordgo.ChannelTypeDM { - roomTypeV2 = "dm" - } else if portal.Type == discordgo.ChannelTypeGroupDM { - roomTypeV2 = "group_dm" - } - - return bridgeInfoStateKey, CustomBridgeInfoContent{bridgeInfo, roomType, roomTypeV2} -} - -func (portal *Portal) UpdateBridgeInfo() { - if len(portal.MXID) == 0 { - portal.log.Debug().Msg("Not updating bridge info: no Matrix room created") - return - } - portal.log.Debug().Msg("Updating bridge info...") - stateKey, content := portal.getBridgeInfo() - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateBridge, stateKey, content) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to update m.bridge") - } - // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - _, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateHalfShotBridge, stateKey, content) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to update uk.half-shot.bridge") - } -} - -func (portal *Portal) shouldSetDMRoomMetadata() bool { - return !portal.IsPrivateChat() || - portal.bridge.Config.Bridge.PrivateChatPortalMeta == "always" || - (portal.IsEncrypted() && portal.bridge.Config.Bridge.PrivateChatPortalMeta != "never") -} - -func (portal *Portal) GetEncryptionEventContent() (evt *event.EncryptionEventContent) { - evt = &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} - if rot := portal.bridge.Config.Bridge.Encryption.Rotation; rot.EnableCustom { - evt.RotationPeriodMillis = rot.Milliseconds - evt.RotationPeriodMessages = rot.Messages - } - return -} - -func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) error { - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - if portal.MXID != "" { - portal.ensureUserInvited(user, false) - return nil - } - portal.log.Info().Msg("Creating Matrix room for channel") - - channel = portal.UpdateInfo(user, channel) - if channel == nil { - return fmt.Errorf("didn't find channel metadata") - } - - intent := portal.MainIntent() - if err := intent.EnsureRegistered(); err != nil { - return err - } - - bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - initialState := []*event.Event{{ - Type: event.StateBridge, - Content: event.Content{Parsed: bridgeInfo}, - StateKey: &bridgeInfoStateKey, - }, { - // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - Type: event.StateHalfShotBridge, - Content: event.Content{Parsed: bridgeInfo}, - StateKey: &bridgeInfoStateKey, - }} - - var invite []id.UserID - - if portal.bridge.Config.Bridge.Encryption.Default { - initialState = append(initialState, &event.Event{ - Type: event.StateEncryption, - Content: event.Content{ - Parsed: portal.GetEncryptionEventContent(), - }, - }) - portal.Encrypted = true - - if portal.IsPrivateChat() { - invite = append(invite, portal.bridge.Bot.UserID) - } - } - - if !portal.AvatarURL.IsEmpty() && portal.shouldSetDMRoomMetadata() { - initialState = append(initialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{Parsed: &event.RoomAvatarEventContent{ - URL: portal.AvatarURL, - }}, - }) - portal.AvatarSet = true - } else { - portal.AvatarSet = false - } - - creationContent := make(map[string]interface{}) - if portal.Type == discordgo.ChannelTypeGuildCategory { - creationContent["type"] = event.RoomTypeSpace - } - if !portal.bridge.Config.Bridge.FederateRooms { - creationContent["m.federate"] = false - } - spaceID := portal.ExpectedSpaceID() - if spaceID != "" { - spaceIDStr := spaceID.String() - initialState = append(initialState, &event.Event{ - Type: event.StateSpaceParent, - StateKey: &spaceIDStr, - Content: event.Content{Parsed: &event.SpaceParentEventContent{ - Via: []string{portal.bridge.AS.HomeserverDomain}, - Canonical: true, - }}, - }) - } - if portal.bridge.Config.Bridge.RestrictedRooms && portal.Guild != nil && portal.Guild.MXID != "" { - // TODO don't do this for private channels in guilds - initialState = append(initialState, &event.Event{ - Type: event.StateJoinRules, - Content: event.Content{Parsed: &event.JoinRulesEventContent{ - JoinRule: event.JoinRuleRestricted, - Allow: []event.JoinRuleAllow{{ - RoomID: portal.Guild.MXID, - Type: event.JoinRuleAllowRoomMembership, - }}, - }}, - }) - } - - req := &mautrix.ReqCreateRoom{ - Visibility: "private", - Name: portal.Name, - Topic: portal.Topic, - Invite: invite, - Preset: "private_chat", - IsDirect: portal.IsPrivateChat(), - InitialState: initialState, - CreationContent: creationContent, - RoomVersion: "11", - } - if !portal.shouldSetDMRoomMetadata() && !portal.FriendNick { - req.Name = "" - } - - var backfillStarted bool - portal.forwardBackfillLock.Lock() - defer func() { - if !backfillStarted { - portal.log.Debug().Msg("Backfill wasn't started, unlocking forward backfill lock") - portal.forwardBackfillLock.Unlock() - } - }() - - resp, err := intent.CreateRoom(req) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to create room") - return err - } - - portal.NameSet = len(req.Name) > 0 - portal.TopicSet = len(req.Topic) > 0 - portal.MXID = resp.RoomID - portal.log = portal.bridge.ZLog.With(). - Str("channel_id", portal.Key.ChannelID). - Str("channel_receiver", portal.Key.Receiver). - Str("room_id", portal.MXID.String()). - Logger() - portal.bridge.portalsLock.Lock() - portal.bridge.portalsByMXID[portal.MXID] = portal - portal.bridge.portalsLock.Unlock() - portal.Update() - portal.log.Info().Msg("Matrix room created") - - if portal.Encrypted && portal.IsPrivateChat() { - err = portal.bridge.Bot.EnsureJoined(portal.MXID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) - if err != nil { - portal.log.Err(err).Msg("Failed to ensure bridge bot is joined to encrypted private chat portal") - } - } - - if portal.GuildID == "" { - user.addPrivateChannelToSpace(portal) - } else { - portal.updateSpace(user) - } - portal.ensureUserInvited(user, true) - user.syncChatDoublePuppetDetails(portal, true) - - portal.syncParticipants(user, channel.Recipients) - - if portal.IsPrivateChat() { - puppet := user.bridge.GetPuppetByID(portal.Key.Receiver) - - chats := map[id.UserID][]id.RoomID{puppet.MXID: {portal.MXID}} - user.updateDirectChats(chats) - } - - firstEventResp, err := portal.MainIntent().SendMessageEvent(portal.MXID, portalCreationDummyEvent, struct{}{}) - if err != nil { - portal.log.Err(err).Msg("Failed to send dummy event to mark portal creation") - } else { - portal.FirstEventID = firstEventResp.EventID - portal.Update() - } - - go portal.forwardBackfillInitial(user, nil) - backfillStarted = true - - return nil -} - -func (portal *Portal) handleDiscordMessages(msg portalDiscordMessage) { - if portal.MXID == "" { - msgCreate, ok := msg.msg.(*discordgo.MessageCreate) - if !ok { - portal.log.Warn().Msg("Can't create Matrix room from non new message event") - return - } - - portal.log.Debug(). - Str("message_id", msgCreate.ID). - Msg("Creating Matrix room from incoming message") - if err := portal.CreateMatrixRoom(msg.user, nil); err != nil { - portal.log.Err(err).Msg("Failed to create portal room") - return - } - } - portal.forwardBackfillLock.Lock() - defer portal.forwardBackfillLock.Unlock() - - switch convertedMsg := msg.msg.(type) { - case *discordgo.MessageCreate: - portal.handleDiscordMessageCreate(msg.user, convertedMsg.Message, msg.thread) - case *discordgo.MessageUpdate: - portal.handleDiscordMessageUpdate(msg.user, convertedMsg.Message) - case *discordgo.MessageDelete: - portal.handleDiscordMessageDelete(msg.user, convertedMsg.Message) - case *discordgo.MessageDeleteBulk: - portal.handleDiscordMessageDeleteBulk(msg.user, convertedMsg.Messages) - case *discordgo.MessageReactionAdd: - portal.handleDiscordReaction(msg.user, convertedMsg.MessageReaction, true, msg.thread, convertedMsg.Member) - case *discordgo.MessageReactionRemove: - portal.handleDiscordReaction(msg.user, convertedMsg.MessageReaction, false, msg.thread, nil) - default: - portal.log.Warn().Type("message_type", msg.msg).Msg("Unknown message type in handleDiscordMessages") - } -} - -func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool { - return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache) -} - -func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) *database.Message { - msg := portal.bridge.DB.Message.New() - msg.Channel = portal.Key - msg.DiscordID = discordID - msg.SenderID = authorID - msg.Timestamp = timestamp - msg.ThreadID = threadID - msg.SenderMXID = senderMXID - msg.MassInsertParts(parts) - msg.MXID = parts[0].MXID - msg.AttachmentID = parts[0].AttachmentID - return msg -} - -func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { - switch msg.Type { - case discordgo.MessageTypeChannelNameChange, discordgo.MessageTypeChannelIconChange, discordgo.MessageTypeChannelPinnedMessage: - // These are handled via channel updates - return - } - - log := portal.log.With(). - Str("message_id", msg.ID). - Int("message_type", int(msg.Type)). - Str("author_id", msg.Author.ID). - Str("action", "discord message create"). - Logger() - ctx := log.WithContext(context.Background()) - - portal.recentMessages.Push(msg.ID, msg) - - existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) - if existing != nil { - log.Debug().Msg("Dropping duplicate message") - return - } - - handlingStartTime := time.Now() - puppet := portal.bridge.GetPuppetByID(msg.Author.ID) - puppet.UpdateInfo(user, msg.Author, msg) - intent := puppet.IntentFor(portal) - - var discordThreadID string - var threadRootEvent, lastThreadEvent id.EventID - if thread != nil { - discordThreadID = thread.ID - threadRootEvent = thread.RootMXID - lastThreadEvent = threadRootEvent - lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) - if lastInThread != nil { - lastThreadEvent = lastInThread.MXID - } - } - replyTo := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false) - mentions := portal.convertDiscordMentions(msg, true) - - ts, _ := discordgo.SnowflakeTimestamp(msg.ID) - parts := portal.convertDiscordMessage(ctx, puppet, intent, msg) - dbParts := make([]database.MessagePart, 0, len(parts)) - eventIDs := zerolog.Dict() - for i, part := range parts { - if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil { - part.Content.RelatesTo = &event.RelatesTo{} - } - if threadRootEvent != "" { - part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent) - } - if replyTo != nil { - part.Content.RelatesTo.SetReplyTo(replyTo.EventID) - if replyTo.UnstableRoomID != "" { - part.Content.RelatesTo.InReplyTo.UnstableRoomID = replyTo.UnstableRoomID - } - // Only set reply for first event - replyTo = nil - } - - part.Content.Mentions = mentions - // Only set mentions for first event, but keep empty object for rest - mentions = &event.Mentions{} - - resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli()) - if err != nil { - log.Err(err). - Int("part_index", i). - Str("attachment_id", part.AttachmentID). - Msg("Failed to send part of message to Matrix") - continue - } - lastThreadEvent = resp.EventID - dbParts = append(dbParts, database.MessagePart{AttachmentID: part.AttachmentID, MXID: resp.EventID}) - eventIDs.Str(part.AttachmentID, resp.EventID.String()) - } - - log = log.With().Dur("handling_time", time.Since(handlingStartTime)).Logger() - if len(parts) == 0 { - log.Warn().Msg("Unhandled message") - } else if len(dbParts) == 0 { - log.Warn().Msg("All parts of message failed to send to Matrix") - } else { - log.Debug().Dict("event_ids", eventIDs).Msg("Finished handling Discord message") - firstDBMessage := portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts) - if msg.Flags == discordgo.MessageFlagsHasThread { - portal.bridge.threadFound(ctx, user, firstDBMessage, msg.ID, msg.Thread) - } - } -} - -var hackyReplyPattern = regexp.MustCompile(`^\*\*\[Replying to]\(https://discord.com/channels/(\d+)/(\d+)/(\d+)\)`) - -func isReplyEmbed(embed *discordgo.MessageEmbed) bool { - return hackyReplyPattern.MatchString(embed.Description) -} - -func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) *event.InReplyTo { - if ref == nil && len(embeds) > 0 { - match := hackyReplyPattern.FindStringSubmatch(embeds[0].Description) - if match != nil && match[1] == portal.GuildID && (match[2] == portal.Key.ChannelID || match[2] == threadID) { - ref = &discordgo.MessageReference{ - MessageID: match[3], - ChannelID: match[2], - GuildID: match[1], - } - } - } - if ref == nil { - return nil - } - // TODO add config option for cross-room replies - crossRoomReplies := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry - - targetPortal := portal - if ref.ChannelID != portal.Key.ChannelID && ref.ChannelID != threadID && crossRoomReplies { - targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID}) - if targetPortal == nil { - return nil - } - } - replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID) - if len(replyToMsg) > 0 { - if !crossRoomReplies { - return &event.InReplyTo{EventID: replyToMsg[0].MXID} - } - return &event.InReplyTo{ - EventID: replyToMsg[0].MXID, - UnstableRoomID: targetPortal.MXID, - } - } else if allowNonExistent { - return &event.InReplyTo{ - EventID: targetPortal.deterministicEventID(ref.MessageID, ""), - UnstableRoomID: targetPortal.MXID, - } - } - return nil -} - -const JoinThreadReaction = "join thread" - -func (portal *Portal) sendThreadCreationNotice(ctx context.Context, thread *Thread) { - thread.creationNoticeLock.Lock() - defer thread.creationNoticeLock.Unlock() - if thread.CreationNoticeMXID != "" { - return - } - creationNotice := "Thread created. React to this message with \"join thread\" to join the thread on Discord." - if portal.bridge.Config.Bridge.AutojoinThreadOnOpen { - creationNotice = "Thread created. Opening this thread will auto-join you to it on Discord." - } - log := zerolog.Ctx(ctx) - resp, err := portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ - Body: creationNotice, - MsgType: event.MsgNotice, - RelatesTo: (&event.RelatesTo{}).SetThread(thread.RootMXID, thread.RootMXID), - }, nil, time.Now().UnixMilli()) - if err != nil { - log.Err(err).Msg("Failed to send thread creation notice") - return - } - portal.bridge.threadsLock.Lock() - thread.CreationNoticeMXID = resp.EventID - portal.bridge.threadsByCreationNoticeMXID[resp.EventID] = thread - portal.bridge.threadsLock.Unlock() - thread.Update() - log.Debug(). - Str("creation_notice_mxid", thread.CreationNoticeMXID.String()). - Msg("Sent thread creation notice") - - resp, err = portal.MainIntent().SendMessageEvent(portal.MXID, event.EventReaction, &event.ReactionEventContent{ - RelatesTo: event.RelatesTo{ - Type: event.RelAnnotation, - EventID: thread.CreationNoticeMXID, - Key: JoinThreadReaction, - }, - }) - if err != nil { - log.Err(err).Msg("Failed to send prefilled reaction to thread creation notice") - } else { - log.Debug(). - Str("reaction_event_id", resp.EventID.String()). - Str("creation_notice_mxid", thread.CreationNoticeMXID.String()). - Msg("Sent prefilled reaction to thread creation notice") - } -} - -func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { - log := portal.log.With(). - Str("message_id", msg.ID). - Str("action", "discord message update"). - Logger() - ctx := log.WithContext(context.Background()) - if portal.MXID == "" { - log.Warn().Msg("handle message called without a valid portal") - return - } - - existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) - if existing == nil { - log.Warn().Msg("Dropping update of unknown message") - return - } - if msg.EditedTimestamp != nil && !msg.EditedTimestamp.After(existing[0].EditTimestamp) { - log.Debug(). - Time("received_edit_ts", *msg.EditedTimestamp). - Time("db_edit_ts", existing[0].EditTimestamp). - Msg("Dropping update of message with older or equal edit timestamp") - return - } - - if msg.Flags == discordgo.MessageFlagsHasThread { - portal.bridge.threadFound(ctx, user, existing[0], msg.ID, msg.Thread) - } - - if msg.Author == nil { - creationMessage, ok := portal.recentMessages.Get(msg.ID) - if !ok { - log.Debug().Msg("Dropping edit with no author of non-recent message") - return - } else if creationMessage.Type == discordgo.MessageTypeCall { - log.Debug().Msg("Dropping edit with of call message") - return - } - log.Debug().Msg("Found original message in cache for edit without author") - if len(msg.Embeds) > 0 { - creationMessage.Embeds = msg.Embeds - } - if len(msg.Attachments) > 0 { - creationMessage.Attachments = msg.Attachments - } - if len(msg.Components) > 0 { - creationMessage.Components = msg.Components - } - // TODO are there other fields that need copying? - msg = creationMessage - } else { - portal.recentMessages.Replace(msg.ID, msg) - } - if msg.Author.ID == portal.RelayWebhookID { - log.Debug(). - Str("message_id", msg.ID). - Str("author_id", msg.Author.ID). - Msg("Dropping edit from relay webhook") - return - } - - puppet := portal.bridge.GetPuppetByID(msg.Author.ID) - intent := puppet.IntentFor(portal) - - redactions := zerolog.Dict() - attachmentMap := map[string]*database.Message{} - for _, existingPart := range existing { - if existingPart.AttachmentID != "" { - attachmentMap[existingPart.AttachmentID] = existingPart - } - } - for _, remainingAttachment := range msg.Attachments { - if _, found := attachmentMap[remainingAttachment.ID]; found { - delete(attachmentMap, remainingAttachment.ID) - } - } - for _, remainingSticker := range msg.StickerItems { - if _, found := attachmentMap[remainingSticker.ID]; found { - delete(attachmentMap, remainingSticker.ID) - } - } - for _, remainingEmbed := range msg.Embeds { - // Other types of embeds are sent inline with the text message part - if getEmbedType(nil, remainingEmbed) != EmbedVideo { - continue - } - embedID := "video_" + remainingEmbed.URL - if _, found := attachmentMap[embedID]; found { - delete(attachmentMap, embedID) - } - } - for _, deletedAttachment := range attachmentMap { - resp, err := intent.RedactEvent(portal.MXID, deletedAttachment.MXID) - if err != nil { - log.Err(err). - Str("event_id", deletedAttachment.MXID.String()). - Msg("Failed to redact attachment") - } else { - redactions.Str(deletedAttachment.AttachmentID, resp.EventID.String()) - } - deletedAttachment.Delete() - } - - var converted *ConvertedMessage - // Slightly hacky special case: messages with gif links will get an embed with the gif. - // The link isn't rendered on Discord, so just edit the link message into a gif message on Matrix too. - if isPlainGifMessage(msg) { - converted = portal.convertDiscordVideoEmbed(ctx, intent, msg.Embeds[0]) - } else { - converted = portal.convertDiscordTextMessage(ctx, intent, msg) - } - if converted == nil { - log.Debug(). - Bool("has_message_on_matrix", existing[0].AttachmentID == ""). - Bool("has_text_on_discord", len(msg.Content) > 0). - Msg("Dropping non-text edit") - return - } - puppet.addWebhookMeta(converted, msg) - puppet.addMemberMeta(converted, msg) - converted.Content.Mentions = portal.convertDiscordMentions(msg, false) - converted.Content.SetEdit(existing[0].MXID) - // Never actually mention new users of edits, only include mentions inside m.new_content - converted.Content.Mentions = &event.Mentions{} - if converted.Extra != nil { - converted.Extra = map[string]any{ - "m.new_content": converted.Extra, - } - } - - var editTS int64 - if msg.EditedTimestamp != nil { - editTS = msg.EditedTimestamp.UnixMilli() - } - // TODO figure out some way to deduplicate outgoing edits - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, converted.Content, converted.Extra, editTS) - if err != nil { - log.Err(err).Msg("Failed to send edit to Matrix") - return - } - - portal.sendDeliveryReceipt(resp.EventID) - - if msg.EditedTimestamp != nil { - existing[0].UpdateEditTimestamp(*msg.EditedTimestamp) - } - log.Debug(). - Str("event_id", resp.EventID.String()). - Dict("redacted_attachments", redactions). - Msg("Finished handling Discord edit") -} - -func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { - lastResp := portal.redactAllParts(portal.MainIntent(), msg.ID) - if lastResp != "" { - portal.sendDeliveryReceipt(lastResp) - } -} - -func (portal *Portal) handleDiscordMessageDeleteBulk(user *User, messages []string) { - intent := portal.MainIntent() - var lastResp id.EventID - for _, msgID := range messages { - newLastResp := portal.redactAllParts(intent, msgID) - if newLastResp != "" { - lastResp = newLastResp - } - } - if lastResp != "" { - portal.sendDeliveryReceipt(lastResp) - } -} - -func (portal *Portal) redactAllParts(intent *appservice.IntentAPI, msgID string) (lastResp id.EventID) { - existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msgID) - for _, dbMsg := range existing { - resp, err := intent.RedactEvent(portal.MXID, dbMsg.MXID) - if err != nil { - portal.log.Err(err). - Str("message_id", msgID). - Str("event_id", dbMsg.MXID.String()). - Msg("Failed to redact Matrix message") - } else if resp != nil && resp.EventID != "" { - lastResp = resp.EventID - } - dbMsg.Delete() - } - return -} - -func (portal *Portal) handleDiscordTyping(evt *discordgo.TypingStart) { - puppet := portal.bridge.GetPuppetByID(evt.UserID) - if puppet.Name == "" { - // Puppet hasn't been synced yet - return - } - log := portal.log.With(). - Str("ghost_mxid", puppet.MXID.String()). - Str("action", "discord typing"). - Logger() - intent := puppet.IntentFor(portal) - err := intent.EnsureJoined(portal.MXID) - if err != nil { - log.Warn().Err(err).Msg("Failed to ensure ghost is joined for typing notification") - return - } - _, err = intent.UserTyping(portal.MXID, true, 12*time.Second) - if err != nil { - log.Warn().Err(err).Msg("Failed to send typing notification to Matrix") - } -} - -func (portal *Portal) syncParticipant(source *User, participant *discordgo.User, remove bool) { - puppet := portal.bridge.GetPuppetByID(participant.ID) - puppet.UpdateInfo(source, participant, nil) - log := portal.log.With(). - Str("participant_id", participant.ID). - Str("ghost_mxid", puppet.MXID.String()). - Logger() - - user := portal.bridge.GetUserByID(participant.ID) - if user != nil { - log.Debug().Msg("Ensuring Matrix user is invited or joined to room") - portal.ensureUserInvited(user, false) - } - - if remove { - _, err := puppet.DefaultIntent().LeaveRoom(portal.MXID) - if err != nil { - log.Warn().Err(err).Msg("Failed to make ghost leave room after member remove event") - } - } else if user == nil || !puppet.IntentFor(portal).IsCustomPuppet { - if err := puppet.IntentFor(portal).EnsureJoined(portal.MXID); err != nil { - log.Warn().Err(err).Msg("Failed to add ghost to room") - } - } -} - -func (portal *Portal) syncParticipants(source *User, participants []*discordgo.User) { - for _, participant := range participants { - puppet := portal.bridge.GetPuppetByID(participant.ID) - puppet.UpdateInfo(source, participant, nil) - - var user *User - if participant.ID != portal.OtherUserID { - user = portal.bridge.GetUserByID(participant.ID) - if user != nil { - portal.ensureUserInvited(user, false) - } - } - - if user == nil || !puppet.IntentFor(portal).IsCustomPuppet { - if err := puppet.IntentFor(portal).EnsureJoined(portal.MXID); err != nil { - portal.log.Warn().Err(err). - Str("participant_id", participant.ID). - Msg("Failed to add ghost to room") - } - } - } -} - -func (portal *Portal) encrypt(intent *appservice.IntentAPI, content *event.Content, eventType event.Type) (event.Type, error) { - if !portal.Encrypted || portal.bridge.Crypto == nil { - return eventType, nil - } - intent.AddDoublePuppetValue(content) - // TODO maybe the locking should be inside mautrix-go? - portal.encryptLock.Lock() - err := portal.bridge.Crypto.Encrypt(portal.MXID, eventType, content) - portal.encryptLock.Unlock() - if err != nil { - return eventType, fmt.Errorf("failed to encrypt event: %w", err) - } - return event.EventEncrypted, nil -} - -func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { - wrappedContent := event.Content{Parsed: content, Raw: extraContent} - var err error - eventType, err = portal.encrypt(intent, &wrappedContent, eventType) - if err != nil { - return nil, err - } - - _, _ = intent.UserTyping(portal.MXID, false, 0) - if timestamp == 0 { - return intent.SendMessageEvent(portal.MXID, eventType, &wrappedContent) - } else { - return intent.SendMassagedMessageEvent(portal.MXID, eventType, &wrappedContent, timestamp) - } -} - -func (portal *Portal) handleMatrixMessages(msg portalMatrixMessage) { - portal.forwardBackfillLock.Lock() - defer portal.forwardBackfillLock.Unlock() - switch msg.evt.Type { - case event.EventMessage, event.EventSticker: - portal.handleMatrixMessage(msg.user, msg.evt) - case event.EventRedaction: - portal.handleMatrixRedaction(msg.user, msg.evt) - case event.EventReaction: - portal.handleMatrixReaction(msg.user, msg.evt) - default: - portal.log.Warn().Str("event_type", msg.evt.Type.Type).Msg("Unknown event type in handleMatrixMessages") - } -} - -const discordEpoch = 1420070400000 - -func generateNonce() string { - snowflake := (time.Now().UnixMilli() - discordEpoch) << 22 - // Nonce snowflakes don't have internal IDs or increments - return strconv.FormatInt(snowflake, 10) -} - -func (portal *Portal) getEvent(mxid id.EventID) (*event.Event, error) { - evt, err := portal.MainIntent().GetEvent(portal.MXID, mxid) - if err != nil { - return nil, err - } - _ = evt.Content.ParseRaw(evt.Type) - if evt.Type == event.EventEncrypted { - decryptedEvt, err := portal.bridge.Crypto.Decrypt(evt) - if err != nil { - return nil, fmt.Errorf("failed to decrypt event: %w", err) - } else { - evt = decryptedEvt - } - } - return evt, nil -} - -func genThreadName(evt *event.Event) string { - body := evt.Content.AsMessage().Body - if len(body) == 0 { - return "thread" - } - fields := strings.Fields(body) - var title string - for _, field := range fields { - if len(title)+len(field) < 40 { - title += field - title += " " - continue - } - if len(title) == 0 { - title = field[:40] - } - break - } - return title -} - -func (portal *Portal) startThreadFromMatrix(sender *User, threadRoot id.EventID) (string, error) { - rootEvt, err := portal.getEvent(threadRoot) - if err != nil { - return "", fmt.Errorf("failed to get root event: %w", err) - } - threadName := genThreadName(rootEvt) - - existingMsg := portal.bridge.DB.Message.GetByMXID(portal.Key, threadRoot) - if existingMsg == nil { - return "", fmt.Errorf("unknown root event") - } else if existingMsg.ThreadID != "" { - return "", fmt.Errorf("root event is already in a thread") - } else { - var ch *discordgo.Channel - ch, err = sender.Session.MessageThreadStartComplex(portal.Key.ChannelID, existingMsg.DiscordID, &discordgo.ThreadStart{ - Name: threadName, - AutoArchiveDuration: 24 * 60, - Type: discordgo.ChannelTypeGuildPublicThread, - Location: "Message", - }, portal.RefererOptIfUser(sender.Session, "")...) - if err != nil { - return "", fmt.Errorf("error starting thread: %v", err) - } - portal.log.Debug(). - Str("thread_root_mxid", threadRoot.String()). - Str("thread_id", ch.ID). - Msg("Created Discord thread") - portal.bridge.GetThreadByID(existingMsg.DiscordID, existingMsg) - return ch.ID, nil - } -} - -func (portal *Portal) sendErrorMessage(evt *event.Event, msgType, message string, confirmed bool) id.EventID { - if !portal.bridge.Config.Bridge.MessageErrorNotices { - return "" - } - certainty := "may not have been" - if confirmed { - certainty = "was not" - } - if portal.RelayWebhookSecret != "" { - message = strings.ReplaceAll(message, portal.RelayWebhookSecret, "") - } - content := &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, message), - } - relatable, ok := evt.Content.Parsed.(event.Relatable) - if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { - content.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) - } - resp, err := portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, content, nil, 0) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to send bridging error message") - return "" - } - return resp.EventID -} - -var ( - errUnknownMsgType = errors.New("unknown msgtype") - errUnexpectedParsedContentType = errors.New("unexpected parsed content type") - errUserNotReceiver = errors.New("user is not portal receiver") - errUserNotLoggedIn = errors.New("user is not logged in and portal doesn't have webhook") - errUnknownEditTarget = errors.New("unknown edit target") - errUnknownRelationType = errors.New("unknown relation type") - errTargetNotFound = errors.New("target event not found") - errUnknownEmoji = errors.New("unknown emoji") - errCantStartThread = errors.New("can't create thread without being logged into Discord") -) - -func errorToStatusReason(err error) (reason event.MessageStatusReason, status event.MessageStatus, isCertain, sendNotice bool, humanMessage string, checkpointError error) { - var restErr *discordgo.RESTError - switch { - case errors.Is(err, errUnknownMsgType), - errors.Is(err, errUnknownRelationType), - errors.Is(err, errUnexpectedParsedContentType), - errors.Is(err, errUnknownEmoji), - errors.Is(err, id.InvalidContentURI), - errors.Is(err, attachment.UnsupportedVersion), - errors.Is(err, attachment.UnsupportedAlgorithm), - errors.Is(err, errCantStartThread): - return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, "", nil - case errors.Is(err, attachment.HashMismatch), - errors.Is(err, attachment.InvalidKey), - errors.Is(err, attachment.InvalidInitVector): - return event.MessageStatusUndecryptable, event.MessageStatusFail, true, true, "", nil - case errors.Is(err, errUserNotReceiver), errors.Is(err, errUserNotLoggedIn): - return event.MessageStatusNoPermission, event.MessageStatusFail, true, false, "", nil - case errors.Is(err, errUnknownEditTarget): - return event.MessageStatusGenericError, event.MessageStatusFail, true, false, "", nil - case errors.Is(err, errTargetNotFound): - return event.MessageStatusGenericError, event.MessageStatusFail, true, false, "", nil - case errors.As(err, &restErr): - if restErr.Message != nil && (restErr.Message.Code != 0 || len(restErr.Message.Message) > 0) { - reason, humanMessage = restErrorToStatusReason(restErr.Message) - status = event.MessageStatusFail - isCertain = true - sendNotice = true - checkpointError = fmt.Errorf("HTTP %d: %d: %s", restErr.Response.StatusCode, restErr.Message.Code, restErr.Message.Message) - if len(restErr.Message.Errors) > 0 { - jsonExtraErrors, _ := json.Marshal(restErr.Message.Errors) - checkpointError = fmt.Errorf("%w (%s)", checkpointError, jsonExtraErrors) - } - return - } else if restErr.Response.StatusCode == http.StatusBadRequest && bytes.HasPrefix(restErr.ResponseBody, []byte(`{"captcha_key"`)) { - return event.MessageStatusGenericError, event.MessageStatusRetriable, true, true, "Captcha error", errors.New("captcha required") - } else if restErr.Response != nil && (restErr.Response.StatusCode == http.StatusServiceUnavailable || restErr.Response.StatusCode == http.StatusBadGateway || restErr.Response.StatusCode == http.StatusGatewayTimeout) { - return event.MessageStatusGenericError, event.MessageStatusRetriable, true, true, fmt.Sprintf("HTTP %s", restErr.Response.Status), fmt.Errorf("HTTP %d", restErr.Response.StatusCode) - } - fallthrough - case errors.Is(err, context.DeadlineExceeded): - return event.MessageStatusTooOld, event.MessageStatusRetriable, false, true, "", context.DeadlineExceeded - case strings.HasSuffix(err.Error(), "(Client.Timeout exceeded while awaiting headers)"): - return event.MessageStatusTooOld, event.MessageStatusRetriable, false, true, "", errors.New("HTTP request timed out") - case errors.Is(err, syscall.ECONNRESET): - return event.MessageStatusGenericError, event.MessageStatusRetriable, false, true, "", errors.New("connection reset") - default: - return event.MessageStatusGenericError, event.MessageStatusRetriable, false, true, "", nil - } -} - -func restErrorToStatusReason(msg *discordgo.APIErrorMessage) (reason event.MessageStatusReason, humanMessage string) { - switch msg.Code { - case discordgo.ErrCodeRequestEntityTooLarge: - return event.MessageStatusUnsupported, "Attachment is too large" - case discordgo.ErrCodeUnknownEmoji: - return event.MessageStatusUnsupported, "Unsupported emoji" - case discordgo.ErrCodeMissingPermissions, discordgo.ErrCodeMissingAccess: - return event.MessageStatusUnsupported, "You don't have the permissions to do that" - case discordgo.ErrCodeCannotSendMessagesToThisUser: - return event.MessageStatusUnsupported, "You can't send messages to this user" - case discordgo.ErrCodeCannotSendMessagesInVoiceChannel: - return event.MessageStatusUnsupported, "You can't send messages in a non-text channel" - case discordgo.ErrCodeInvalidFormBody: - contentErrs := msg.Errors["content"].Errors - if len(contentErrs) == 1 && contentErrs[0].Code == "BASE_TYPE_MAX_LENGTH" { - return event.MessageStatusUnsupported, "Message is too long: " + contentErrs[0].Message - } - } - return event.MessageStatusGenericError, fmt.Sprintf("%d: %s", msg.Code, msg.Message) -} - -func (portal *Portal) sendStatusEvent(evtID id.EventID, err error) { - if !portal.bridge.Config.Bridge.MessageStatusEvents { - return - } - intent := portal.bridge.Bot - if !portal.Encrypted { - // Bridge bot isn't present in unencrypted DMs - intent = portal.MainIntent() - } - stateKey, _ := portal.getBridgeInfo() - content := event.BeeperMessageStatusEventContent{ - Network: stateKey, - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evtID, - }, - Status: event.MessageStatusSuccess, - } - if err == nil { - content.Status = event.MessageStatusSuccess - } else { - var checkpointErr error - content.Reason, content.Status, _, _, content.Message, checkpointErr = errorToStatusReason(err) - if checkpointErr != nil { - content.Error = checkpointErr.Error() - } else { - content.Error = err.Error() - } - } - _, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content) - if err != nil { - portal.log.Err(err).Str("event_id", evtID.String()).Msg("Failed to send message status event") - } -} - -func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string) { - var msgType string - switch evt.Type { - case event.EventMessage, event.EventSticker: - msgType = "message" - case event.EventReaction: - msgType = "reaction" - case event.EventRedaction: - msgType = "redaction" - default: - msgType = "unknown event" - } - level := zerolog.DebugLevel - if err != nil && part != "Ignoring" { - level = zerolog.ErrorLevel - } - logEvt := portal.log.WithLevel(level). - Str("action", "send matrix message metrics"). - Str("event_type", evt.Type.Type). - Str("event_id", evt.ID.String()). - Str("sender", evt.Sender.String()) - if evt.Type == event.EventRedaction { - logEvt.Str("redacts", evt.Redacts.String()) - } - if err != nil { - logEvt.Err(err). - Str("result", fmt.Sprintf("%s event", part)). - Msg("Matrix event not handled") - reason, statusCode, isCertain, sendNotice, humanMessage, checkpointErr := errorToStatusReason(err) - if checkpointErr == nil { - checkpointErr = err - } - checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode) - portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, checkpointErr, checkpointStatus, 0) - if sendNotice { - if humanMessage == "" { - humanMessage = err.Error() - } - portal.sendErrorMessage(evt, msgType, humanMessage, isCertain) - } - portal.sendStatusEvent(evt.ID, err) - } else { - logEvt.Err(err).Msg("Matrix event handled successfully") - portal.sendDeliveryReceipt(evt.ID) - portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, 0) - portal.sendStatusEvent(evt.ID, nil) - } -} - -func (br *DiscordBridge) serveMediaProxy(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - mxc := id.ContentURI{ - Homeserver: vars["server"], - FileID: vars["mediaID"], - } - checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) - if err != nil || len(checksum) != 32 { - w.WriteHeader(http.StatusNotFound) - return - } - _, expectedChecksum := br.hashMediaProxyURL(mxc) - if !hmac.Equal(checksum, expectedChecksum) { - w.WriteHeader(http.StatusNotFound) - return - } - reader, err := br.Bot.Download(mxc) - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to download media to proxy") - w.WriteHeader(http.StatusInternalServerError) - return - } - buf := make([]byte, 32*1024) - n, err := io.ReadFull(reader, buf) - if err != nil && (!errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF)) { - br.ZLog.Warn().Err(err).Msg("Failed to read first part of media to proxy") - w.WriteHeader(http.StatusBadGateway) - return - } - w.Header().Add("Content-Type", http.DetectContentType(buf[:n])) - if n < len(buf) { - w.Header().Add("Content-Length", strconv.Itoa(n)) - } - w.WriteHeader(http.StatusOK) - _, err = w.Write(buf[:n]) - if err != nil { - return - } - if n >= len(buf) { - _, _ = io.CopyBuffer(w, reader, buf) - } -} - -func (br *DiscordBridge) hashMediaProxyURL(mxc id.ContentURI) (string, []byte) { - path := fmt.Sprintf("/mautrix-discord/avatar/%s/%s/", mxc.Homeserver, mxc.FileID) - checksum := hmac.New(sha256.New, []byte(br.Config.Bridge.AvatarProxyKey)) - checksum.Write([]byte(path)) - return path, checksum.Sum(nil) -} - -func (br *DiscordBridge) makeMediaProxyURL(mxc id.ContentURI) string { - if br.Config.Bridge.PublicAddress == "" { - return "" - } - path, checksum := br.hashMediaProxyURL(mxc) - return br.Config.Bridge.PublicAddress + path + base64.RawURLEncoding.EncodeToString(checksum) -} - -func (portal *Portal) getRelayUserMeta(sender *User) (name, avatarURL string) { - member := portal.bridge.StateStore.GetMember(portal.MXID, sender.MXID) - name = member.Displayname - if name == "" { - name = sender.MXID.String() - } - mxc := member.AvatarURL.ParseOrIgnore() - if !mxc.IsEmpty() && portal.bridge.Config.Bridge.PublicAddress != "" { - avatarURL = portal.bridge.makeMediaProxyURL(mxc) - } - return -} - -const replyEmbedMaxLines = 1 -const replyEmbedMaxChars = 72 - -func cutBody(body string) string { - lines := strings.Split(strings.TrimSpace(body), "\n") - var output string - for i, line := range lines { - if i >= replyEmbedMaxLines { - output += " […]" - break - } - if i > 0 { - output += "\n" - } - output += line - if len(output) > replyEmbedMaxChars { - output = output[:replyEmbedMaxChars] + "…" - break - } - } - return output -} - -func (portal *Portal) convertReplyMessageToEmbed(eventID id.EventID, url string) (*discordgo.MessageEmbed, error) { - evt, err := portal.getEvent(eventID) - if err != nil { - return nil, fmt.Errorf("failed to get reply target event: %w", err) - } - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if !ok { - return nil, fmt.Errorf("unsupported event type %s / %T", evt.Type.String(), evt.Content.Parsed) - } - content.RemoveReplyFallback() - var targetUser string - - puppet := portal.bridge.GetPuppetByMXID(evt.Sender) - if puppet != nil { - targetUser = fmt.Sprintf("<@%s>", puppet.ID) - } else if user := portal.bridge.GetUserByMXID(evt.Sender); user != nil && user.DiscordID != "" { - targetUser = fmt.Sprintf("<@%s>", user.DiscordID) - } else if member := portal.bridge.StateStore.GetMember(portal.MXID, evt.Sender); member != nil && member.Displayname != "" { - targetUser = member.Displayname - } else { - targetUser = evt.Sender.String() - } - body := escapeDiscordMarkdown(cutBody(content.Body)) - body = fmt.Sprintf("**[Replying to](%s) %s**\n%s", url, targetUser, body) - embed := &discordgo.MessageEmbed{Description: body} - return embed, nil -} - -func (portal *Portal) RefererOpt(threadID string) discordgo.RequestOption { - if threadID != "" && threadID != portal.Key.ChannelID { - return discordgo.WithThreadReferer(portal.GuildID, portal.Key.ChannelID, threadID) - } - return discordgo.WithChannelReferer(portal.GuildID, portal.Key.ChannelID) -} - -func (portal *Portal) RefererOptIfUser(sess *discordgo.Session, threadID string) []discordgo.RequestOption { - if sess == nil || !sess.IsUser { - return nil - } - return []discordgo.RequestOption{portal.RefererOpt(threadID)} -} - -func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { - if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { - go portal.sendMessageMetrics(evt, errUserNotReceiver, "Ignoring") - return - } - - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if !ok { - go portal.sendMessageMetrics(evt, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed), "Ignoring") - return - } - - channelID := portal.Key.ChannelID - sess := sender.Session - if sess == nil && portal.RelayWebhookID == "" { - go portal.sendMessageMetrics(evt, errUserNotLoggedIn, "Ignoring") - return - } - isWebhookSend := sess == nil - var threadID string - - if editMXID := content.GetRelatesTo().GetReplaceID(); editMXID != "" && content.NewContent != nil { - edits := portal.bridge.DB.Message.GetByMXID(portal.Key, editMXID) - if edits != nil { - newContentRaw, _ := evt.Content.Raw["m.new_content"].(map[string]any) - discordContent, allowedMentions := portal.parseMatrixHTML(content.NewContent, parseAllowedLinkPreviews(newContentRaw)) - var err error - var msg *discordgo.Message - if !isWebhookSend { - // TODO save edit in message table - msg, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent) - } else { - msg, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{ - Content: &discordContent, - AllowedMentions: allowedMentions, - }) - } - go portal.sendMessageMetrics(evt, err, "Failed to edit") - if msg.EditedTimestamp != nil { - edits.UpdateEditTimestamp(*msg.EditedTimestamp) - } - } else { - go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEditTarget, editMXID), "Ignoring") - } - return - } else if threadRoot := content.GetRelatesTo().GetThreadParent(); threadRoot != "" { - existingThread := portal.bridge.GetThreadByRootMXID(threadRoot) - if existingThread != nil { - threadID = existingThread.ID - existingThread.initialBackfillAttempted = true - } else { - if isWebhookSend { - // TODO start thread with bot? - go portal.sendMessageMetrics(evt, errCantStartThread, "Dropping") - return - } - var err error - threadID, err = portal.startThreadFromMatrix(sender, threadRoot) - if err != nil { - portal.log.Warn().Err(err). - Str("thread_root_mxid", threadRoot.String()). - Msg("Failed to start thread from Matrix") - } - } - } - if threadID != "" { - channelID = threadID - } - - var sendReq discordgo.MessageSend - - var description string - if evt.Type == event.EventSticker { - content.MsgType = event.MsgImage - if mimeData := mimetype.Lookup(content.Info.MimeType); mimeData != nil { - description = content.Body - content.Body = "sticker" + mimeData.Extension() - } - } - - replyToMXID := content.RelatesTo.GetNonFallbackReplyTo() - var replyToUser id.UserID - if replyToMXID != "" { - replyTo := portal.bridge.DB.Message.GetByMXID(portal.Key, replyToMXID) - if replyTo != nil && replyTo.ThreadID == threadID { - replyToUser = replyTo.SenderMXID - if isWebhookSend { - messageURL := fmt.Sprintf("https://discord.com/channels/%s/%s/%s", portal.GuildID, channelID, replyTo.DiscordID) - embed, err := portal.convertReplyMessageToEmbed(replyTo.MXID, messageURL) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to convert reply message to embed for webhook send") - } else if embed != nil { - sendReq.Embeds = []*discordgo.MessageEmbed{embed} - } - } else { - sendReq.Reference = &discordgo.MessageReference{ - ChannelID: channelID, - MessageID: replyTo.DiscordID, - } - } - } - } - switch content.MsgType { - case event.MsgText, event.MsgEmote, event.MsgNotice: - sendReq.Content, sendReq.AllowedMentions = portal.parseMatrixHTML(content, parseAllowedLinkPreviews(evt.Content.Raw)) - if content.MsgType == event.MsgEmote { - sendReq.Content = fmt.Sprintf("_%s_", sendReq.Content) - } - case event.MsgAudio, event.MsgFile, event.MsgImage, event.MsgVideo: - data, err := downloadMatrixAttachment(portal.MainIntent(), content) - if err != nil { - go portal.sendMessageMetrics(evt, err, "Error downloading media in") - return - } - filename := content.Body - if content.FileName != "" && content.FileName != content.Body { - filename = content.FileName - sendReq.Content, sendReq.AllowedMentions = portal.parseMatrixHTML(content, parseAllowedLinkPreviews(evt.Content.Raw)) - } - - if evt.Content.Raw["page.codeberg.everypizza.msc4193.spoiler"] == true { - filename = "SPOILER_" + filename - } - - if portal.bridge.Config.Bridge.UseDiscordCDNUpload && !isWebhookSend && sess.IsUser { - att := &discordgo.MessageAttachment{ - ID: "0", - Filename: filename, - Description: description, - } - sendReq.Attachments = []*discordgo.MessageAttachment{att} - prep, err := sender.Session.ChannelAttachmentCreate(channelID, &discordgo.ReqPrepareAttachments{ - Files: []*discordgo.FilePrepare{{ - Size: len(data), - Name: att.Filename, - ID: sender.NextDiscordUploadID(), - }}, - }, portal.RefererOpt(threadID)) - if err != nil { - go portal.sendMessageMetrics(evt, err, "Error preparing to reupload media in") - return - } - prepared := prep.Attachments[0] - att.UploadedFilename = prepared.UploadFilename - err = uploadDiscordAttachment(sender.Session.Client, prepared.UploadURL, data) - if err != nil { - go portal.sendMessageMetrics(evt, err, "Error reuploading media in") - return - } - } else { - sendReq.Files = []*discordgo.File{{ - Name: filename, - ContentType: content.Info.MimeType, - Reader: bytes.NewReader(data), - }} - } - default: - go portal.sendMessageMetrics(evt, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType), "Ignoring") - return - } - silentReply := content.Mentions != nil && replyToMXID != "" && - (len(content.Mentions.UserIDs) == 0 || (replyToUser != "" && !slices.Contains(content.Mentions.UserIDs, replyToUser))) - if silentReply && sendReq.AllowedMentions != nil { - sendReq.AllowedMentions.RepliedUser = false - } - if !isWebhookSend { - // AllowedMentions must not be set for real users, and it's also not that useful for personal bots. - // It's only important for relaying, where the webhook may have higher permissions than the user on Matrix. - if silentReply { - sendReq.AllowedMentions = &discordgo.MessageAllowedMentions{ - Parse: []discordgo.AllowedMentionType{discordgo.AllowedMentionTypeUsers, discordgo.AllowedMentionTypeRoles, discordgo.AllowedMentionTypeEveryone}, - RepliedUser: false, - } - } else { - sendReq.AllowedMentions = nil - } - } else if strings.Contains(sendReq.Content, "@everyone") || strings.Contains(sendReq.Content, "@here") { - powerLevels, err := portal.MainIntent().PowerLevels(portal.MXID) - if err != nil { - portal.log.Warn().Err(err). - Str("user_id", sender.MXID.String()). - Msg("Failed to get power levels to check if user can use @everyone") - } else if powerLevels.GetUserLevel(sender.MXID) >= powerLevels.Notifications.Room() { - sendReq.AllowedMentions.Parse = append(sendReq.AllowedMentions.Parse, discordgo.AllowedMentionTypeEveryone) - } - } - sendReq.Nonce = generateNonce() - var msg *discordgo.Message - var err error - if !isWebhookSend { - msg, err = sess.ChannelMessageSendComplex(channelID, &sendReq, portal.RefererOptIfUser(sess, threadID)...) - } else { - username, avatarURL := portal.getRelayUserMeta(sender) - msg, err = relayClient.WebhookThreadExecute(portal.RelayWebhookID, portal.RelayWebhookSecret, true, threadID, &discordgo.WebhookParams{ - Content: sendReq.Content, - Username: username, - AvatarURL: avatarURL, - Files: sendReq.Files, - Components: sendReq.Components, - Embeds: sendReq.Embeds, - AllowedMentions: sendReq.AllowedMentions, - }) - } - sender.handlePossible40002(err) - go portal.sendMessageMetrics(evt, err, "Error sending") - if msg != nil { - dbMsg := portal.bridge.DB.Message.New() - dbMsg.Channel = portal.Key - dbMsg.DiscordID = msg.ID - if len(msg.Attachments) > 0 { - dbMsg.AttachmentID = msg.Attachments[0].ID - } - dbMsg.MXID = evt.ID - if sess != nil { - dbMsg.SenderID = sender.DiscordID - } else { - dbMsg.SenderID = portal.RelayWebhookID - } - dbMsg.SenderMXID = sender.MXID - dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID) - dbMsg.ThreadID = threadID - dbMsg.Insert() - } -} - -func parseAllowedLinkPreviews(raw map[string]any) []string { - if raw == nil { - return nil - } - linkPreviews, ok := raw["com.beeper.linkpreviews"].([]any) - if !ok { - return nil - } - allowedLinkPreviews := make([]string, 0, len(linkPreviews)) - for _, preview := range linkPreviews { - previewMap, ok := preview.(map[string]any) - if !ok { - continue - } - matchedURL, _ := previewMap["matched_url"].(string) - if matchedURL != "" { - allowedLinkPreviews = append(allowedLinkPreviews, matchedURL) - } - } - return allowedLinkPreviews -} - -func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) { - if portal.bridge.Config.Bridge.DeliveryReceipts { - err := portal.bridge.Bot.MarkRead(portal.MXID, eventID) - if err != nil { - portal.log.Warn().Err(err). - Str("event_id", eventID.String()). - Msg("Failed to send delivery receipt") - } - } -} - -func (portal *Portal) HandleMatrixLeave(brSender bridge.User) { - sender := brSender.(*User) - if portal.IsPrivateChat() && sender.DiscordID == portal.Key.Receiver { - portal.log.Debug().Msg("User left private chat portal, cleaning up and deleting...") - portal.cleanup(false) - portal.RemoveMXID() - } else { - portal.cleanupIfEmpty() - } -} - -func (portal *Portal) HandleMatrixKick(brSender bridge.User, brTarget bridge.Ghost) {} -func (portal *Portal) HandleMatrixInvite(brSender bridge.User, brTarget bridge.Ghost) {} - -func (portal *Portal) Delete() { - portal.Portal.Delete() - portal.bridge.portalsLock.Lock() - delete(portal.bridge.portalsByID, portal.Key) - if portal.MXID != "" { - delete(portal.bridge.portalsByMXID, portal.MXID) - } - portal.bridge.portalsLock.Unlock() -} - -func (portal *Portal) cleanupIfEmpty() { - if portal.MXID == "" { - return - } - - users, err := portal.getMatrixUsers() - if err != nil { - portal.log.Err(err).Msg("Failed to get Matrix user list to determine if portal needs to be cleaned up") - return - } - - if len(users) == 0 { - portal.log.Info().Msg("Room seems to be empty, cleaning up...") - portal.cleanup(false) - portal.RemoveMXID() - } -} - -func (portal *Portal) RemoveMXID() { - portal.bridge.portalsLock.Lock() - defer portal.bridge.portalsLock.Unlock() - if portal.MXID == "" { - return - } - delete(portal.bridge.portalsByMXID, portal.MXID) - portal.MXID = "" - portal.log = portal.bridge.ZLog.With(). - Str("channel_id", portal.Key.ChannelID). - Str("channel_receiver", portal.Key.Receiver). - Str("room_id", portal.MXID.String()). - Logger() - portal.AvatarSet = false - portal.NameSet = false - portal.TopicSet = false - portal.Encrypted = false - portal.InSpace = "" - portal.FirstEventID = "" - portal.Update() - portal.bridge.DB.Message.DeleteAll(portal.Key) -} - -func (portal *Portal) cleanup(puppetsOnly bool) { - if portal.MXID == "" { - return - } - intent := portal.MainIntent() - if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - err := intent.BeeperDeleteRoom(portal.MXID) - if err != nil && !errors.Is(err, mautrix.MNotFound) { - portal.log.Err(err).Msg("Failed to delete room using hungryserv yeet endpoint") - } - return - } - - if portal.IsPrivateChat() { - _, err := portal.MainIntent().LeaveRoom(portal.MXID) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to leave private chat portal with main intent") - } - return - } - - portal.bridge.cleanupRoom(intent, portal.MXID, puppetsOnly, portal.log) -} - -func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log zerolog.Logger) { - members, err := intent.JoinedMembers(mxid) - if err != nil { - log.Err(err).Msg("Failed to get portal members for cleanup") - return - } - - for member := range members.Joined { - if member == intent.UserID { - continue - } - - puppet := br.GetPuppetByMXID(member) - if puppet != nil { - _, err = puppet.DefaultIntent().LeaveRoom(mxid) - if err != nil { - log.Err(err).Msg("Error leaving as puppet while cleaning up portal") - } - } else if !puppetsOnly { - _, err = intent.KickUser(mxid, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) - if err != nil { - log.Err(err).Msg("Error kicking user while cleaning up portal") - } - } - } - - _, err = intent.LeaveRoom(mxid) - if err != nil { - log.Err(err).Msg("Error leaving with main intent while cleaning up portal") - } -} - -func (portal *Portal) getMatrixUsers() ([]id.UserID, error) { - members, err := portal.MainIntent().JoinedMembers(portal.MXID) - if err != nil { - return nil, fmt.Errorf("failed to get member list: %w", err) - } - - var users []id.UserID - for userID := range members.Joined { - _, isPuppet := portal.bridge.ParsePuppetMXID(userID) - if !isPuppet && userID != portal.bridge.Bot.UserID { - users = append(users, userID) - } - } - - return users, nil -} - -func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { - if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { - go portal.sendMessageMetrics(evt, errUserNotReceiver, "Ignoring") - return - } else if !sender.IsLoggedIn() { - //go portal.sendMessageMetrics(evt, errReactionUserNotLoggedIn, "Ignoring") - return - } - - reaction := evt.Content.AsReaction() - if reaction.RelatesTo.Type != event.RelAnnotation { - go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownRelationType, reaction.RelatesTo.Type), "Ignoring") - return - } - - if reaction.RelatesTo.Key == JoinThreadReaction { - thread := portal.bridge.GetThreadByRootOrCreationNoticeMXID(reaction.RelatesTo.EventID) - if thread == nil { - go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring thread join") - return - } - thread.Join(sender) - return - } - - msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) - if msg == nil { - go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") - return - } - - firstMsg := msg - if msg.AttachmentID != "" { - firstMsg = portal.bridge.DB.Message.GetFirstByDiscordID(portal.Key, msg.DiscordID) - // TODO should the emoji be rerouted to the first message if it's different? - } - - // Figure out if this is a custom emoji or not. - emojiID := reaction.RelatesTo.Key - if strings.HasPrefix(emojiID, "mxc://") { - uri, _ := id.ParseContentURI(emojiID) - emojiInfo := portal.bridge.DMA.GetEmojiInfo(uri) - if emojiInfo != nil { - emojiID = fmt.Sprintf("%s:%d", emojiInfo.Name, emojiInfo.EmojiID) - } else if emojiFile := portal.bridge.DB.File.GetEmojiByMXC(uri); emojiFile != nil && emojiFile.ID != "" && emojiFile.EmojiName != "" { - emojiID = fmt.Sprintf("%s:%s", emojiFile.EmojiName, emojiFile.ID) - } else { - go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEmoji, emojiID), "Ignoring") - return - } - } else { - emojiID = variationselector.FullyQualify(emojiID) - } - - existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, msg.DiscordID, sender.DiscordID, emojiID) - if existing != nil { - portal.log.Debug(). - Str("event_id", evt.ID.String()). - Str("existing_reaction_mxid", existing.MXID.String()). - Msg("Dropping duplicate Matrix reaction") - go portal.sendMessageMetrics(evt, nil, "") - return - } - - err := sender.Session.MessageReactionAddUser(portal.GuildID, msg.DiscordProtoChannelID(), msg.DiscordID, emojiID) - go portal.sendMessageMetrics(evt, err, "Error sending") - if err == nil { - dbReaction := portal.bridge.DB.Reaction.New() - dbReaction.Channel = portal.Key - dbReaction.MessageID = msg.DiscordID - dbReaction.FirstAttachmentID = firstMsg.AttachmentID - dbReaction.Sender = sender.DiscordID - dbReaction.EmojiName = emojiID - dbReaction.ThreadID = msg.ThreadID - dbReaction.MXID = evt.ID - dbReaction.Insert() - } -} - -func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageReaction, add bool, thread *Thread, member *discordgo.Member) { - puppet := portal.bridge.GetPuppetByID(reaction.UserID) - if member != nil { - puppet.UpdateInfo(user, member.User, nil) - } - intent := puppet.IntentFor(portal) - - log := portal.log.With(). - Str("message_id", reaction.MessageID). - Str("author_id", reaction.UserID). - Bool("add", add). - Str("action", "discord reaction"). - Logger() - - var discordID string - var matrixReaction string - - if reaction.Emoji.ID != "" { - reactionMXC := portal.getEmojiMXCByDiscordID(reaction.Emoji.ID, reaction.Emoji.Name, reaction.Emoji.Animated) - if reactionMXC.IsEmpty() { - return - } - matrixReaction = reactionMXC.String() - discordID = fmt.Sprintf("%s:%s", reaction.Emoji.Name, reaction.Emoji.ID) - } else { - discordID = reaction.Emoji.Name - matrixReaction = variationselector.Add(reaction.Emoji.Name) - } - - // Find the message that we're working with. - message := portal.bridge.DB.Message.GetByDiscordID(portal.Key, reaction.MessageID) - if message == nil { - log.Debug().Msg("Failed to add reaction to message: message not found") - return - } - - // Lookup an existing reaction - existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message[0].DiscordID, reaction.UserID, discordID) - if !add { - if existing == nil { - log.Debug().Msg("Failed to remove reaction: reaction not found") - return - } - - resp, err := intent.RedactEvent(portal.MXID, existing.MXID) - if err != nil { - log.Err(err).Msg("Failed to remove reaction") - } else { - go portal.sendDeliveryReceipt(resp.EventID) - } - - existing.Delete() - return - } else if existing != nil { - log.Debug().Msg("Ignoring duplicate reaction") - return - } - - content := event.ReactionEventContent{ - RelatesTo: event.RelatesTo{ - EventID: message[0].MXID, - Type: event.RelAnnotation, - Key: matrixReaction, - }, - } - extraContent := map[string]any{} - if reaction.Emoji.ID != "" { - extraContent["fi.mau.discord.reaction"] = map[string]any{ - "id": reaction.Emoji.ID, - "name": reaction.Emoji.Name, - "mxc": matrixReaction, - } - wrappedShortcode := fmt.Sprintf(":%s:", reaction.Emoji.Name) - extraContent["com.beeper.reaction.shortcode"] = wrappedShortcode - if !portal.bridge.Config.Bridge.CustomEmojiReactions { - content.RelatesTo.Key = wrappedShortcode - } - } - - resp, err := intent.SendMessageEvent(portal.MXID, event.EventReaction, &event.Content{ - Parsed: &content, - Raw: extraContent, - }) - if err != nil { - log.Err(err).Msg("Failed to send reaction") - return - } - - if existing == nil { - dbReaction := portal.bridge.DB.Reaction.New() - dbReaction.Channel = portal.Key - dbReaction.MessageID = message[0].DiscordID - dbReaction.FirstAttachmentID = message[0].AttachmentID - dbReaction.Sender = reaction.UserID - dbReaction.EmojiName = discordID - dbReaction.MXID = resp.EventID - if thread != nil { - dbReaction.ThreadID = thread.ID - } - dbReaction.Insert() - portal.sendDeliveryReceipt(dbReaction.MXID) - } -} - -func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { - if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { - go portal.sendMessageMetrics(evt, errUserNotReceiver, "Ignoring") - return - } - - sess := sender.Session - if sess == nil && portal.RelayWebhookID == "" { - go portal.sendMessageMetrics(evt, errUserNotLoggedIn, "Ignoring") - return - } - - message := portal.bridge.DB.Message.GetByMXID(portal.Key, evt.Redacts) - if message != nil { - var err error - // TODO add support for deleting individual attachments from messages - if sess != nil { - err = sess.ChannelMessageDelete(message.DiscordProtoChannelID(), message.DiscordID, portal.RefererOptIfUser(sess, message.ThreadID)...) - } else { - // TODO pre-validate that the message was sent by the webhook? - err = relayClient.WebhookMessageDelete(portal.RelayWebhookID, portal.RelayWebhookSecret, message.DiscordID) - } - go portal.sendMessageMetrics(evt, err, "Error sending") - if err == nil { - message.Delete() - } - return - } - - if sess != nil { - reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts) - if reaction != nil && reaction.Channel == portal.Key { - err := sess.MessageReactionRemoveUser(portal.GuildID, reaction.DiscordProtoChannelID(), reaction.MessageID, reaction.EmojiName, reaction.Sender) - go portal.sendMessageMetrics(evt, err, "Error sending") - if err == nil { - reaction.Delete() - } - return - } - } - - go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") -} - -func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receipt event.ReadReceipt) { - sender := brUser.(*User) - if sender.Session == nil { - return - } - var thread *Thread - discordThreadID := "" - if receipt.ThreadID != "" && receipt.ThreadID != event.ReadReceiptThreadMain { - thread = portal.bridge.GetThreadByRootMXID(receipt.ThreadID) - if thread != nil { - discordThreadID = thread.ID - } - } - log := portal.log.With(). - Str("sender", brUser.GetMXID().String()). - Str("event_id", eventID.String()). - Str("action", "matrix read receipt"). - Str("discord_thread_id", discordThreadID). - Logger() - if thread != nil { - if portal.bridge.Config.Bridge.AutojoinThreadOnOpen { - thread.Join(sender) - } - if eventID == thread.CreationNoticeMXID { - log.Debug().Msg("Dropping read receipt for thread creation notice") - return - } - } - if !sender.Session.IsUser { - // Drop read receipts from bot users (after checking for the thread auto-join stuff) - return - } - msg := portal.bridge.DB.Message.GetByMXID(portal.Key, eventID) - if msg == nil { - msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, discordThreadID, receipt.Timestamp) - if msg == nil { - log.Debug().Msg("Dropping read receipt: no messages found") - return - } else { - log = log.With(). - Str("closest_event_id", msg.MXID.String()). - Str("closest_message_id", msg.DiscordID). - Logger() - log.Debug().Msg("Read receipt target event not found, using closest message") - } - } else { - log = log.With(). - Str("message_id", msg.DiscordID). - Logger() - } - if receipt.ThreadID != "" && msg.ThreadID != discordThreadID { - log.Debug(). - Str("receipt_thread_event_id", receipt.ThreadID.String()). - Str("message_discord_thread_id", msg.ThreadID). - Msg("Dropping read receipt: thread ID mismatch") - return - } - resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID, portal.RefererOpt(msg.DiscordProtoChannelID())) - if err != nil { - log.Err(err).Msg("Failed to send read receipt to Discord") - } else if resp.Token != nil { - log.Debug(). - Str("unexpected_resp_token", *resp.Token). - Msg("Marked message as read on Discord (and got unexpected non-nil token)") - } else { - log.Debug().Msg("Marked message as read on Discord") - } -} - -func typingDiff(prev, new []id.UserID) (started []id.UserID) { -OuterNew: - for _, userID := range new { - for _, previousUserID := range prev { - if userID == previousUserID { - continue OuterNew - } - } - started = append(started, userID) - } - return -} - -func (portal *Portal) HandleMatrixTyping(newTyping []id.UserID) { - portal.currentlyTypingLock.Lock() - defer portal.currentlyTypingLock.Unlock() - startedTyping := typingDiff(portal.currentlyTyping, newTyping) - portal.currentlyTyping = newTyping - for _, userID := range startedTyping { - user := portal.bridge.GetUserByMXID(userID) - if user != nil && user.Session != nil { - user.ViewingChannel(portal) - err := user.Session.ChannelTyping(portal.Key.ChannelID, portal.RefererOptIfUser(user.Session, "")...) - if err != nil { - portal.log.Warn().Err(err). - Str("user_id", user.MXID.String()). - Msg("Failed to mark user as typing") - } else { - portal.log.Debug(). - Str("user_id", user.MXID.String()). - Msg("Marked user as typing") - } - } - } -} - -func (portal *Portal) UpdateName(meta *discordgo.Channel) bool { - var parentName, guildName string - if portal.Parent != nil { - parentName = portal.Parent.PlainName - } - if portal.Guild != nil { - guildName = portal.Guild.PlainName - } - plainNameChanged := portal.PlainName != meta.Name - portal.PlainName = meta.Name - return portal.UpdateNameDirect(portal.bridge.Config.Bridge.FormatChannelName(config.ChannelNameParams{ - Name: meta.Name, - ParentName: parentName, - GuildName: guildName, - NSFW: meta.NSFW, - Type: meta.Type, - }), false) || plainNameChanged -} - -func (portal *Portal) UpdateNameDirect(name string, isFriendNick bool) bool { - if portal.FriendNick && !isFriendNick { - return false - } else if portal.Name == name && (portal.NameSet || portal.MXID == "" || (!portal.shouldSetDMRoomMetadata() && !isFriendNick)) { - return false - } - portal.log.Debug(). - Str("old_name", portal.Name). - Str("new_name", name). - Msg("Updating portal name") - portal.Name = name - portal.NameSet = false - portal.updateRoomName() - return true -} - -func (portal *Portal) updateRoomName() { - if portal.MXID != "" && (portal.shouldSetDMRoomMetadata() || portal.FriendNick) { - _, err := portal.MainIntent().SetRoomName(portal.MXID, portal.Name) - if err != nil { - portal.log.Err(err).Msg("Failed to update room name") - } else { - portal.NameSet = true - } - } -} - -func (portal *Portal) UpdateAvatarFromPuppet(puppet *Puppet) bool { - if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (puppet.Avatar == "" || portal.AvatarSet || portal.MXID == "" || !portal.shouldSetDMRoomMetadata()) { - return false - } - portal.log.Debug(). - Str("old_avatar_id", portal.Avatar). - Str("new_avatar_id", puppet.Avatar). - Msg("Updating avatar from puppet") - portal.Avatar = puppet.Avatar - portal.AvatarURL = puppet.AvatarURL - portal.AvatarSet = false - portal.updateRoomAvatar() - return true -} - -func (portal *Portal) UpdateGroupDMAvatar(iconID string) bool { - if portal.Avatar == iconID && (iconID == "") == portal.AvatarURL.IsEmpty() && (iconID == "" || portal.AvatarSet || portal.MXID == "") { - return false - } - portal.log.Debug(). - Str("old_avatar_id", portal.Avatar). - Str("new_avatar_id", portal.Avatar). - Msg("Updating group DM avatar") - portal.Avatar = iconID - portal.AvatarSet = false - portal.AvatarURL = id.ContentURI{} - if portal.Avatar != "" { - // TODO direct media support - copied, err := portal.bridge.copyAttachmentToMatrix(portal.MainIntent(), discordgo.EndpointGroupIcon(portal.Key.ChannelID, portal.Avatar), false, AttachmentMeta{ - AttachmentID: fmt.Sprintf("private_channel_avatar/%s/%s", portal.Key.ChannelID, iconID), - }) - if err != nil { - portal.log.Err(err).Str("avatar_id", iconID).Msg("Failed to reupload channel avatar") - return true - } - portal.AvatarURL = copied.MXC - } - portal.updateRoomAvatar() - return true -} - -func (portal *Portal) updateRoomAvatar() { - if portal.MXID == "" || portal.AvatarURL.IsEmpty() || !portal.shouldSetDMRoomMetadata() { - return - } - _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) - if err != nil { - portal.log.Err(err).Msg("Failed to update room avatar") - } else { - portal.AvatarSet = true - } -} - -func (portal *Portal) UpdateTopic(topic string) bool { - if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { - return false - } - portal.log.Debug(). - Str("old_topic", portal.Topic). - Str("new_topic", topic). - Msg("Updating portal topic") - portal.Topic = topic - portal.TopicSet = false - portal.updateRoomTopic() - return true -} - -func (portal *Portal) updateRoomTopic() { - if portal.MXID != "" { - _, err := portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic) - if err != nil { - portal.log.Err(err).Msg("Failed to update room topic") - } else { - portal.TopicSet = true - } - } -} - -func (portal *Portal) removeFromSpace() { - if portal.InSpace == "" { - return - } - - log := portal.log.With().Str("space_mxid", portal.InSpace.String()).Logger() - log.Debug().Msg("Removing room from space") - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, portal.InSpace.String(), struct{}{}) - if err != nil { - log.Warn().Err(err).Msg("Failed to clear m.space.parent event in room") - } - _, err = portal.bridge.Bot.SendStateEvent(portal.InSpace, event.StateSpaceChild, portal.MXID.String(), struct{}{}) - if err != nil { - log.Warn().Err(err).Msg("Failed to clear m.space.child event in space") - } - portal.InSpace = "" -} - -func (portal *Portal) addToSpace(mxid id.RoomID) bool { - if portal.InSpace == mxid { - return false - } - portal.removeFromSpace() - if mxid == "" { - return true - } - - log := portal.log.With().Str("space_mxid", mxid.String()).Logger() - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, mxid.String(), &event.SpaceParentEventContent{ - Via: []string{portal.bridge.AS.HomeserverDomain}, - Canonical: true, - }) - if err != nil { - log.Warn().Err(err).Msg("Failed to set m.space.parent event in room") - } - - _, err = portal.bridge.Bot.SendStateEvent(mxid, event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ - Via: []string{portal.bridge.AS.HomeserverDomain}, - // TODO order - }) - if err != nil { - log.Warn().Err(err).Msg("Failed to set m.space.child event in space") - } else { - portal.InSpace = mxid - } - return true -} - -func (portal *Portal) UpdateParent(parentID string) bool { - if portal.ParentID == parentID { - return false - } - portal.log.Debug(). - Str("old_parent_id", portal.ParentID). - Str("new_parent_id", parentID). - Msg("Updating parent ID") - portal.ParentID = parentID - if portal.ParentID != "" { - portal.Parent = portal.bridge.GetPortalByID(database.NewPortalKey(parentID, ""), discordgo.ChannelTypeGuildCategory) - } else { - portal.Parent = nil - } - return true -} - -func (portal *Portal) ExpectedSpaceID() id.RoomID { - if portal.Parent != nil { - return portal.Parent.MXID - } else if portal.Guild != nil { - return portal.Guild.MXID - } - return "" -} - -func (portal *Portal) updateSpace(source *User) bool { - if portal.MXID == "" { - return false - } - if portal.Parent != nil { - if portal.Parent.MXID != "" { - portal.log.Warn().Str("parent_id", portal.ParentID).Msg("Parent portal has no Matrix room, creating...") - err := portal.Parent.CreateMatrixRoom(source, nil) - if err != nil { - portal.log.Err(err).Str("parent_id", portal.ParentID).Msg("Failed to create Matrix room for parent") - return false - } - } - return portal.addToSpace(portal.Parent.MXID) - } else if portal.Guild != nil { - return portal.addToSpace(portal.Guild.MXID) - } - return false -} - -func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discordgo.Channel { - changed := false - - log := portal.log.With(). - Str("action", "update info"). - Str("through_user_mxid", source.MXID.String()). - Str("through_user_dcid", source.DiscordID). - Logger() - - if meta == nil { - log.Debug().Msg("UpdateInfo called without metadata, fetching from user's state cache") - meta, _ = source.Session.State.Channel(portal.Key.ChannelID) - if meta == nil { - log.Warn().Msg("No metadata found in state cache, fetching from server via user") - var err error - meta, err = source.Session.Channel(portal.Key.ChannelID) - if err != nil { - log.Err(err).Msg("Failed to fetch meta via user") - return nil - } - } - } - - if portal.Type != meta.Type { - log.Warn(). - Int("old_type", int(portal.Type)). - Int("new_type", int(meta.Type)). - Msg("Portal type changed") - portal.Type = meta.Type - changed = true - } - if portal.OtherUserID == "" && portal.IsPrivateChat() { - if len(meta.Recipients) == 0 { - var err error - meta, err = source.Session.Channel(meta.ID) - if err != nil { - log.Err(err).Msg("Failed to fetch DM channel info to find other user ID") - } - } - if len(meta.Recipients) > 0 { - portal.OtherUserID = meta.Recipients[0].ID - log.Info().Str("other_user_id", portal.OtherUserID).Msg("Found other user ID") - changed = true - } - } - if meta.GuildID != "" && portal.GuildID == "" { - portal.GuildID = meta.GuildID - portal.Guild = portal.bridge.GetGuildByID(portal.GuildID, true) - changed = true - } - - switch portal.Type { - case discordgo.ChannelTypeDM: - if portal.OtherUserID != "" { - puppet := portal.bridge.GetPuppetByID(portal.OtherUserID) - changed = portal.UpdateAvatarFromPuppet(puppet) || changed - if rel, ok := source.relationships[portal.OtherUserID]; ok && rel.Nickname != "" { - portal.FriendNick = true - changed = portal.UpdateNameDirect(rel.Nickname, true) || changed - } else { - portal.FriendNick = false - changed = portal.UpdateNameDirect(puppet.Name, false) || changed - } - } - if portal.MXID != "" { - portal.syncParticipants(source, meta.Recipients) - } - case discordgo.ChannelTypeGroupDM: - changed = portal.UpdateGroupDMAvatar(meta.Icon) || changed - if portal.MXID != "" { - portal.syncParticipants(source, meta.Recipients) - } - fallthrough - default: - changed = portal.UpdateName(meta) || changed - if portal.MXID != "" { - portal.ensureUserInvited(source, false) - } - } - changed = portal.UpdateTopic(meta.Topic) || changed - changed = portal.UpdateParent(meta.ParentID) || changed - // Private channels are added to the space in User.handlePrivateChannel - if portal.GuildID != "" && portal.MXID != "" && portal.ExpectedSpaceID() != portal.InSpace { - changed = portal.updateSpace(source) || changed - } - if changed { - portal.UpdateBridgeInfo() - portal.Update() - } - return meta -} - -func (br *DiscordBridge) HandleTombstone(evt *event.Event) { - if evt.StateKey == nil || *evt.StateKey != "" { - return - } - content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) - if !ok { - return - } - defer br.MatrixHandler.TrackEventDuration(evt.Type)() - portal := br.GetPortalByMXID(evt.RoomID) - if portal == nil { - return - } - logEvt := portal.log.Debug(). - Stringer("sender", evt.Sender). - Stringer("replacement_room", content.ReplacementRoom). - Str("body", content.Body) - if content.ReplacementRoom == "" { - logEvt.Msg("Received tombstone event with no replacement room, cleaning up portal") - portal.cleanup(true) - portal.RemoveMXID() - return - } - logEvt.Msg("Received tombstone event, joining new room") - _, err := br.Bot.JoinRoom(content.ReplacementRoom.String(), evt.Sender.Homeserver(), nil) - if err != nil { - portal.log.Err(err).Msg("Failed to join replacement room") - return - } - _, err = br.Bot.State(content.ReplacementRoom) - if err != nil { - portal.log.Err(err).Msg("Failed to get state of replacement room") - return - } - - encrypted := br.AS.StateStore.IsEncrypted(portal.MXID) - br.portalsLock.Lock() - defer br.portalsLock.Unlock() - if portal.MXID != evt.RoomID { - portal.log.Warn(). - Stringer("old_mxid", evt.RoomID). - Stringer("new_mxid", portal.MXID). - Msg("Portal MXID changed while processing tombstone event, not updating") - return - } - _, alreadyAPortal := br.portalsByMXID[content.ReplacementRoom] - if alreadyAPortal { - portal.log.Warn(). - Stringer("replacement_room", content.ReplacementRoom). - Msg("Replacement room is already a portal, not updating") - return - } - delete(portal.bridge.portalsByMXID, portal.MXID) - portal.MXID = content.ReplacementRoom - portal.bridge.portalsByMXID[portal.MXID] = portal - portal.log = portal.bridge.ZLog.With(). - Str("channel_id", portal.Key.ChannelID). - Str("channel_receiver", portal.Key.Receiver). - Str("room_id", portal.MXID.String()). - Logger() - portal.AvatarSet = false - portal.NameSet = false - portal.TopicSet = false - portal.Encrypted = encrypted - portal.InSpace = "" - portal.FirstEventID = "" - portal.Update() - portal.log.Info().Msg("Followed tombstone and updated portal MXID") - portal.UpdateBridgeInfo() -} diff --git a/portal_convert.go b/portal_convert.go deleted file mode 100644 index 6823e2c..0000000 --- a/portal_convert.go +++ /dev/null @@ -1,779 +0,0 @@ -// mautrix-discord - A Matrix-Discord puppeting bridge. -// Copyright (C) 2023 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package main - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "html" - "strconv" - "strings" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/rs/zerolog" - "golang.org/x/exp/slices" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" -) - -type ConvertedMessage struct { - AttachmentID string - - Type event.Type - Content *event.MessageEventContent - Extra map[string]any -} - -func (portal *Portal) createMediaFailedMessage(bridgeErr error) *event.MessageEventContent { - return &event.MessageEventContent{ - Body: fmt.Sprintf("Failed to bridge media: %v", bridgeErr), - MsgType: event.MsgNotice, - } -} - -const DiscordStickerSize = 160 - -func (portal *Portal) convertDiscordFile(ctx context.Context, typeName string, intent *appservice.IntentAPI, id, url string, content *event.MessageEventContent) *event.MessageEventContent { - meta := AttachmentMeta{AttachmentID: id, MimeType: content.Info.MimeType} - if typeName == "sticker" && content.Info.MimeType == "application/json" { - meta.Converter = portal.bridge.convertLottie - } - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, url, portal.Encrypted, meta) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to copy attachment to Matrix") - return portal.createMediaFailedMessage(err) - } - if typeName == "sticker" && content.Info.MimeType == "application/json" { - content.Info.MimeType = dbFile.MimeType - } - content.Info.Size = dbFile.Size - if content.Info.Width == 0 && content.Info.Height == 0 { - content.Info.Width = dbFile.Width - content.Info.Height = dbFile.Height - } - if dbFile.DecryptionInfo != nil { - content.File = &event.EncryptedFileInfo{ - EncryptedFile: *dbFile.DecryptionInfo, - URL: dbFile.MXC.CUString(), - } - } else { - content.URL = dbFile.MXC.CUString() - } - return content -} - -func (portal *Portal) cleanupConvertedStickerInfo(content *event.MessageEventContent) { - if content.Info == nil { - return - } - if content.Info.Width == 0 && content.Info.Height == 0 { - content.Info.Width = DiscordStickerSize - content.Info.Height = DiscordStickerSize - } else if content.Info.Width > DiscordStickerSize || content.Info.Height > DiscordStickerSize { - if content.Info.Width > content.Info.Height { - content.Info.Height /= content.Info.Width / DiscordStickerSize - content.Info.Width = DiscordStickerSize - } else if content.Info.Width < content.Info.Height { - content.Info.Width /= content.Info.Height / DiscordStickerSize - content.Info.Height = DiscordStickerSize - } else { - content.Info.Width = DiscordStickerSize - content.Info.Height = DiscordStickerSize - } - } -} - -func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appservice.IntentAPI, sticker *discordgo.StickerItem) *ConvertedMessage { - var mime string - switch sticker.FormatType { - case discordgo.StickerFormatTypePNG: - mime = "image/png" - case discordgo.StickerFormatTypeAPNG: - mime = "image/apng" - case discordgo.StickerFormatTypeLottie: - mime = "application/json" - case discordgo.StickerFormatTypeGIF: - mime = "image/gif" - default: - zerolog.Ctx(ctx).Warn(). - Int("sticker_format", int(sticker.FormatType)). - Str("sticker_id", sticker.ID). - Msg("Unknown sticker format") - } - content := &event.MessageEventContent{ - Body: sticker.Name, // TODO find description from somewhere? - Info: &event.FileInfo{ - MimeType: mime, - }, - } - - mxc := portal.bridge.DMA.StickerMXC(sticker.ID, sticker.FormatType) - // TODO add config option to use direct media even for lottie stickers - if mxc.IsEmpty() && mime != "application/json" { - content = portal.convertDiscordFile(ctx, "sticker", intent, sticker.ID, sticker.URL(), content) - } else { - content.URL = mxc.CUString() - } - portal.cleanupConvertedStickerInfo(content) - return &ConvertedMessage{ - AttachmentID: sticker.ID, - Type: event.EventSticker, - Content: content, - } -} - -func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *appservice.IntentAPI, messageID string, att *discordgo.MessageAttachment) *ConvertedMessage { - content := &event.MessageEventContent{ - Body: att.Filename, - Info: &event.FileInfo{ - Height: att.Height, - MimeType: att.ContentType, - Width: att.Width, - - // This gets overwritten later after the file is uploaded to the homeserver - Size: att.Size, - }, - } - - var extra = make(map[string]any) - - if strings.HasPrefix(att.Filename, "SPOILER_") { - extra["page.codeberg.everypizza.msc4193.spoiler"] = true - } - - if att.Description != "" { - content.Body = att.Description - content.FileName = att.Filename - } - - switch strings.ToLower(strings.Split(att.ContentType, "/")[0]) { - case "audio": - content.MsgType = event.MsgAudio - if att.Waveform != nil { - // TODO convert waveform - extra["org.matrix.msc1767.audio"] = map[string]any{ - "duration": int(att.DurationSeconds * 1000), - } - extra["org.matrix.msc3245.voice"] = map[string]any{} - } - case "image": - content.MsgType = event.MsgImage - case "video": - content.MsgType = event.MsgVideo - default: - content.MsgType = event.MsgFile - } - mxc := portal.bridge.DMA.AttachmentMXC(portal.Key.ChannelID, messageID, att) - if mxc.IsEmpty() { - content = portal.convertDiscordFile(ctx, "attachment", intent, att.ID, att.URL, content) - } else { - content.URL = mxc.CUString() - } - return &ConvertedMessage{ - AttachmentID: att.ID, - Type: event.EventMessage, - Content: content, - Extra: extra, - } -} - -func (portal *Portal) convertDiscordVideoEmbed(ctx context.Context, intent *appservice.IntentAPI, embed *discordgo.MessageEmbed) *ConvertedMessage { - attachmentID := fmt.Sprintf("video_%s", embed.URL) - var proxyURL string - if embed.Video != nil { - proxyURL = embed.Video.ProxyURL - } else if embed.Thumbnail != nil { - proxyURL = embed.Thumbnail.ProxyURL - } else { - zerolog.Ctx(ctx).Warn().Str("embed_url", embed.URL).Msg("No video or thumbnail proxy URL found in embed") - return &ConvertedMessage{ - AttachmentID: attachmentID, - Type: event.EventMessage, - Content: &event.MessageEventContent{ - Body: "Failed to bridge media: no video or thumbnail proxy URL found in embed", - MsgType: event.MsgNotice, - }, - } - } - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, proxyURL, portal.Encrypted, NoMeta) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to copy video embed to Matrix") - return &ConvertedMessage{ - AttachmentID: attachmentID, - Type: event.EventMessage, - Content: portal.createMediaFailedMessage(err), - } - } - - content := &event.MessageEventContent{ - Body: embed.URL, - Info: &event.FileInfo{ - MimeType: dbFile.MimeType, - Size: dbFile.Size, - }, - } - if embed.Video != nil { - content.MsgType = event.MsgVideo - content.Info.Width = embed.Video.Width - content.Info.Height = embed.Video.Height - } else { - content.MsgType = event.MsgImage - content.Info.Width = embed.Thumbnail.Width - content.Info.Height = embed.Thumbnail.Height - } - if content.Info.Width == 0 && content.Info.Height == 0 { - content.Info.Width = dbFile.Width - content.Info.Height = dbFile.Height - } - if dbFile.DecryptionInfo != nil { - content.File = &event.EncryptedFileInfo{ - EncryptedFile: *dbFile.DecryptionInfo, - URL: dbFile.MXC.CUString(), - } - } else { - content.URL = dbFile.MXC.CUString() - } - extra := map[string]any{} - if content.MsgType == event.MsgVideo && embed.Type == discordgo.EmbedTypeGifv { - extra["info"] = map[string]any{ - "fi.mau.discord.gifv": true, - "fi.mau.gif": true, - "fi.mau.loop": true, - "fi.mau.autoplay": true, - "fi.mau.hide_controls": true, - "fi.mau.no_audio": true, - } - } - return &ConvertedMessage{ - AttachmentID: attachmentID, - Type: event.EventMessage, - Content: content, - Extra: extra, - } -} - -func (portal *Portal) convertDiscordMessage(ctx context.Context, puppet *Puppet, intent *appservice.IntentAPI, msg *discordgo.Message) []*ConvertedMessage { - predictedLength := len(msg.Attachments) + len(msg.StickerItems) - if msg.Content != "" { - predictedLength++ - } - parts := make([]*ConvertedMessage, 0, predictedLength) - if textPart := portal.convertDiscordTextMessage(ctx, intent, msg); textPart != nil { - parts = append(parts, textPart) - } - log := zerolog.Ctx(ctx) - handledIDs := make(map[string]struct{}) - for _, att := range msg.Attachments { - if _, handled := handledIDs[att.ID]; handled { - continue - } - handledIDs[att.ID] = struct{}{} - log := log.With().Str("attachment_id", att.ID).Logger() - if part := portal.convertDiscordAttachment(log.WithContext(ctx), intent, msg.ID, att); part != nil { - parts = append(parts, part) - } - } - for _, sticker := range msg.StickerItems { - if _, handled := handledIDs[sticker.ID]; handled { - continue - } - handledIDs[sticker.ID] = struct{}{} - log := log.With().Str("sticker_id", sticker.ID).Logger() - if part := portal.convertDiscordSticker(log.WithContext(ctx), intent, sticker); part != nil { - parts = append(parts, part) - } - } - for i, embed := range msg.Embeds { - // Ignore non-video embeds, they're handled in convertDiscordTextMessage - if getEmbedType(msg, embed) != EmbedVideo { - continue - } - // Discord deduplicates embeds by URL. It makes things easier for us too. - if _, handled := handledIDs[embed.URL]; handled { - continue - } - handledIDs[embed.URL] = struct{}{} - log := log.With(). - Str("computed_embed_type", "video"). - Str("embed_type", string(embed.Type)). - Int("embed_index", i). - Logger() - part := portal.convertDiscordVideoEmbed(log.WithContext(ctx), intent, embed) - if part != nil { - parts = append(parts, part) - } - } - if len(parts) == 0 && msg.Thread != nil { - parts = append(parts, &ConvertedMessage{Type: event.EventMessage, Content: &event.MessageEventContent{ - MsgType: event.MsgText, - Body: fmt.Sprintf("Created a thread: %s", msg.Thread.Name), - }}) - } - for _, part := range parts { - puppet.addWebhookMeta(part, msg) - puppet.addMemberMeta(part, msg) - } - return parts -} - -func (puppet *Puppet) addMemberMeta(part *ConvertedMessage, msg *discordgo.Message) { - if msg.Member == nil { - return - } - if part.Extra == nil { - part.Extra = make(map[string]any) - } - var avatarURL id.ContentURI - var discordAvatarURL string - if msg.Member.Avatar != "" { - var err error - avatarURL, discordAvatarURL, err = puppet.bridge.reuploadUserAvatar(puppet.DefaultIntent(), msg.GuildID, msg.Author.ID, msg.Member.Avatar) - if err != nil { - puppet.log.Warn().Err(err). - Str("avatar_id", msg.Member.Avatar). - Msg("Failed to reupload guild user avatar") - } - } - part.Extra["fi.mau.discord.guild_member_metadata"] = map[string]any{ - "nick": msg.Member.Nick, - "avatar_id": msg.Member.Avatar, - "avatar_url": discordAvatarURL, - "avatar_mxc": avatarURL.String(), - } - if msg.Member.Nick != "" || !avatarURL.IsEmpty() { - perMessageProfile := map[string]any{ - "id": fmt.Sprintf("%s_%s", msg.GuildID, msg.Author.ID), - "displayname": msg.Member.Nick, - "avatar_url": avatarURL.String(), - } - if msg.Member.Nick == "" { - perMessageProfile["displayname"] = puppet.Name - } - if avatarURL.IsEmpty() { - perMessageProfile["avatar_url"] = puppet.AvatarURL.String() - } - part.Extra["com.beeper.per_message_profile"] = perMessageProfile - } -} - -func (puppet *Puppet) addWebhookMeta(part *ConvertedMessage, msg *discordgo.Message) { - if msg.WebhookID == "" { - return - } - if part.Extra == nil { - part.Extra = make(map[string]any) - } - var avatarURL id.ContentURI - if msg.Author.Avatar != "" { - var err error - avatarURL, _, err = puppet.bridge.reuploadUserAvatar(puppet.DefaultIntent(), "", msg.Author.ID, msg.Author.Avatar) - if err != nil { - puppet.log.Warn().Err(err). - Str("avatar_id", msg.Author.Avatar). - Msg("Failed to reupload webhook avatar") - } - } - part.Extra["fi.mau.discord.webhook_metadata"] = map[string]any{ - "id": msg.WebhookID, - "name": msg.Author.Username, - "avatar_id": msg.Author.Avatar, - "avatar_url": msg.Author.AvatarURL(""), - "avatar_mxc": avatarURL.String(), - } - profileID := sha256.Sum256(fmt.Appendf(nil, "%s:%s", msg.Author.Username, msg.Author.Avatar)) - hasFallback := false - if msg.ApplicationID == "" && - puppet.bridge.Config.Bridge.PrefixWebhookMessages && - (part.Content.MsgType == event.MsgText || part.Content.MsgType == event.MsgNotice || (part.Content.FileName != "" && part.Content.FileName != part.Content.Body)) { - part.Content.EnsureHasHTML() - part.Content.Body = fmt.Sprintf("%s: %s", msg.Author.Username, part.Content.Body) - part.Content.FormattedBody = fmt.Sprintf("%s: %s", html.EscapeString(msg.Author.Username), part.Content.FormattedBody) - hasFallback = true - } - part.Extra["com.beeper.per_message_profile"] = map[string]any{ - "id": hex.EncodeToString(profileID[:]), - "avatar_url": avatarURL.String(), - "displayname": msg.Author.Username, - "has_fallback": hasFallback, - } -} - -const ( - embedHTMLWrapper = `
    %s
    ` - embedHTMLWrapperColor = `
    %s
    ` - embedHTMLAuthorWithImage = `

     %s

    ` - embedHTMLAuthorPlain = `

    %s

    ` - embedHTMLAuthorLink = `%s` - embedHTMLTitleWithLink = `

    %s

    ` - embedHTMLTitlePlain = `

    %s

    ` - embedHTMLDescription = `

    %s

    ` - embedHTMLFieldName = `%s` - embedHTMLFieldValue = `%s` - embedHTMLFields = `%s%s
    ` - embedHTMLLinearField = `

    %s
    %s

    ` - embedHTMLImage = `

    ` - embedHTMLFooterWithImage = `` - embedHTMLFooterPlain = `` - embedHTMLFooterOnlyDate = `` - embedHTMLDate = `` - embedFooterDateSeparator = ` • ` -) - -func (portal *Portal) convertDiscordRichEmbed(ctx context.Context, intent *appservice.IntentAPI, embed *discordgo.MessageEmbed, msgID string, index int) string { - log := zerolog.Ctx(ctx) - var htmlParts []string - if embed.Author != nil { - var authorHTML string - authorNameHTML := html.EscapeString(embed.Author.Name) - if embed.Author.URL != "" { - authorNameHTML = fmt.Sprintf(embedHTMLAuthorLink, embed.Author.URL, authorNameHTML) - } - authorHTML = fmt.Sprintf(embedHTMLAuthorPlain, authorNameHTML) - if embed.Author.ProxyIconURL != "" { - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, embed.Author.ProxyIconURL, false, NoMeta) - if err != nil { - log.Warn().Err(err).Msg("Failed to reupload author icon in embed") - } else { - authorHTML = fmt.Sprintf(embedHTMLAuthorWithImage, dbFile.MXC, authorNameHTML) - } - } - htmlParts = append(htmlParts, authorHTML) - } - if embed.Title != "" { - var titleHTML string - baseTitleHTML := portal.renderDiscordMarkdownOnlyHTML(embed.Title, false) - if embed.URL != "" { - titleHTML = fmt.Sprintf(embedHTMLTitleWithLink, html.EscapeString(embed.URL), baseTitleHTML) - } else { - titleHTML = fmt.Sprintf(embedHTMLTitlePlain, baseTitleHTML) - } - htmlParts = append(htmlParts, titleHTML) - } - if embed.Description != "" { - htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLDescription, portal.renderDiscordMarkdownOnlyHTML(embed.Description, true))) - } - for i := 0; i < len(embed.Fields); i++ { - item := embed.Fields[i] - if portal.bridge.Config.Bridge.EmbedFieldsAsTables { - splitItems := []*discordgo.MessageEmbedField{item} - if item.Inline && len(embed.Fields) > i+1 && embed.Fields[i+1].Inline { - splitItems = append(splitItems, embed.Fields[i+1]) - i++ - if len(embed.Fields) > i+1 && embed.Fields[i+1].Inline { - splitItems = append(splitItems, embed.Fields[i+1]) - i++ - } - } - headerParts := make([]string, len(splitItems)) - contentParts := make([]string, len(splitItems)) - for j, splitItem := range splitItems { - headerParts[j] = fmt.Sprintf(embedHTMLFieldName, portal.renderDiscordMarkdownOnlyHTML(splitItem.Name, false)) - contentParts[j] = fmt.Sprintf(embedHTMLFieldValue, portal.renderDiscordMarkdownOnlyHTML(splitItem.Value, true)) - } - htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLFields, strings.Join(headerParts, ""), strings.Join(contentParts, ""))) - } else { - htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLLinearField, - strconv.FormatBool(item.Inline), - portal.renderDiscordMarkdownOnlyHTML(item.Name, false), - portal.renderDiscordMarkdownOnlyHTML(item.Value, true), - )) - } - } - if embed.Image != nil { - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, embed.Image.ProxyURL, false, NoMeta) - if err != nil { - log.Warn().Err(err).Msg("Failed to reupload image in embed") - } else { - htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLImage, dbFile.MXC)) - } - } - var embedDateHTML string - if embed.Timestamp != "" { - formattedTime := embed.Timestamp - parsedTS, err := time.Parse(time.RFC3339, embed.Timestamp) - if err != nil { - log.Warn().Err(err).Msg("Failed to parse timestamp in embed") - } else { - formattedTime = parsedTS.Format(discordTimestampStyle('F').Format()) - } - embedDateHTML = fmt.Sprintf(embedHTMLDate, embed.Timestamp, formattedTime) - } - if embed.Footer != nil { - var footerHTML string - var datePart string - if embedDateHTML != "" { - datePart = embedFooterDateSeparator + embedDateHTML - } - footerHTML = fmt.Sprintf(embedHTMLFooterPlain, html.EscapeString(embed.Footer.Text), datePart) - if embed.Footer.ProxyIconURL != "" { - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, embed.Footer.ProxyIconURL, false, NoMeta) - if err != nil { - log.Warn().Err(err).Msg("Failed to reupload footer icon in embed") - } else { - footerHTML = fmt.Sprintf(embedHTMLFooterWithImage, dbFile.MXC, html.EscapeString(embed.Footer.Text), datePart) - } - } - htmlParts = append(htmlParts, footerHTML) - } else if embed.Timestamp != "" { - htmlParts = append(htmlParts, fmt.Sprintf(embedHTMLFooterOnlyDate, embedDateHTML)) - } - - if len(htmlParts) == 0 { - return "" - } - - compiledHTML := strings.Join(htmlParts, "") - if embed.Color != 0 { - compiledHTML = fmt.Sprintf(embedHTMLWrapperColor, embed.Color, compiledHTML) - } else { - compiledHTML = fmt.Sprintf(embedHTMLWrapper, compiledHTML) - } - return compiledHTML -} - -type BeeperLinkPreview struct { - mautrix.RespPreviewURL - MatchedURL string `json:"matched_url"` - ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` -} - -func (portal *Portal) convertDiscordLinkEmbedImage(ctx context.Context, intent *appservice.IntentAPI, url string, width, height int, preview *BeeperLinkPreview) { - dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, url, portal.Encrypted, NoMeta) - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to reupload image in URL preview") - return - } - if width != 0 || height != 0 { - preview.ImageWidth = width - preview.ImageHeight = height - } else { - preview.ImageWidth = dbFile.Width - preview.ImageHeight = dbFile.Height - } - preview.ImageSize = dbFile.Size - preview.ImageType = dbFile.MimeType - if dbFile.Encrypted { - preview.ImageEncryption = &event.EncryptedFileInfo{ - EncryptedFile: *dbFile.DecryptionInfo, - URL: dbFile.MXC.CUString(), - } - } else { - preview.ImageURL = dbFile.MXC.CUString() - } -} - -func (portal *Portal) convertDiscordLinkEmbedToBeeper(ctx context.Context, intent *appservice.IntentAPI, embed *discordgo.MessageEmbed) *BeeperLinkPreview { - var preview BeeperLinkPreview - preview.MatchedURL = embed.URL - preview.Title = embed.Title - preview.Description = embed.Description - if embed.Image != nil { - portal.convertDiscordLinkEmbedImage(ctx, intent, embed.Image.ProxyURL, embed.Image.Width, embed.Image.Height, &preview) - } else if embed.Thumbnail != nil { - portal.convertDiscordLinkEmbedImage(ctx, intent, embed.Thumbnail.ProxyURL, embed.Thumbnail.Width, embed.Thumbnail.Height, &preview) - } - return &preview -} - -const msgInteractionTemplateHTML = `
    -%s used /%s -
    ` - -const msgComponentTemplateHTML = `

    This message contains interactive elements. Use the Discord app to interact with the message.

    ` - -type BridgeEmbedType int - -const ( - EmbedUnknown BridgeEmbedType = iota - EmbedRich - EmbedLinkPreview - EmbedVideo -) - -func isActuallyLinkPreview(embed *discordgo.MessageEmbed) bool { - // Sending YouTube links creates a video embed, but we want to bridge it as a URL preview, - // so this is a hacky way to detect those. - return embed.Video != nil && embed.Video.ProxyURL == "" -} - -func getEmbedType(msg *discordgo.Message, embed *discordgo.MessageEmbed) BridgeEmbedType { - switch embed.Type { - case discordgo.EmbedTypeLink, discordgo.EmbedTypeArticle: - return EmbedLinkPreview - case discordgo.EmbedTypeVideo: - if isActuallyLinkPreview(embed) { - return EmbedLinkPreview - } - return EmbedVideo - case discordgo.EmbedTypeGifv: - return EmbedVideo - case discordgo.EmbedTypeImage: - if msg != nil && isPlainGifMessage(msg) { - return EmbedVideo - } else if embed.Image == nil && embed.Thumbnail != nil { - return EmbedLinkPreview - } - return EmbedRich - case discordgo.EmbedTypeRich: - return EmbedRich - default: - return EmbedUnknown - } -} - -func isPlainGifMessage(msg *discordgo.Message) bool { - if len(msg.Embeds) != 1 { - return false - } - embed := msg.Embeds[0] - isGifVideo := embed.Type == discordgo.EmbedTypeGifv && embed.Video != nil - isGifImage := embed.Type == discordgo.EmbedTypeImage && embed.Image == nil && embed.Thumbnail != nil && embed.Title == "" - contentIsOnlyURL := msg.Content == embed.URL || discordLinkRegexFull.MatchString(msg.Content) - return contentIsOnlyURL && (isGifVideo || isGifImage) -} - -func (portal *Portal) convertDiscordMentions(msg *discordgo.Message, syncGhosts bool) *event.Mentions { - var matrixMentions event.Mentions - for _, mention := range msg.Mentions { - puppet := portal.bridge.GetPuppetByID(mention.ID) - if syncGhosts { - puppet.UpdateInfo(nil, mention, nil) - } - user := portal.bridge.GetUserByID(mention.ID) - if user != nil { - matrixMentions.UserIDs = append(matrixMentions.UserIDs, user.MXID) - } else { - matrixMentions.UserIDs = append(matrixMentions.UserIDs, puppet.MXID) - } - } - slices.Sort(matrixMentions.UserIDs) - matrixMentions.UserIDs = slices.Compact(matrixMentions.UserIDs) - if msg.MentionEveryone { - matrixMentions.Room = true - } - return &matrixMentions -} - -const forwardTemplateHTML = `
    -

    ↷ Forwarded

    -%s -

    %s

    -
    ` - -func (portal *Portal) convertDiscordTextMessage(ctx context.Context, intent *appservice.IntentAPI, msg *discordgo.Message) *ConvertedMessage { - log := zerolog.Ctx(ctx) - if msg.Type == discordgo.MessageTypeCall { - return &ConvertedMessage{Type: event.EventMessage, Content: &event.MessageEventContent{ - MsgType: event.MsgEmote, - Body: "started a call", - }} - } else if msg.Type == discordgo.MessageTypeGuildMemberJoin { - return &ConvertedMessage{Type: event.EventMessage, Content: &event.MessageEventContent{ - MsgType: event.MsgEmote, - Body: "joined the server", - }} - } - var htmlParts []string - if msg.Interaction != nil { - puppet := portal.bridge.GetPuppetByID(msg.Interaction.User.ID) - puppet.UpdateInfo(nil, msg.Interaction.User, nil) - htmlParts = append(htmlParts, fmt.Sprintf(msgInteractionTemplateHTML, puppet.MXID, puppet.Name, msg.Interaction.Name)) - } - if msg.Content != "" && !isPlainGifMessage(msg) { - htmlParts = append(htmlParts, portal.renderDiscordMarkdownOnlyHTML(msg.Content, true)) - } else if msg.MessageReference != nil && - msg.MessageReference.Type == discordgo.MessageReferenceTypeForward && - len(msg.MessageSnapshots) > 0 && - msg.MessageSnapshots[0].Message != nil { - forwardedHTML := portal.renderDiscordMarkdownOnlyHTMLNoUnwrap(msg.MessageSnapshots[0].Message.Content, true) - msgTSText := msg.MessageSnapshots[0].Message.Timestamp.Format("2006-01-02 15:04 MST") - origLink := fmt.Sprintf("unknown channel • %s", msgTSText) - forwardedFromPortal := portal.bridge.GetExistingPortalByID(database.NewPortalKey(msg.MessageReference.ChannelID, "")) - if forwardedFromPortal != nil { - origMessage := portal.bridge.DB.Message.GetFirstByDiscordID(forwardedFromPortal.Key, msg.MessageReference.MessageID) - if origMessage != nil { - origLink = fmt.Sprintf( - `#%s • %s`, - forwardedFromPortal.MXID.EventURI(origMessage.MXID, portal.bridge.AS.HomeserverDomain), - forwardedFromPortal.PlainName, - msgTSText, - ) - } else if forwardedFromPortal.MXID != "" { - origLink = fmt.Sprintf( - `#%s • %s`, - forwardedFromPortal.MXID.URI(portal.bridge.AS.HomeserverDomain), - forwardedFromPortal.PlainName, - msgTSText, - ) - } else if forwardedFromPortal.PlainName != "" { - origLink = fmt.Sprintf("%s • %s", forwardedFromPortal.PlainName, msgTSText) - } - } - - htmlParts = append(htmlParts, fmt.Sprintf(forwardTemplateHTML, forwardedHTML, origLink)) - } - previews := make([]*BeeperLinkPreview, 0) - for i, embed := range msg.Embeds { - if i == 0 && msg.MessageReference == nil && isReplyEmbed(embed) { - continue - } - with := log.With(). - Str("embed_type", string(embed.Type)). - Int("embed_index", i) - switch getEmbedType(msg, embed) { - case EmbedRich: - log := with.Str("computed_embed_type", "rich").Logger() - htmlParts = append(htmlParts, portal.convertDiscordRichEmbed(log.WithContext(ctx), intent, embed, msg.ID, i)) - case EmbedLinkPreview: - log := with.Str("computed_embed_type", "link preview").Logger() - previews = append(previews, portal.convertDiscordLinkEmbedToBeeper(log.WithContext(ctx), intent, embed)) - case EmbedVideo: - // Ignore video embeds, they're handled as separate messages - default: - log := with.Logger() - log.Warn().Msg("Unknown embed type in message") - } - } - - if len(msg.Components) > 0 { - htmlParts = append(htmlParts, msgComponentTemplateHTML) - } - - if len(htmlParts) == 0 { - return nil - } - - fullHTML := strings.Join(htmlParts, "\n") - if !msg.MentionEveryone { - fullHTML = strings.ReplaceAll(fullHTML, "@room", "@\u2063ro\u2063om") - } - - content := format.HTMLToContent(fullHTML) - extraContent := map[string]any{ - "com.beeper.linkpreviews": previews, - } - - return &ConvertedMessage{Type: event.EventMessage, Content: &content, Extra: extraContent} -} diff --git a/provisioning.go b/provisioning.go deleted file mode 100644 index c9ff3ab..0000000 --- a/provisioning.go +++ /dev/null @@ -1,552 +0,0 @@ -package main - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "net" - "net/http" - _ "net/http/pprof" - "strings" - "time" - - "github.com/gorilla/mux" - "github.com/gorilla/websocket" - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" - "go.mau.fi/mautrix-discord/remoteauth" -) - -const ( - SecWebSocketProtocol = "com.gitlab.beeper.discord" -) - -const ( - ErrCodeNotConnected = "FI.MAU.DISCORD.NOT_CONNECTED" - ErrCodeAlreadyLoggedIn = "FI.MAU.DISCORD.ALREADY_LOGGED_IN" - ErrCodeAlreadyConnected = "FI.MAU.DISCORD.ALREADY_CONNECTED" - ErrCodeConnectFailed = "FI.MAU.DISCORD.CONNECT_FAILED" - ErrCodeDisconnectFailed = "FI.MAU.DISCORD.DISCONNECT_FAILED" - ErrCodeGuildBridgeFailed = "M_UNKNOWN" - ErrCodeGuildUnbridgeFailed = "M_UNKNOWN" - ErrCodeGuildNotBridged = "FI.MAU.DISCORD.GUILD_NOT_BRIDGED" - ErrCodeLoginPrepareFailed = "FI.MAU.DISCORD.LOGIN_PREPARE_FAILED" - ErrCodeLoginConnectionFailed = "FI.MAU.DISCORD.LOGIN_CONN_FAILED" - ErrCodeLoginFailed = "FI.MAU.DISCORD.LOGIN_FAILED" - ErrCodePostLoginConnFailed = "FI.MAU.DISCORD.POST_LOGIN_CONNECTION_FAILED" -) - -type ProvisioningAPI struct { - bridge *DiscordBridge - log log.Logger -} - -func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI { - p := &ProvisioningAPI{ - bridge: br, - log: br.Log.Sub("Provisioning"), - } - - prefix := br.Config.Bridge.Provisioning.Prefix - - p.log.Debugln("Enabling provisioning API at", prefix) - - r := br.AS.Router.PathPrefix(prefix).Subrouter() - - r.Use(p.authMiddleware) - - r.HandleFunc("/v1/disconnect", p.disconnect).Methods(http.MethodPost) - r.HandleFunc("/v1/ping", p.ping).Methods(http.MethodGet) - r.HandleFunc("/v1/login/qr", p.qrLogin).Methods(http.MethodGet) - r.HandleFunc("/v1/login/token", p.tokenLogin).Methods(http.MethodPost) - r.HandleFunc("/v1/logout", p.logout).Methods(http.MethodPost) - r.HandleFunc("/v1/reconnect", p.reconnect).Methods(http.MethodPost) - - r.HandleFunc("/v1/guilds", p.guildsList).Methods(http.MethodGet) - r.HandleFunc("/v1/guilds/{guildID}", p.guildsBridge).Methods(http.MethodPost) - r.HandleFunc("/v1/guilds/{guildID}", p.guildsUnbridge).Methods(http.MethodDelete) - - if p.bridge.Config.Bridge.Provisioning.DebugEndpoints { - p.log.Debugln("Enabling debug API at /debug") - r := p.bridge.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(p.authMiddleware) - r.PathPrefix("/pprof").Handler(http.DefaultServeMux) - } - - return p -} - -func jsonResponse(w http.ResponseWriter, status int, response interface{}) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(response) -} - -// Response structs -type Response struct { - Success bool `json:"success"` - Status string `json:"status"` -} - -type Error struct { - Success bool `json:"success"` - Error string `json:"error"` - ErrCode string `json:"errcode"` -} - -// Wrapped http.ResponseWriter to capture the status code -type responseWrap struct { - http.ResponseWriter - statusCode int -} - -var _ http.Hijacker = (*responseWrap)(nil) - -func (rw *responseWrap) WriteHeader(statusCode int) { - rw.ResponseWriter.WriteHeader(statusCode) - rw.statusCode = statusCode -} - -func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := rw.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("response does not implement http.Hijacker") - } - return hijacker.Hijack() -} - -// Middleware -func (p *ProvisioningAPI) authMiddleware(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - - // Special case the login endpoint to use the discord qrcode auth - if auth == "" && strings.HasSuffix(r.URL.Path, "/login") { - authParts := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",") - for _, part := range authParts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, SecWebSocketProtocol+"-") { - auth = part[len(SecWebSocketProtocol+"-"):] - - break - } - } - } else if strings.HasPrefix(auth, "Bearer ") { - auth = auth[len("Bearer "):] - } - - if auth != p.bridge.Config.Bridge.Provisioning.SharedSecret { - jsonResponse(w, http.StatusUnauthorized, map[string]interface{}{ - "error": "Invalid auth token", - "errcode": mautrix.MUnknownToken.ErrCode, - }) - - return - } - - userID := r.URL.Query().Get("user_id") - user := p.bridge.GetUserByMXID(id.UserID(userID)) - - start := time.Now() - wWrap := &responseWrap{w, 200} - h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) - duration := time.Now().Sub(start).Seconds() - - p.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode) - }) -} - -// websocket upgrader -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - Subprotocols: []string{SecWebSocketProtocol}, -} - -// Handlers -func (p *ProvisioningAPI) disconnect(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - - if !user.Connected() { - jsonResponse(w, http.StatusConflict, Error{ - Error: "You're not connected to discord", - ErrCode: ErrCodeNotConnected, - }) - return - } - - if err := user.Disconnect(); err != nil { - p.log.Errorfln("Failed to disconnect %s: %v", user.MXID, err) - jsonResponse(w, http.StatusInternalServerError, Error{ - Error: "Failed to disconnect from discord", - ErrCode: ErrCodeDisconnectFailed, - }) - } else { - jsonResponse(w, http.StatusOK, Response{ - Success: true, - Status: "Disconnected from Discord", - }) - } -} - -type respPing struct { - Discord struct { - ID string `json:"id,omitempty"` - LoggedIn bool `json:"logged_in"` - Connected bool `json:"connected"` - Conn struct { - LastHeartbeatAck int64 `json:"last_heartbeat_ack,omitempty"` - LastHeartbeatSent int64 `json:"last_heartbeat_sent,omitempty"` - } `json:"conn"` - } - MXID id.UserID `json:"mxid"` - ManagementRoom id.RoomID `json:"management_room"` -} - -func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - - resp := respPing{ - MXID: user.MXID, - ManagementRoom: user.ManagementRoom, - } - resp.Discord.LoggedIn = user.IsLoggedIn() - resp.Discord.Connected = user.Connected() - resp.Discord.ID = user.DiscordID - if user.Session != nil { - resp.Discord.Conn.LastHeartbeatAck = user.Session.LastHeartbeatAck.UnixMilli() - resp.Discord.Conn.LastHeartbeatSent = user.Session.LastHeartbeatSent.UnixMilli() - } - jsonResponse(w, http.StatusOK, resp) -} - -func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - var msg string - if user.DiscordID != "" { - msg = "Logged out successfully." - } else { - msg = "User wasn't logged in." - } - user.Logout(false) - jsonResponse(w, http.StatusOK, Response{true, msg}) -} - -func (p *ProvisioningAPI) qrLogin(w http.ResponseWriter, r *http.Request) { - userID := r.URL.Query().Get("user_id") - user := p.bridge.GetUserByMXID(id.UserID(userID)) - - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - p.log.Errorln("Failed to upgrade connection to websocket:", err) - return - } - - log := p.log.Sub("QRLogin").Sub(user.MXID.String()) - - defer func() { - err := c.Close() - if err != nil { - log.Debugln("Error closing websocket:", err) - } - }() - - go func() { - // Read everything so SetCloseHandler() works - for { - _, _, err := c.ReadMessage() - if err != nil { - break - } - } - }() - - ctx, cancel := context.WithCancel(context.Background()) - c.SetCloseHandler(func(code int, text string) error { - log.Debugfln("Login websocket closed (%d), cancelling login", code) - cancel() - return nil - }) - - if user.IsLoggedIn() { - _ = c.WriteJSON(Error{ - Error: "You're already logged into Discord", - ErrCode: ErrCodeAlreadyLoggedIn, - }) - return - } - - client, err := remoteauth.New() - if err != nil { - log.Errorln("Failed to prepare login:", err) - _ = c.WriteJSON(Error{ - Error: "Failed to prepare login", - ErrCode: ErrCodeLoginPrepareFailed, - }) - return - } - - qrChan := make(chan string) - doneChan := make(chan struct{}) - - log.Debugln("Started login via provisioning API") - - err = client.Dial(ctx, qrChan, doneChan) - if err != nil { - log.Errorln("Failed to connect to Discord login websocket:", err) - close(qrChan) - close(doneChan) - _ = c.WriteJSON(Error{ - Error: "Failed to connect to Discord login websocket", - ErrCode: ErrCodeLoginConnectionFailed, - }) - return - } - - for { - select { - case qrCode, ok := <-qrChan: - if !ok { - continue - } - err = c.WriteJSON(map[string]interface{}{ - "code": qrCode, - "timeout": 120, // TODO: move this to the library or something - }) - if err != nil { - log.Errorln("Failed to write QR code to websocket:", err) - } - case <-doneChan: - var discordUser remoteauth.User - discordUser, err = client.Result() - if err != nil { - log.Errorln("Discord login websocket returned error:", err) - _ = c.WriteJSON(Error{ - Error: "Failed to log in", - ErrCode: ErrCodeLoginFailed, - }) - return - } - - log.Infofln("Logged in as %s#%s (%s)", discordUser.Username, discordUser.Discriminator, discordUser.UserID) - - if err = user.Login(discordUser.Token); err != nil { - log.Errorln("Failed to connect after logging in:", err) - _ = c.WriteJSON(Error{ - Error: "Failed to connect to Discord after logging in", - ErrCode: ErrCodePostLoginConnFailed, - }) - return - } - - err = c.WriteJSON(respLogin{ - Success: true, - ID: user.DiscordID, - Username: discordUser.Username, - Discriminator: discordUser.Discriminator, - }) - if err != nil { - log.Errorln("Failed to write login success to websocket:", err) - } - return - case <-ctx.Done(): - return - } - } -} - -type respLogin struct { - Success bool `json:"success"` - ID string `json:"id"` - Username string `json:"username"` - Discriminator string `json:"discriminator"` -} - -type reqTokenLogin struct { - Token string `json:"token"` -} - -func (p *ProvisioningAPI) tokenLogin(w http.ResponseWriter, r *http.Request) { - userID := r.URL.Query().Get("user_id") - user := p.bridge.GetUserByMXID(id.UserID(userID)) - log := p.log.Sub("TokenLogin").Sub(user.MXID.String()) - if user.IsLoggedIn() { - jsonResponse(w, http.StatusConflict, Error{ - Error: "You're already logged into Discord", - ErrCode: ErrCodeAlreadyLoggedIn, - }) - return - } - var body reqTokenLogin - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - log.Errorln("Failed to parse login request:", err) - jsonResponse(w, http.StatusBadRequest, Error{ - Error: "Failed to parse request body", - ErrCode: mautrix.MBadJSON.ErrCode, - }) - return - } - if err := user.Login(body.Token); err != nil { - log.Errorln("Failed to connect with provided token:", err) - jsonResponse(w, http.StatusUnauthorized, Error{ - Error: "Failed to connect to Discord", - ErrCode: ErrCodePostLoginConnFailed, - }) - return - } - log.Infoln("Successfully logged in") - jsonResponse(w, http.StatusOK, respLogin{ - Success: true, - ID: user.DiscordID, - Username: user.Session.State.User.Username, - Discriminator: user.Session.State.User.Discriminator, - }) -} - -func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - - if user.Connected() { - jsonResponse(w, http.StatusConflict, Error{ - Error: "You're already connected to discord", - ErrCode: ErrCodeAlreadyConnected, - }) - - return - } - - if err := user.Connect(); err != nil { - jsonResponse(w, http.StatusInternalServerError, Error{ - Error: "Failed to connect to discord", - ErrCode: ErrCodeConnectFailed, - }) - } else { - jsonResponse(w, http.StatusOK, Response{ - Success: true, - Status: "Connected to Discord", - }) - } -} - -type guildEntry struct { - ID string `json:"id"` - Name string `json:"name"` - AvatarURL id.ContentURI `json:"avatar_url"` - MXID id.RoomID `json:"mxid"` - AutoBridge bool `json:"auto_bridge_channels"` - BridgingMode string `json:"bridging_mode"` -} - -type respGuildsList struct { - Guilds []guildEntry `json:"guilds"` -} - -func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - - var resp respGuildsList - resp.Guilds = []guildEntry{} - for _, userGuild := range user.GetPortals() { - guild := p.bridge.GetGuildByID(userGuild.DiscordID, false) - if guild == nil { - continue - } - resp.Guilds = append(resp.Guilds, guildEntry{ - ID: guild.ID, - Name: guild.PlainName, - AvatarURL: guild.AvatarURL, - MXID: guild.MXID, - AutoBridge: guild.BridgingMode == database.GuildBridgeEverything, - BridgingMode: guild.BridgingMode.String(), - }) - } - - jsonResponse(w, http.StatusOK, resp) -} - -type reqBridgeGuild struct { - AutoCreateChannels bool `json:"auto_create_channels"` -} - -type respBridgeGuild struct { - Success bool `json:"success"` - MXID id.RoomID `json:"mxid"` -} - -func (p *ProvisioningAPI) guildsBridge(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) - guildID := mux.Vars(r)["guildID"] - - var body reqBridgeGuild - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - p.log.Errorln("Failed to parse bridge request:", err) - jsonResponse(w, http.StatusBadRequest, Error{ - Error: "Failed to parse request body", - ErrCode: mautrix.MBadJSON.ErrCode, - }) - return - } - - guild := user.bridge.GetGuildByID(guildID, false) - if guild == nil { - jsonResponse(w, http.StatusNotFound, Error{ - Error: "Guild not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) - return - } - alreadyExists := guild.MXID == "" - if err := user.bridgeGuild(guildID, body.AutoCreateChannels); err != nil { - p.log.Errorfln("Error bridging %s: %v", guildID, err) - jsonResponse(w, http.StatusInternalServerError, Error{ - Error: "Internal error while trying to bridge guild", - ErrCode: ErrCodeGuildBridgeFailed, - }) - } else if alreadyExists { - jsonResponse(w, http.StatusOK, respBridgeGuild{ - Success: true, - MXID: guild.MXID, - }) - } else { - jsonResponse(w, http.StatusCreated, respBridgeGuild{ - Success: true, - MXID: guild.MXID, - }) - } -} - -func (p *ProvisioningAPI) guildsUnbridge(w http.ResponseWriter, r *http.Request) { - guildID := mux.Vars(r)["guildID"] - user := r.Context().Value("user").(*User) - if user.PermissionLevel < bridgeconfig.PermissionLevelAdmin { - jsonResponse(w, http.StatusForbidden, Error{ - Error: "Only bridge admins can unbridge guilds", - ErrCode: mautrix.MForbidden.ErrCode, - }) - } else if guild := user.bridge.GetGuildByID(guildID, false); guild == nil { - jsonResponse(w, http.StatusNotFound, Error{ - Error: "Guild not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) - } else if guild.BridgingMode == database.GuildBridgeNothing && guild.MXID == "" { - jsonResponse(w, http.StatusNotFound, Error{ - Error: "That guild is not bridged", - ErrCode: ErrCodeGuildNotBridged, - }) - } else if err := user.unbridgeGuild(guildID); err != nil { - p.log.Errorfln("Error unbridging %s: %v", guildID, err) - jsonResponse(w, http.StatusInternalServerError, Error{ - Error: "Internal error while trying to unbridge guild", - ErrCode: ErrCodeGuildUnbridgeFailed, - }) - } else { - w.WriteHeader(http.StatusNoContent) - } -} diff --git a/puppet.go b/puppet.go deleted file mode 100644 index ca6489e..0000000 --- a/puppet.go +++ /dev/null @@ -1,386 +0,0 @@ -package main - -import ( - "fmt" - "regexp" - "strings" - "sync" - - "github.com/bwmarrin/discordgo" - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" -) - -type Puppet struct { - *database.Puppet - - bridge *DiscordBridge - log zerolog.Logger - - MXID id.UserID - - customIntent *appservice.IntentAPI - customUser *User - - syncLock sync.Mutex -} - -var _ bridge.Ghost = (*Puppet)(nil) -var _ bridge.GhostWithProfile = (*Puppet)(nil) - -func (puppet *Puppet) GetMXID() id.UserID { - return puppet.MXID -} - -var userIDRegex *regexp.Regexp - -func (br *DiscordBridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { - return &Puppet{ - Puppet: dbPuppet, - bridge: br, - log: br.ZLog.With().Str("discord_user_id", dbPuppet.ID).Logger(), - - MXID: br.FormatPuppetMXID(dbPuppet.ID), - } -} - -func (br *DiscordBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { - if userIDRegex == nil { - pattern := fmt.Sprintf( - "^@%s:%s$", - br.Config.Bridge.FormatUsername("([0-9]+)"), - br.Config.Homeserver.Domain, - ) - - userIDRegex = regexp.MustCompile(pattern) - } - - match := userIDRegex.FindStringSubmatch(string(mxid)) - if len(match) == 2 { - return match[1], true - } - - return "", false -} - -func (br *DiscordBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { - discordID, ok := br.ParsePuppetMXID(mxid) - if !ok { - return nil - } - - return br.GetPuppetByID(discordID) -} - -func (br *DiscordBridge) GetPuppetByID(id string) *Puppet { - br.puppetsLock.Lock() - defer br.puppetsLock.Unlock() - - puppet, ok := br.puppets[id] - if !ok { - dbPuppet := br.DB.Puppet.Get(id) - if dbPuppet == nil { - dbPuppet = br.DB.Puppet.New() - dbPuppet.ID = id - dbPuppet.Insert() - } - - puppet = br.NewPuppet(dbPuppet) - br.puppets[puppet.ID] = puppet - } - - return puppet -} - -func (br *DiscordBridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { - br.puppetsLock.Lock() - defer br.puppetsLock.Unlock() - - puppet, ok := br.puppetsByCustomMXID[mxid] - if !ok { - dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid) - if dbPuppet == nil { - return nil - } - - puppet = br.NewPuppet(dbPuppet) - br.puppets[puppet.ID] = puppet - br.puppetsByCustomMXID[puppet.CustomMXID] = puppet - } - - return puppet -} - -func (br *DiscordBridge) GetAllPuppetsWithCustomMXID() []*Puppet { - return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID()) -} - -func (br *DiscordBridge) GetAllPuppets() []*Puppet { - return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll()) -} - -func (br *DiscordBridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { - br.puppetsLock.Lock() - defer br.puppetsLock.Unlock() - - output := make([]*Puppet, len(dbPuppets)) - for index, dbPuppet := range dbPuppets { - if dbPuppet == nil { - continue - } - - puppet, ok := br.puppets[dbPuppet.ID] - if !ok { - puppet = br.NewPuppet(dbPuppet) - br.puppets[dbPuppet.ID] = puppet - - if dbPuppet.CustomMXID != "" { - br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet - } - } - - output[index] = puppet - } - - return output -} - -func (br *DiscordBridge) FormatPuppetMXID(did string) id.UserID { - return id.NewUserID( - br.Config.Bridge.FormatUsername(did), - br.Config.Homeserver.Domain, - ) -} - -func (puppet *Puppet) GetDisplayname() string { - return puppet.Name -} - -func (puppet *Puppet) GetAvatarURL() id.ContentURI { - return puppet.AvatarURL -} - -func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI { - return puppet.bridge.AS.Intent(puppet.MXID) -} - -func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI { - if puppet.customIntent == nil || (portal.Key.Receiver != "" && portal.Key.Receiver != puppet.ID) { - return puppet.DefaultIntent() - } - - return puppet.customIntent -} - -func (puppet *Puppet) CustomIntent() *appservice.IntentAPI { - if puppet == nil { - return nil - } - return puppet.customIntent -} - -func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) { - for _, portal := range puppet.bridge.GetDMPortalsWith(puppet.ID) { - // Get room create lock to prevent races between receiving contact info and room creation. - portal.roomCreateLock.Lock() - meta(portal) - portal.roomCreateLock.Unlock() - } -} - -func (puppet *Puppet) UpdateName(info *discordgo.User) bool { - newName := puppet.bridge.Config.Bridge.FormatDisplayname(info, puppet.IsWebhook, puppet.IsApplication) - if puppet.Name == newName && puppet.NameSet { - return false - } - puppet.Name = newName - puppet.NameSet = false - err := puppet.DefaultIntent().SetDisplayName(newName) - if err != nil { - puppet.log.Warn().Err(err).Msg("Failed to update displayname") - } else { - go puppet.updatePortalMeta(func(portal *Portal) { - if portal.UpdateNameDirect(puppet.Name, false) { - portal.Update() - portal.UpdateBridgeInfo() - } - }) - puppet.NameSet = true - } - return true -} - -func (br *DiscordBridge) reuploadUserAvatar(intent *appservice.IntentAPI, guildID, userID, avatarID string) (id.ContentURI, string, error) { - var downloadURL string - if guildID == "" { - if strings.HasPrefix(avatarID, "a_") { - downloadURL = discordgo.EndpointUserAvatarAnimated(userID, avatarID) - } else { - downloadURL = discordgo.EndpointUserAvatar(userID, avatarID) - } - } else { - if strings.HasPrefix(avatarID, "a_") { - downloadURL = discordgo.EndpointGuildMemberAvatarAnimated(guildID, userID, avatarID) - } else { - downloadURL = discordgo.EndpointGuildMemberAvatar(guildID, userID, avatarID) - } - } - url := br.DMA.AvatarMXC(guildID, userID, avatarID) - if !url.IsEmpty() { - return url, downloadURL, nil - } - copied, err := br.copyAttachmentToMatrix(intent, downloadURL, false, AttachmentMeta{ - AttachmentID: fmt.Sprintf("avatar/%s/%s/%s", guildID, userID, avatarID), - }) - if err != nil { - return id.ContentURI{}, downloadURL, err - } - return copied.MXC, downloadURL, nil -} - -func (puppet *Puppet) UpdateAvatar(info *discordgo.User) bool { - avatarID := info.Avatar - if puppet.IsWebhook && !puppet.bridge.Config.Bridge.EnableWebhookAvatars { - avatarID = "" - } - if puppet.Avatar == avatarID && puppet.AvatarSet { - return false - } - avatarChanged := avatarID != puppet.Avatar - puppet.Avatar = avatarID - puppet.AvatarSet = false - puppet.AvatarURL = id.ContentURI{} - - if puppet.Avatar != "" && (puppet.AvatarURL.IsEmpty() || avatarChanged) { - url, _, err := puppet.bridge.reuploadUserAvatar(puppet.DefaultIntent(), "", info.ID, puppet.Avatar) - if err != nil { - puppet.log.Warn().Err(err).Str("avatar_id", puppet.Avatar).Msg("Failed to reupload user avatar") - return true - } - puppet.AvatarURL = url - } - - err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) - if err != nil { - puppet.log.Warn().Err(err).Msg("Failed to update avatar") - } else { - go puppet.updatePortalMeta(func(portal *Portal) { - if portal.UpdateAvatarFromPuppet(puppet) { - portal.Update() - portal.UpdateBridgeInfo() - } - }) - puppet.AvatarSet = true - } - return true -} - -func (puppet *Puppet) UpdateInfo(source *User, info *discordgo.User, message *discordgo.Message) { - puppet.syncLock.Lock() - defer puppet.syncLock.Unlock() - - if info == nil || len(info.Username) == 0 || len(info.Discriminator) == 0 { - if puppet.Name != "" || source == nil { - return - } - var err error - puppet.log.Debug().Str("source_user", source.DiscordID).Msg("Fetching info through user to update puppet") - info, err = source.Session.User(puppet.ID) - if err != nil { - puppet.log.Error().Err(err).Str("source_user", source.DiscordID).Msg("Failed to fetch info through user") - return - } - } - - err := puppet.DefaultIntent().EnsureRegistered() - if err != nil { - puppet.log.Error().Err(err).Msg("Failed to ensure registered") - } - - changed := false - if message != nil { - if message.WebhookID != "" && message.ApplicationID == "" && !puppet.IsWebhook { - puppet.log.Debug(). - Str("message_id", message.ID). - Str("webhook_id", message.WebhookID). - Msg("Found webhook ID in message, marking ghost as a webhook") - puppet.IsWebhook = true - changed = true - } - if message.ApplicationID != "" && !puppet.IsApplication { - puppet.log.Debug(). - Str("message_id", message.ID). - Str("application_id", message.ApplicationID). - Msg("Found application ID in message, marking ghost as an application") - puppet.IsApplication = true - puppet.IsWebhook = false - changed = true - } - } - changed = puppet.UpdateContactInfo(info) || changed - changed = puppet.UpdateName(info) || changed - changed = puppet.UpdateAvatar(info) || changed - if changed { - puppet.Update() - } -} - -func (puppet *Puppet) UpdateContactInfo(info *discordgo.User) bool { - changed := false - if puppet.Username != info.Username { - puppet.Username = info.Username - changed = true - } - if puppet.GlobalName != info.GlobalName { - puppet.GlobalName = info.GlobalName - changed = true - } - if puppet.Discriminator != info.Discriminator { - puppet.Discriminator = info.Discriminator - changed = true - } - if puppet.IsBot != info.Bot { - puppet.IsBot = info.Bot - changed = true - } - if (changed && !puppet.IsWebhook) || !puppet.ContactInfoSet { - puppet.ContactInfoSet = false - puppet.ResendContactInfo() - return true - } - return false -} - -func (puppet *Puppet) ResendContactInfo() { - if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || puppet.ContactInfoSet { - return - } - discordUsername := puppet.Username - if puppet.Discriminator != "0" { - discordUsername += "#" + puppet.Discriminator - } - contactInfo := map[string]any{ - "com.beeper.bridge.identifiers": []string{ - fmt.Sprintf("discord:%s", discordUsername), - }, - "com.beeper.bridge.remote_id": puppet.ID, - "com.beeper.bridge.service": puppet.bridge.BeeperServiceName, - "com.beeper.bridge.network": puppet.bridge.BeeperNetworkName, - "com.beeper.bridge.is_network_bot": puppet.IsBot, - } - if puppet.IsWebhook { - contactInfo["com.beeper.bridge.identifiers"] = []string{} - } - err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo) - if err != nil { - puppet.log.Warn().Err(err).Msg("Failed to store custom contact info in profile") - } else { - puppet.ContactInfoSet = true - } -} diff --git a/thread.go b/thread.go deleted file mode 100644 index 6e6aa7b..0000000 --- a/thread.go +++ /dev/null @@ -1,161 +0,0 @@ -package main - -import ( - "context" - "sync" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/rs/zerolog" - "golang.org/x/exp/slices" - "maunium.net/go/mautrix/id" - - "go.mau.fi/mautrix-discord/database" -) - -type Thread struct { - *database.Thread - Parent *Portal - - creationNoticeLock sync.Mutex - initialBackfillAttempted bool -} - -func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread { - br.threadsLock.Lock() - defer br.threadsLock.Unlock() - thread, ok := br.threadsByID[id] - if !ok { - return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root) - } - return thread -} - -func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread { - br.threadsLock.Lock() - defer br.threadsLock.Unlock() - thread, ok := br.threadsByRootMXID[mxid] - if !ok { - return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil) - } - return thread -} - -func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread { - br.threadsLock.Lock() - defer br.threadsLock.Unlock() - thread, ok := br.threadsByRootMXID[mxid] - if !ok { - thread, ok = br.threadsByCreationNoticeMXID[mxid] - if !ok { - return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil) - } - } - return thread -} - -func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread { - if dbThread == nil { - if root == nil { - return nil - } - dbThread = br.DB.Thread.New() - dbThread.ID = id - dbThread.RootDiscordID = root.DiscordID - dbThread.RootMXID = root.MXID - dbThread.ParentID = root.Channel.ChannelID - dbThread.Insert() - } - thread := &Thread{ - Thread: dbThread, - } - thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, "")) - br.threadsByID[thread.ID] = thread - br.threadsByRootMXID[thread.RootMXID] = thread - if thread.CreationNoticeMXID != "" { - br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread - } - return thread -} - -func (br *DiscordBridge) threadFound(ctx context.Context, source *User, rootMessage *database.Message, id string, metadata *discordgo.Channel) { - thread := br.GetThreadByID(id, rootMessage) - log := zerolog.Ctx(ctx) - log.Debug().Msg("Marked message as thread root") - if thread.CreationNoticeMXID == "" { - thread.Parent.sendThreadCreationNotice(ctx, thread) - } - // TODO member_ids_preview is probably not guaranteed to contain the source user - if source != nil && metadata != nil && slices.Contains(metadata.MemberIDsPreview, source.DiscordID) && !source.IsInPortal(thread.ID) { - source.MarkInPortal(database.UserPortal{ - DiscordID: thread.ID, - Type: database.UserPortalTypeThread, - Timestamp: time.Now(), - }) - if metadata.MessageCount > 0 { - go thread.maybeInitialBackfill(source) - } else { - thread.initialBackfillAttempted = true - } - } -} - -func (thread *Thread) maybeInitialBackfill(source *User) { - if thread.initialBackfillAttempted || thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread == 0 { - return - } - thread.Parent.forwardBackfillLock.Lock() - if thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) != nil { - thread.Parent.forwardBackfillLock.Unlock() - return - } - thread.Parent.forwardBackfillInitial(source, thread) -} - -func (thread *Thread) RefererOpt() discordgo.RequestOption { - return discordgo.WithThreadReferer(thread.Parent.GuildID, thread.ParentID, thread.ID) -} - -func (thread *Thread) Join(user *User) { - if user.IsInPortal(thread.ID) { - return - } - log := user.log.With().Str("thread_id", thread.ID).Str("channel_id", thread.ParentID).Logger() - log.Debug().Msg("Joining thread") - - var doBackfill, backfillStarted bool - if !thread.initialBackfillAttempted && thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread > 0 { - thread.Parent.forwardBackfillLock.Lock() - lastMessage := thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) - if lastMessage != nil { - thread.Parent.forwardBackfillLock.Unlock() - } else { - doBackfill = true - defer func() { - if !backfillStarted { - thread.Parent.forwardBackfillLock.Unlock() - } - }() - } - } - - var err error - if user.Session.IsUser { - err = user.Session.ThreadJoin(thread.ID, discordgo.WithLocationParam(discordgo.ThreadJoinLocationContextMenu), thread.RefererOpt()) - } else { - err = user.Session.ThreadJoin(thread.ID) - } - if err != nil { - log.Error().Err(err).Msg("Error joining thread") - } else { - user.MarkInPortal(database.UserPortal{ - DiscordID: thread.ID, - Type: database.UserPortalTypeThread, - Timestamp: time.Now(), - }) - if doBackfill { - go thread.Parent.forwardBackfillInitial(user, thread) - backfillStarted = true - } - } -} diff --git a/user.go b/user.go deleted file mode 100644 index f209b33..0000000 --- a/user.go +++ /dev/null @@ -1,1526 +0,0 @@ -package main - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "math/rand" - "net/http" - "net/url" - "os" - "runtime/debug" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/bwmarrin/discordgo" - "github.com/gorilla/websocket" - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules" - - "go.mau.fi/mautrix-discord/database" -) - -var ( - ErrNotConnected = errors.New("not connected") - ErrNotLoggedIn = errors.New("not logged in") -) - -type User struct { - *database.User - - sync.Mutex - - bridge *DiscordBridge - log zerolog.Logger - - PermissionLevel bridgeconfig.PermissionLevel - - spaceCreateLock sync.Mutex - spaceMembershipChecked bool - dmSpaceMembershipChecked bool - - Session *discordgo.Session - - BridgeState *bridge.BridgeStateQueue - bridgeStateLock sync.Mutex - wasDisconnected bool - wasLoggedOut bool - - markedOpened map[string]time.Time - markedOpenedLock sync.Mutex - - pendingInteractions map[string]*WrappedCommandEvent - pendingInteractionsLock sync.Mutex - - nextDiscordUploadID atomic.Int32 - - relationships map[string]*discordgo.Relationship -} - -func (user *User) GetRemoteID() string { - return user.DiscordID -} - -func (user *User) GetRemoteName() string { - if user.Session != nil && user.Session.State != nil && user.Session.State.User != nil { - if user.Session.State.User.Discriminator == "0" { - return fmt.Sprintf("@%s", user.Session.State.User.Username) - } - return fmt.Sprintf("%s#%s", user.Session.State.User.Username, user.Session.State.User.Discriminator) - } - return user.DiscordID -} - -var discordLog zerolog.Logger - -func discordToZeroLevel(level int) zerolog.Level { - switch level { - case discordgo.LogError: - return zerolog.ErrorLevel - case discordgo.LogWarning: - return zerolog.WarnLevel - case discordgo.LogInformational: - return zerolog.InfoLevel - case discordgo.LogDebug: - fallthrough - default: - return zerolog.DebugLevel - } -} - -func init() { - discordgo.Logger = func(msgL, caller int, format string, a ...interface{}) { - discordLog.WithLevel(discordToZeroLevel(msgL)).Caller(caller+1).Msgf(strings.TrimSpace(format), a...) // zerolog-allow-msgf - } -} - -func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel { - return user.PermissionLevel -} - -func (user *User) GetManagementRoomID() id.RoomID { - return user.ManagementRoom -} - -func (user *User) GetMXID() id.UserID { - return user.MXID -} - -func (user *User) GetCommandState() map[string]interface{} { - return nil -} - -func (user *User) GetIDoublePuppet() bridge.DoublePuppet { - p := user.bridge.GetPuppetByCustomMXID(user.MXID) - if p == nil || p.CustomIntent() == nil { - return nil - } - return p -} - -func (user *User) GetIGhost() bridge.Ghost { - if user.DiscordID == "" { - return nil - } - p := user.bridge.GetPuppetByID(user.DiscordID) - if p == nil { - return nil - } - return p -} - -var _ bridge.User = (*User)(nil) - -func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { - if dbUser == nil { - if mxid == nil { - return nil - } - dbUser = br.DB.User.New() - dbUser.MXID = *mxid - dbUser.Insert() - } - - user := br.NewUser(dbUser) - br.usersByMXID[user.MXID] = user - if user.DiscordID != "" { - br.usersByID[user.DiscordID] = user - } - if user.ManagementRoom != "" { - br.managementRoomsLock.Lock() - br.managementRooms[user.ManagementRoom] = user - br.managementRoomsLock.Unlock() - } - return user -} - -func (br *DiscordBridge) GetUserByMXID(userID id.UserID) *User { - if userID == br.Bot.UserID || br.IsGhost(userID) { - return nil - } - br.usersLock.Lock() - defer br.usersLock.Unlock() - - user, ok := br.usersByMXID[userID] - if !ok { - return br.loadUser(br.DB.User.GetByMXID(userID), &userID) - } - return user -} - -func (br *DiscordBridge) GetUserByID(id string) *User { - br.usersLock.Lock() - defer br.usersLock.Unlock() - - user, ok := br.usersByID[id] - if !ok { - return br.loadUser(br.DB.User.GetByID(id), nil) - } - return user -} - -func (br *DiscordBridge) GetCachedUserByID(id string) *User { - br.usersLock.Lock() - defer br.usersLock.Unlock() - return br.usersByID[id] -} - -func (br *DiscordBridge) GetCachedUserByMXID(userID id.UserID) *User { - br.usersLock.Lock() - defer br.usersLock.Unlock() - return br.usersByMXID[userID] -} - -func (br *DiscordBridge) NewUser(dbUser *database.User) *User { - user := &User{ - User: dbUser, - bridge: br, - log: br.ZLog.With().Str("user_id", string(dbUser.MXID)).Logger(), - - markedOpened: make(map[string]time.Time), - PermissionLevel: br.Config.Bridge.Permissions.Get(dbUser.MXID), - - pendingInteractions: make(map[string]*WrappedCommandEvent), - - relationships: make(map[string]*discordgo.Relationship), - } - user.nextDiscordUploadID.Store(rand.Int31n(100)) - user.BridgeState = br.NewBridgeStateQueue(user) - return user -} - -func (br *DiscordBridge) getAllUsersWithToken() []*User { - br.usersLock.Lock() - defer br.usersLock.Unlock() - - dbUsers := br.DB.User.GetAllWithToken() - users := make([]*User, len(dbUsers)) - - for idx, dbUser := range dbUsers { - user, ok := br.usersByMXID[dbUser.MXID] - if !ok { - user = br.loadUser(dbUser, nil) - } - users[idx] = user - } - return users -} - -func (br *DiscordBridge) startUsers() { - br.ZLog.Debug().Msg("Starting users") - - usersWithToken := br.getAllUsersWithToken() - for _, u := range usersWithToken { - go u.startupTryConnect(0) - } - if len(usersWithToken) == 0 { - br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil)) - } - - br.ZLog.Debug().Msg("Starting custom puppets") - for _, customPuppet := range br.GetAllPuppetsWithCustomMXID() { - go func(puppet *Puppet) { - br.ZLog.Debug().Str("user_id", puppet.CustomMXID.String()).Msg("Starting custom puppet") - - if err := puppet.StartCustomMXID(true); err != nil { - puppet.log.Error().Err(err).Msg("Failed to start custom puppet") - } - }(customPuppet) - } -} - -func (user *User) startupTryConnect(retryCount int) { - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting}) - err := user.Connect() - if err != nil { - user.log.Error().Err(err).Msg("Error connecting on startup") - closeErr := &websocket.CloseError{} - if errors.As(err, &closeErr) && closeErr.Code == 4004 { - user.invalidAuthHandler(nil) - } else if retryCount < 6 { - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "dc-unknown-websocket-error", Message: err.Error()}) - retryInSeconds := 2 << retryCount - user.log.Debug().Int("retry_in_seconds", retryInSeconds).Msg("Sleeping and retrying connection") - time.Sleep(time.Duration(retryInSeconds) * time.Second) - user.startupTryConnect(retryCount + 1) - } else { - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateUnknownError, Error: "dc-unknown-websocket-error", Message: err.Error()}) - } - } -} - -func (user *User) SetManagementRoom(roomID id.RoomID) { - user.bridge.managementRoomsLock.Lock() - defer user.bridge.managementRoomsLock.Unlock() - - existing, ok := user.bridge.managementRooms[roomID] - if ok { - existing.ManagementRoom = "" - existing.Update() - } - - user.ManagementRoom = roomID - user.bridge.managementRooms[user.ManagementRoom] = user - user.Update() -} - -func (user *User) getSpaceRoom(ptr *id.RoomID, name, topic string, parent id.RoomID) id.RoomID { - if len(*ptr) > 0 { - return *ptr - } - user.spaceCreateLock.Lock() - defer user.spaceCreateLock.Unlock() - if len(*ptr) > 0 { - return *ptr - } - - initialState := []*event.Event{{ - Type: event.StateRoomAvatar, - Content: event.Content{ - Parsed: &event.RoomAvatarEventContent{ - URL: user.bridge.Config.AppService.Bot.ParsedAvatar, - }, - }, - }} - - if parent != "" { - parentIDStr := parent.String() - initialState = append(initialState, &event.Event{ - Type: event.StateSpaceParent, - StateKey: &parentIDStr, - Content: event.Content{ - Parsed: &event.SpaceParentEventContent{ - Canonical: true, - Via: []string{user.bridge.AS.HomeserverDomain}, - }, - }, - }) - } - - resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ - Visibility: "private", - Name: name, - Topic: topic, - InitialState: initialState, - CreationContent: map[string]interface{}{ - "type": event.RoomTypeSpace, - }, - PowerLevelOverride: &event.PowerLevelsEventContent{ - Users: map[id.UserID]int{ - user.bridge.Bot.UserID: 9001, - user.MXID: 50, - }, - }, - RoomVersion: "11", - }) - - if err != nil { - user.log.Error().Err(err).Msg("Failed to auto-create space room") - } else { - *ptr = resp.RoomID - user.Update() - user.ensureInvited(nil, *ptr, false, true) - - if parent != "" { - _, err = user.bridge.Bot.SendStateEvent(parent, event.StateSpaceChild, resp.RoomID.String(), &event.SpaceChildEventContent{ - Via: []string{user.bridge.AS.HomeserverDomain}, - Order: " 0000", - }) - if err != nil { - user.log.Error().Err(err). - Str("created_space_id", resp.RoomID.String()). - Str("parent_space_id", parent.String()). - Msg("Failed to add created space room to parent space") - } - } - } - return *ptr -} - -func (user *User) GetSpaceRoom() id.RoomID { - return user.getSpaceRoom(&user.SpaceRoom, "Discord", "Your Discord bridged chats", "") -} - -func (user *User) GetDMSpaceRoom() id.RoomID { - return user.getSpaceRoom(&user.DMSpaceRoom, "Direct Messages", "Your Discord direct messages", user.GetSpaceRoom()) -} - -func (user *User) ViewingChannel(portal *Portal) bool { - if portal.GuildID != "" || !user.Session.IsUser { - return false - } - user.markedOpenedLock.Lock() - defer user.markedOpenedLock.Unlock() - ts := user.markedOpened[portal.Key.ChannelID] - // TODO is there an expiry time? - if ts.IsZero() { - user.markedOpened[portal.Key.ChannelID] = time.Now() - err := user.Session.MarkViewing(portal.Key.ChannelID) - if err != nil { - user.log.Error().Err(err). - Str("channel_id", portal.Key.ChannelID). - Msg("Failed to mark user as viewing channel") - } - return true - } - return false -} - -func (user *User) mutePortal(intent *appservice.IntentAPI, portal *Portal, unmute bool) { - if len(portal.MXID) == 0 || !user.bridge.Config.Bridge.MuteChannelsOnCreate { - return - } - var err error - if unmute { - user.log.Debug().Str("room_id", portal.MXID.String()).Msg("Unmuting portal") - err = intent.DeletePushRule("global", pushrules.RoomRule, string(portal.MXID)) - } else { - user.log.Debug().Str("room_id", portal.MXID.String()).Msg("Muting portal") - err = intent.PutPushRule("global", pushrules.RoomRule, string(portal.MXID), &mautrix.ReqPutPushRule{ - Actions: []pushrules.PushActionType{pushrules.ActionDontNotify}, - }) - } - if err != nil && !errors.Is(err, mautrix.MNotFound) { - user.log.Warn().Err(err). - Str("room_id", portal.MXID.String()). - Msg("Failed to update push rule through double puppet") - } -} - -func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool) { - doublePuppetIntent := portal.bridge.GetPuppetByCustomMXID(user.MXID).CustomIntent() - if doublePuppetIntent == nil || portal.MXID == "" { - return - } - - // TODO sync mute status properly - if portal.GuildID != "" && user.bridge.Config.Bridge.MuteChannelsOnCreate && justCreated { - user.mutePortal(doublePuppetIntent, portal, false) - } -} - -func (user *User) NextDiscordUploadID() string { - val := user.nextDiscordUploadID.Add(2) - return strconv.Itoa(int(val)) -} - -func (user *User) Login(token string) error { - user.bridgeStateLock.Lock() - user.wasLoggedOut = false - user.bridgeStateLock.Unlock() - user.DiscordToken = token - var err error - const maxRetries = 3 -Loop: - for i := 0; i < maxRetries; i++ { - err = user.Connect() - if err == nil { - user.Update() - return nil - } - user.log.Error().Err(err).Msg("Error connecting for login") - closeErr := &websocket.CloseError{} - errors.As(err, &closeErr) - switch closeErr.Code { - case 4004, 4010, 4011, 4012, 4013, 4014: - break Loop - case 4000: - fallthrough - default: - if i < maxRetries-1 { - time.Sleep(time.Duration(i+1) * 2 * time.Second) - } - } - } - user.DiscordToken = "" - return err -} - -func (user *User) IsLoggedIn() bool { - user.Lock() - defer user.Unlock() - - return user.DiscordToken != "" -} - -func (user *User) Logout(isOverwriting bool) { - user.Lock() - defer user.Unlock() - - if user.DiscordID != "" { - puppet := user.bridge.GetPuppetByID(user.DiscordID) - if puppet.CustomMXID != "" { - err := puppet.SwitchCustomMXID("", "") - if err != nil { - user.log.Warn().Err(err).Msg("Failed to disable custom puppet while logging out of Discord") - } - } - } - - if user.Session != nil { - if err := user.Session.Close(); err != nil { - user.log.Warn().Err(err).Msg("Error closing session") - } - } - - user.Session = nil - user.DiscordToken = "" - user.ReadStateVersion = 0 - if !isOverwriting { - user.bridge.usersLock.Lock() - if user.bridge.usersByID[user.DiscordID] == user { - delete(user.bridge.usersByID, user.DiscordID) - } - user.bridge.usersLock.Unlock() - } - user.DiscordID = "" - user.Update() - user.log.Info().Msg("User logged out") -} - -func (user *User) Connected() bool { - user.Lock() - defer user.Unlock() - - return user.Session != nil -} - -const BotIntents = discordgo.IntentGuilds | - discordgo.IntentGuildMessages | - discordgo.IntentGuildMessageReactions | - discordgo.IntentGuildMessageTyping | - discordgo.IntentGuildBans | - discordgo.IntentGuildEmojis | - discordgo.IntentGuildIntegrations | - discordgo.IntentGuildInvites | - //discordgo.IntentGuildVoiceStates | - //discordgo.IntentGuildScheduledEvents | - discordgo.IntentDirectMessages | - discordgo.IntentDirectMessageTyping | - discordgo.IntentDirectMessageTyping | - // Privileged intents - discordgo.IntentMessageContent | - //discordgo.IntentGuildPresences | - discordgo.IntentGuildMembers - -func (user *User) Connect() error { - user.Lock() - defer user.Unlock() - - if user.DiscordToken == "" { - return ErrNotLoggedIn - } - - user.log.Debug().Msg("Connecting to discord") - - session, err := discordgo.New(user.DiscordToken) - if err != nil { - return err - } - if user.bridge.Config.Bridge.Proxy != "" { - u, _ := url.Parse(user.bridge.Config.Bridge.Proxy) - tlsConf := &tls.Config{ - InsecureSkipVerify: os.Getenv("DISCORD_SKIP_TLS_VERIFICATION") == "true", - } - session.Client.Transport = &http.Transport{ - Proxy: http.ProxyURL(u), - TLSClientConfig: tlsConf, - ForceAttemptHTTP2: true, - } - session.Dialer.Proxy = http.ProxyURL(u) - session.Dialer.TLSClientConfig = tlsConf - } - // TODO move to config - if os.Getenv("DISCORD_DEBUG") == "1" { - session.LogLevel = discordgo.LogDebug - } else { - session.LogLevel = discordgo.LogInformational - } - userDiscordLog := user.log.With().Str("component", "discordgo").Logger() - session.Logger = func(msgL, caller int, format string, a ...interface{}) { - userDiscordLog.WithLevel(discordToZeroLevel(msgL)).Caller(caller+1).Msgf(strings.TrimSpace(format), a...) // zerolog-allow-msgf - } - if !session.IsUser { - session.Identify.Intents = BotIntents - } - session.EventHandler = user.eventHandlerSync - - if session.IsUser { - err = session.LoadMainPage(context.TODO()) - if err != nil { - user.log.Warn().Err(err).Msg("Failed to load main page") - } - } - - user.Session = session - - for { - err = user.Session.Open() - if errors.Is(err, discordgo.ErrImmediateDisconnect) { - user.log.Warn().Err(err).Msg("Retrying initial connection in 5 seconds") - time.Sleep(5 * time.Second) - continue - } - return err - } -} - -func (user *User) eventHandlerSync(rawEvt any) { - go user.eventHandler(rawEvt) -} - -func (user *User) eventHandler(rawEvt any) { - defer func() { - err := recover() - if err != nil { - user.log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in Discord event handler") - } - }() - switch evt := rawEvt.(type) { - case *discordgo.Ready: - user.readyHandler(evt) - case *discordgo.Resumed: - user.resumeHandler(evt) - case *discordgo.Connect: - user.connectedHandler(evt) - case *discordgo.Disconnect: - user.disconnectedHandler(evt) - case *discordgo.InvalidAuth: - user.invalidAuthHandler(evt) - case *discordgo.GuildCreate: - user.guildCreateHandler(evt) - case *discordgo.GuildDelete: - user.guildDeleteHandler(evt) - case *discordgo.GuildUpdate: - user.guildUpdateHandler(evt) - case *discordgo.GuildRoleCreate: - user.discordRoleToDB(evt.GuildID, evt.Role, nil, nil) - case *discordgo.GuildRoleUpdate: - user.discordRoleToDB(evt.GuildID, evt.Role, nil, nil) - case *discordgo.GuildRoleDelete: - user.bridge.DB.Role.DeleteByID(evt.GuildID, evt.RoleID) - case *discordgo.ChannelCreate: - user.channelCreateHandler(evt) - case *discordgo.ChannelDelete: - user.channelDeleteHandler(evt) - case *discordgo.ChannelUpdate: - user.channelUpdateHandler(evt) - case *discordgo.ChannelRecipientAdd: - user.channelRecipientAdd(evt) - case *discordgo.ChannelRecipientRemove: - user.channelRecipientRemove(evt) - case *discordgo.RelationshipAdd: - user.relationshipAddHandler(evt) - case *discordgo.RelationshipRemove: - user.relationshipRemoveHandler(evt) - case *discordgo.RelationshipUpdate: - user.relationshipUpdateHandler(evt) - case *discordgo.MessageCreate: - user.pushPortalMessage(evt, "message create", evt.ChannelID, evt.GuildID) - case *discordgo.MessageDelete: - user.pushPortalMessage(evt, "message delete", evt.ChannelID, evt.GuildID) - case *discordgo.MessageDeleteBulk: - user.pushPortalMessage(evt, "bulk message delete", evt.ChannelID, evt.GuildID) - case *discordgo.MessageUpdate: - user.pushPortalMessage(evt, "message update", evt.ChannelID, evt.GuildID) - case *discordgo.MessageReactionAdd: - user.pushPortalMessage(evt, "reaction add", evt.ChannelID, evt.GuildID) - case *discordgo.MessageReactionRemove: - user.pushPortalMessage(evt, "reaction remove", evt.ChannelID, evt.GuildID) - case *discordgo.MessageAck: - user.messageAckHandler(evt) - case *discordgo.TypingStart: - user.typingStartHandler(evt) - case *discordgo.InteractionSuccess: - user.interactionSuccessHandler(evt) - case *discordgo.ThreadListSync: - user.threadListSyncHandler(evt) - case *discordgo.Event: - // Ignore - default: - user.log.Debug().Type("event_type", evt).Msg("Unhandled event") - } -} - -func (user *User) Disconnect() error { - user.Lock() - defer user.Unlock() - if user.Session == nil { - return ErrNotConnected - } - - user.log.Info().Msg("Disconnecting session manually") - if err := user.Session.Close(); err != nil { - return err - } - user.Session = nil - return nil -} - -func (user *User) getGuildBridgingMode(guildID string) database.GuildBridgingMode { - if guildID == "" { - return database.GuildBridgeEverything - } - guild := user.bridge.GetGuildByID(guildID, false) - if guild == nil { - return database.GuildBridgeNothing - } - return guild.BridgingMode -} - -type ChannelSlice []*discordgo.Channel - -func (s ChannelSlice) Len() int { - return len(s) -} - -func (s ChannelSlice) Less(i, j int) bool { - if s[i].Position != 0 || s[j].Position != 0 { - return s[i].Position < s[j].Position - } - return compareMessageIDs(s[i].LastMessageID, s[j].LastMessageID) == 1 -} - -func (s ChannelSlice) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - -func (user *User) readyHandler(r *discordgo.Ready) { - user.log.Debug().Msg("Discord connection ready") - user.bridgeStateLock.Lock() - user.wasLoggedOut = false - user.bridgeStateLock.Unlock() - - if user.DiscordID != r.User.ID { - user.bridge.usersLock.Lock() - user.DiscordID = r.User.ID - if previousUser, ok := user.bridge.usersByID[user.DiscordID]; ok && previousUser != user { - user.log.Warn(). - Str("previous_user_id", previousUser.MXID.String()). - Msg("Another user is logged in with same Discord ID, logging them out") - // TODO send notice? - previousUser.Logout(true) - } - user.bridge.usersByID[user.DiscordID] = user - user.bridge.usersLock.Unlock() - user.Update() - } - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateBackfilling}) - user.tryAutomaticDoublePuppeting() - - for _, relationship := range r.Relationships { - user.relationships[relationship.ID] = relationship - } - - updateTS := time.Now() - portalsInSpace := make(map[string]bool) - for _, guild := range user.GetPortals() { - portalsInSpace[guild.DiscordID] = guild.InSpace - } - for _, guild := range r.Guilds { - user.handleGuild(guild, updateTS, portalsInSpace[guild.ID]) - } - // The private channel list doesn't seem to be sorted by default, so sort it by message IDs (highest=newest first) - sort.Sort(ChannelSlice(r.PrivateChannels)) - for i, ch := range r.PrivateChannels { - portal := user.GetPortalByMeta(ch) - user.handlePrivateChannel(portal, ch, updateTS, i < user.bridge.Config.Bridge.PrivateChannelCreateLimit, portalsInSpace[portal.Key.ChannelID]) - } - user.PrunePortalList(updateTS) - - if r.ReadState != nil && r.ReadState.Version > user.ReadStateVersion { - // TODO can we figure out which read states are actually new? - for _, entry := range r.ReadState.Entries { - user.messageAckHandler(&discordgo.MessageAck{ - MessageID: string(entry.LastMessageID), - ChannelID: entry.ID, - }) - } - user.ReadStateVersion = r.ReadState.Version - user.Update() - } - - go user.subscribeGuilds(2 * time.Second) - - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) -} - -func (user *User) subscribeGuilds(delay time.Duration) { - if !user.Session.IsUser { - return - } - for _, guildMeta := range user.Session.State.Guilds { - guild := user.bridge.GetGuildByID(guildMeta.ID, false) - if guild != nil && guild.MXID != "" { - user.log.Debug().Str("guild_id", guild.ID).Msg("Subscribing to guild") - dat := discordgo.GuildSubscribeData{ - GuildID: guild.ID, - Typing: true, - Activities: true, - Threads: true, - } - err := user.Session.SubscribeGuild(dat) - if err != nil { - user.log.Warn().Err(err).Str("guild_id", guild.ID).Msg("Failed to subscribe to guild") - } - time.Sleep(delay) - } - } -} - -func (user *User) resumeHandler(_ *discordgo.Resumed) { - user.log.Debug().Msg("Discord connection resumed") - user.subscribeGuilds(0 * time.Second) -} - -func (user *User) addPrivateChannelToSpace(portal *Portal) bool { - if portal.MXID == "" { - return false - } - _, err := user.bridge.Bot.SendStateEvent(user.GetDMSpaceRoom(), event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ - Via: []string{user.bridge.AS.HomeserverDomain}, - }) - if err != nil { - user.log.Error().Err(err). - Str("room_id", portal.MXID.String()). - Msg("Failed to add DMM room to user DM space") - return false - } else { - return true - } -} - -func (user *User) relationshipAddHandler(r *discordgo.RelationshipAdd) { - user.log.Debug().Interface("relationship", r.Relationship).Msg("Relationship added") - user.relationships[r.ID] = r.Relationship - user.handleRelationshipChange(r.ID, r.Nickname) -} - -func (user *User) relationshipUpdateHandler(r *discordgo.RelationshipUpdate) { - user.log.Debug().Interface("relationship", r.Relationship).Msg("Relationship update") - user.relationships[r.ID] = r.Relationship - user.handleRelationshipChange(r.ID, r.Nickname) -} - -func (user *User) relationshipRemoveHandler(r *discordgo.RelationshipRemove) { - user.log.Debug().Str("other_user_id", r.ID).Msg("Relationship removed") - delete(user.relationships, r.ID) - user.handleRelationshipChange(r.ID, "") -} - -func (user *User) handleRelationshipChange(userID, nickname string) { - puppet := user.bridge.GetPuppetByID(userID) - portal := user.FindPrivateChatWith(userID) - if portal == nil || puppet == nil { - return - } - - updated := portal.FriendNick == (nickname != "") - portal.FriendNick = nickname != "" - if nickname != "" { - updated = portal.UpdateNameDirect(nickname, true) - } else if portal.Name != puppet.Name { - if portal.shouldSetDMRoomMetadata() { - updated = portal.UpdateNameDirect(puppet.Name, false) - } else if portal.NameSet { - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateRoomName, "", map[string]any{}) - if err != nil { - portal.log.Warn().Err(err).Msg("Failed to clear room name after friend nickname was removed") - } else { - portal.log.Debug().Msg("Cleared room name after friend nickname was removed") - portal.NameSet = false - portal.Update() - updated = true - } - } - } - if !updated { - portal.Update() - } -} - -func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel, timestamp time.Time, create, isInSpace bool) { - if create && portal.MXID == "" { - err := portal.CreateMatrixRoom(user, meta) - if err != nil { - user.log.Error().Err(err). - Str("channel_id", portal.Key.ChannelID). - Msg("Failed to create portal for private channel in create handler") - } - } else { - portal.UpdateInfo(user, meta) - portal.ForwardBackfillMissed(user, meta.LastMessageID, nil) - } - user.MarkInPortal(database.UserPortal{ - DiscordID: portal.Key.ChannelID, - Type: database.UserPortalTypeDM, - Timestamp: timestamp, - InSpace: isInSpace || user.addPrivateChannelToSpace(portal), - }) -} - -func (user *User) addGuildToSpace(guild *Guild, isInSpace bool, timestamp time.Time) bool { - if len(guild.MXID) > 0 && !isInSpace { - _, err := user.bridge.Bot.SendStateEvent(user.GetSpaceRoom(), event.StateSpaceChild, guild.MXID.String(), &event.SpaceChildEventContent{ - Via: []string{user.bridge.AS.HomeserverDomain}, - }) - if err != nil { - user.log.Error().Err(err). - Str("guild_space_id", guild.MXID.String()). - Msg("Failed to add guild space to user space") - } else { - isInSpace = true - } - } - user.MarkInPortal(database.UserPortal{ - DiscordID: guild.ID, - Type: database.UserPortalTypeGuild, - Timestamp: timestamp, - InSpace: isInSpace, - }) - return isInSpace -} - -func (user *User) discordRoleToDB(guildID string, role *discordgo.Role, dbRole *database.Role, txn dbutil.Execable) bool { - var changed bool - if dbRole == nil { - dbRole = user.bridge.DB.Role.New() - dbRole.ID = role.ID - dbRole.GuildID = guildID - changed = true - } else { - changed = dbRole.Name != role.Name || - dbRole.Icon != role.Icon || - dbRole.Mentionable != role.Mentionable || - dbRole.Managed != role.Managed || - dbRole.Hoist != role.Hoist || - dbRole.Color != role.Color || - dbRole.Position != role.Position || - dbRole.Permissions != role.Permissions - } - dbRole.Role = *role - if changed { - dbRole.Upsert(txn) - } - return changed -} - -func (user *User) handleGuildRoles(guildID string, newRoles []*discordgo.Role) { - existingRoles := user.bridge.DB.Role.GetAll(guildID) - existingRoleMap := make(map[string]*database.Role, len(existingRoles)) - for _, role := range existingRoles { - existingRoleMap[role.ID] = role - } - txn, err := user.bridge.DB.Begin() - if err != nil { - user.log.Error().Err(err).Msg("Failed to start transaction for guild role sync") - panic(err) - } - for _, role := range newRoles { - user.discordRoleToDB(guildID, role, existingRoleMap[role.ID], txn) - delete(existingRoleMap, role.ID) - } - for _, removeRole := range existingRoleMap { - removeRole.Delete(txn) - } - err = txn.Commit() - if err != nil { - user.log.Error().Err(err).Msg("Failed to commit guild role sync transaction") - rollbackErr := txn.Rollback() - if rollbackErr != nil { - user.log.Error().Err(rollbackErr).Msg("Failed to rollback errored guild role sync transaction") - } - panic(err) - } -} - -func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSpace bool) { - guild := user.bridge.GetGuildByID(meta.ID, true) - guild.UpdateInfo(user, meta) - if len(meta.Channels) > 0 { - for _, ch := range meta.Channels { - if !user.channelIsBridgeable(ch) { - continue - } - portal := user.GetPortalByMeta(ch) - if guild.BridgingMode >= database.GuildBridgeEverything && portal.MXID == "" { - err := portal.CreateMatrixRoom(user, ch) - if err != nil { - user.log.Error().Err(err). - Str("guild_id", guild.ID). - Str("channel_id", ch.ID). - Msg("Failed to create portal for guild channel in guild handler") - } - } else { - portal.UpdateInfo(user, ch) - if user.bridge.Config.Bridge.Backfill.MaxGuildMembers < 0 || meta.MemberCount < user.bridge.Config.Bridge.Backfill.MaxGuildMembers { - portal.ForwardBackfillMissed(user, ch.LastMessageID, nil) - } - } - } - } - if len(meta.Roles) > 0 { - user.handleGuildRoles(meta.ID, meta.Roles) - } - user.addGuildToSpace(guild, isInSpace, timestamp) -} - -func (user *User) connectedHandler(_ *discordgo.Connect) { - user.bridgeStateLock.Lock() - defer user.bridgeStateLock.Unlock() - user.log.Debug().Msg("Connected to Discord") - if user.wasDisconnected { - user.wasDisconnected = false - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) - } -} - -func (user *User) disconnectedHandler(_ *discordgo.Disconnect) { - user.bridgeStateLock.Lock() - defer user.bridgeStateLock.Unlock() - if user.wasLoggedOut { - user.log.Debug().Msg("Disconnected from Discord (not updating bridge state as user was just logged out)") - return - } - user.log.Debug().Msg("Disconnected from Discord") - user.wasDisconnected = true - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "dc-transient-disconnect", Message: "Temporarily disconnected from Discord, trying to reconnect"}) -} - -func (user *User) invalidAuthHandler(_ *discordgo.InvalidAuth) { - user.bridgeStateLock.Lock() - defer user.bridgeStateLock.Unlock() - user.log.Info().Msg("Got logged out from Discord due to invalid token") - user.wasLoggedOut = true - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Error: "dc-websocket-disconnect-4004", Message: "Discord access token is no longer valid, please log in again"}) - go user.Logout(false) -} - -func (user *User) handlePossible40002(err error) bool { - var restErr *discordgo.RESTError - if !errors.As(err, &restErr) || restErr.Message == nil || restErr.Message.Code != discordgo.ErrCodeActionRequiredVerifiedAccount { - return false - } - user.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Error: "dc-http-40002", Message: restErr.Message.Message}) - return true -} - -func (user *User) guildCreateHandler(g *discordgo.GuildCreate) { - user.log.Info(). - Str("guild_id", g.ID). - Str("name", g.Name). - Bool("unavailable", g.Unavailable). - Msg("Got guild create event") - user.handleGuild(g.Guild, time.Now(), false) -} - -func (user *User) guildDeleteHandler(g *discordgo.GuildDelete) { - if g.Unavailable { - user.log.Info().Str("guild_id", g.ID).Msg("Ignoring guild delete event with unavailable flag") - return - } - user.log.Info().Str("guild_id", g.ID).Msg("Got guild delete event") - user.MarkNotInPortal(g.ID) - guild := user.bridge.GetGuildByID(g.ID, false) - if guild == nil || guild.MXID == "" { - return - } - if user.bridge.Config.Bridge.DeleteGuildOnLeave && !user.PortalHasOtherUsers(g.ID) { - user.log.Debug().Str("guild_id", g.ID).Msg("No other users in guild, cleaning up all portals") - err := user.unbridgeGuild(g.ID) - if err != nil { - user.log.Warn().Err(err).Msg("Failed to unbridge guild that was deleted") - } - } -} - -func (user *User) guildUpdateHandler(g *discordgo.GuildUpdate) { - user.log.Debug().Str("guild_id", g.ID).Msg("Got guild update event") - user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID)) -} - -func (user *User) threadListSyncHandler(t *discordgo.ThreadListSync) { - for _, meta := range t.Threads { - log := user.log.With(). - Str("action", "thread list sync"). - Str("guild_id", t.GuildID). - Str("parent_id", meta.ParentID). - Str("thread_id", meta.ID). - Logger() - ctx := log.WithContext(context.Background()) - thread := user.bridge.GetThreadByID(meta.ID, nil) - if thread == nil { - msg := user.bridge.DB.Message.GetByDiscordID(database.NewPortalKey(meta.ParentID, ""), meta.ID) - if len(msg) == 0 { - log.Debug().Msg("Found unknown thread in thread list sync and don't have message") - } else { - log.Debug().Msg("Found unknown thread in thread list sync for existing message, creating thread") - user.bridge.threadFound(ctx, user, msg[0], meta.ID, meta) - } - } else { - thread.Parent.ForwardBackfillMissed(user, meta.LastMessageID, thread) - } - } -} - -func (user *User) channelCreateHandler(c *discordgo.ChannelCreate) { - if user.getGuildBridgingMode(c.GuildID) < database.GuildBridgeEverything { - user.log.Debug(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Ignoring channel create event in unbridged guild") - return - } - user.log.Info(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Got channel create event") - portal := user.GetPortalByMeta(c.Channel) - if portal.MXID != "" { - return - } - if c.GuildID == "" { - user.handlePrivateChannel(portal, c.Channel, time.Now(), true, user.IsInSpace(portal.Key.String())) - } else if user.channelIsBridgeable(c.Channel) { - err := portal.CreateMatrixRoom(user, c.Channel) - if err != nil { - user.log.Error().Err(err). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Error creating Matrix room after channel create event") - } - } else { - user.log.Debug(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Got channel create event, but it's not bridgeable, ignoring") - } -} - -func (user *User) channelDeleteHandler(c *discordgo.ChannelDelete) { - portal := user.GetExistingPortalByID(c.ID) - if portal == nil { - user.log.Debug(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Ignoring channel delete event of unknown channel") - return - } - user.log.Info(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Got channel delete event, cleaning up portal") - portal.Delete() - portal.cleanup(!user.bridge.Config.Bridge.DeletePortalOnChannelDelete) - if c.GuildID == "" { - user.MarkNotInPortal(portal.Key.ChannelID) - } - user.log.Debug(). - Str("guild_id", c.GuildID).Str("channel_id", c.ID). - Msg("Completed cleaning up channel") -} - -func (user *User) channelUpdateHandler(c *discordgo.ChannelUpdate) { - portal := user.GetPortalByMeta(c.Channel) - if c.GuildID == "" { - user.handlePrivateChannel(portal, c.Channel, time.Now(), true, user.IsInSpace(portal.Key.String())) - } else if user.channelIsBridgeable(c.Channel) { - portal.UpdateInfo(user, c.Channel) - } -} - -func (user *User) channelRecipientAdd(c *discordgo.ChannelRecipientAdd) { - portal := user.GetExistingPortalByID(c.ChannelID) - if portal != nil { - portal.syncParticipant(user, c.User, false) - } -} - -func (user *User) channelRecipientRemove(c *discordgo.ChannelRecipientRemove) { - portal := user.GetExistingPortalByID(c.ChannelID) - if portal != nil { - portal.syncParticipant(user, c.User, true) - } -} - -func (user *User) findPortal(channelID string) (*Portal, *Thread) { - portal := user.GetExistingPortalByID(channelID) - if portal != nil { - return portal, nil - } - thread := user.bridge.GetThreadByID(channelID, nil) - if thread != nil && thread.Parent != nil { - return thread.Parent, thread - } - if !user.Session.IsUser { - channel, _ := user.Session.State.Channel(channelID) - if channel == nil { - user.log.Debug().Str("channel_id", channelID).Msg("Fetching info of unknown channel to handle message") - var err error - channel, err = user.Session.Channel(channelID) - if err != nil { - user.log.Warn().Err(err).Str("channel_id", channelID).Msg("Failed to get info of unknown channel") - } else { - user.log.Debug().Str("channel_id", channelID).Msg("Got info for channel to handle message") - _ = user.Session.State.ChannelAdd(channel) - } - } - if channel != nil && user.channelIsBridgeable(channel) { - user.log.Debug().Str("channel_id", channelID).Msg("Creating portal and updating info to handle message") - portal = user.GetPortalByMeta(channel) - if channel.GuildID == "" { - user.handlePrivateChannel(portal, channel, time.Now(), false, false) - } else { - user.log.Warn(). - Str("channel_id", channel.ID).Str("guild_id", channel.GuildID). - Msg("Unexpected unknown guild channel") - } - return portal, nil - } - } - return nil, nil -} - -func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) { - if user.getGuildBridgingMode(guildID) <= database.GuildBridgeNothing { - // If guild bridging mode is nothing, don't even check if the portal exists - return - } - - portal, thread := user.findPortal(channelID) - if portal == nil { - user.log.Debug(). - Str("discord_event", typeName). - Str("guild_id", guildID). - Str("channel_id", channelID). - Msg("Dropping event in unknown channel") - return - } - if mode := user.getGuildBridgingMode(portal.GuildID); mode <= database.GuildBridgeNothing || (portal.MXID == "" && mode <= database.GuildBridgeIfPortalExists) { - return - } - - wrappedMsg := portalDiscordMessage{ - msg: msg, - user: user, - thread: thread, - } - select { - case portal.discordMessages <- wrappedMsg: - default: - user.log.Warn(). - Str("discord_event", typeName). - Str("guild_id", guildID). - Str("channel_id", channelID). - Msg("Portal message buffer is full") - portal.discordMessages <- wrappedMsg - } -} - -type CustomReadReceipt struct { - Timestamp int64 `json:"ts,omitempty"` - DoublePuppetSource string `json:"fi.mau.double_puppet_source,omitempty"` -} - -type CustomReadMarkers struct { - mautrix.ReqSetReadMarkers - ReadExtra CustomReadReceipt `json:"com.beeper.read.extra"` - FullyReadExtra CustomReadReceipt `json:"com.beeper.fully_read.extra"` -} - -func (user *User) makeReadMarkerContent(eventID id.EventID) *CustomReadMarkers { - var extra CustomReadReceipt - extra.DoublePuppetSource = user.bridge.Name - return &CustomReadMarkers{ - ReqSetReadMarkers: mautrix.ReqSetReadMarkers{ - Read: eventID, - FullyRead: eventID, - }, - ReadExtra: extra, - FullyReadExtra: extra, - } -} - -func (user *User) messageAckHandler(m *discordgo.MessageAck) { - portal := user.GetExistingPortalByID(m.ChannelID) - if portal == nil || portal.MXID == "" { - return - } - dp := user.GetIDoublePuppet() - if dp == nil { - return - } - msg := user.bridge.DB.Message.GetLastByDiscordID(portal.Key, m.MessageID) - if msg == nil { - user.log.Debug(). - Str("channel_id", m.ChannelID).Str("message_id", m.MessageID). - Msg("Dropping message ack event for unknown message") - return - } - err := dp.CustomIntent().SetReadMarkers(portal.MXID, user.makeReadMarkerContent(msg.MXID)) - if err != nil { - user.log.Error().Err(err). - Str("event_id", msg.MXID.String()).Str("message_id", msg.DiscordID). - Msg("Failed to mark event as read") - } else { - user.log.Debug(). - Str("event_id", msg.MXID.String()).Str("message_id", msg.DiscordID). - Msg("Marked event as read after Discord message ack") - if user.ReadStateVersion < m.Version { - user.ReadStateVersion = m.Version - // TODO maybe don't update every time? - user.Update() - } - } -} - -func (user *User) typingStartHandler(t *discordgo.TypingStart) { - if t.UserID == user.DiscordID { - return - } - portal := user.GetExistingPortalByID(t.ChannelID) - if portal == nil || portal.MXID == "" { - return - } - targetUser := user.bridge.GetCachedUserByID(t.UserID) - if targetUser != nil { - return - } - portal.handleDiscordTyping(t) -} - -func (user *User) interactionSuccessHandler(s *discordgo.InteractionSuccess) { - user.pendingInteractionsLock.Lock() - defer user.pendingInteractionsLock.Unlock() - ce, ok := user.pendingInteractions[s.Nonce] - if !ok { - user.log.Debug().Str("nonce", s.Nonce).Str("id", s.ID).Msg("Got interaction success for unknown interaction") - } else { - user.log.Debug().Str("nonce", s.Nonce).Str("id", s.ID).Msg("Got interaction success for pending interaction") - ce.React("✅") - delete(user.pendingInteractions, s.Nonce) - } -} - -func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect, ignoreCache bool) bool { - if roomID == "" { - return false - } - if intent == nil { - intent = user.bridge.Bot - } - if !ignoreCache && intent.StateStore.IsInvited(roomID, user.MXID) { - return true - } - ret := false - - inviteContent := event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipInvite, - IsDirect: isDirect, - }, - Raw: map[string]interface{}{}, - } - - customPuppet := user.bridge.GetPuppetByCustomMXID(user.MXID) - if customPuppet != nil && customPuppet.CustomIntent() != nil { - inviteContent.Raw["fi.mau.will_auto_accept"] = true - } - - _, err := intent.SendStateEvent(roomID, event.StateMember, user.MXID.String(), &inviteContent) - - var httpErr mautrix.HTTPError - if err != nil && errors.As(err, &httpErr) && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") { - user.bridge.StateStore.SetMembership(roomID, user.MXID, event.MembershipJoin) - ret = true - } else if err != nil { - user.log.Error().Err(err).Str("room_id", roomID.String()).Msg("Failed to invite user to room") - } else { - ret = true - } - - if customPuppet != nil && customPuppet.CustomIntent() != nil { - err = customPuppet.CustomIntent().EnsureJoined(roomID, appservice.EnsureJoinedParams{IgnoreCache: true}) - if err != nil { - user.log.Warn().Err(err).Str("room_id", roomID.String()).Msg("Failed to auto-join room") - ret = false - } else { - ret = true - } - } - - return ret -} - -func (user *User) getDirectChats() map[id.UserID][]id.RoomID { - chats := map[id.UserID][]id.RoomID{} - - privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.DiscordID) - for _, portal := range privateChats { - if portal.MXID != "" { - puppetMXID := user.bridge.FormatPuppetMXID(portal.Key.Receiver) - - chats[puppetMXID] = []id.RoomID{portal.MXID} - } - } - - return chats -} - -func (user *User) updateDirectChats(chats map[id.UserID][]id.RoomID) { - if !user.bridge.Config.Bridge.SyncDirectChatList { - return - } - - puppet := user.bridge.GetPuppetByMXID(user.MXID) - if puppet == nil { - return - } - - intent := puppet.CustomIntent() - if intent == nil { - return - } - - method := http.MethodPatch - if chats == nil { - chats = user.getDirectChats() - method = http.MethodPut - } - - user.log.Debug().Msg("Updating m.direct list on homeserver") - - var err error - if user.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareAsmux { - urlPath := intent.BuildURL(mautrix.ClientURLPath{"unstable", "com.beeper.asmux", "dms"}) - _, err = intent.MakeFullRequest(mautrix.FullRequest{ - Method: method, - URL: urlPath, - Headers: http.Header{"X-Asmux-Auth": {user.bridge.AS.Registration.AppToken}}, - RequestJSON: chats, - }) - } else { - existingChats := map[id.UserID][]id.RoomID{} - - err = intent.GetAccountData(event.AccountDataDirectChats.Type, &existingChats) - if err != nil { - user.log.Warn().Err(err).Msg("Failed to get m.direct event to update it") - return - } - - for userID, rooms := range existingChats { - if _, ok := user.bridge.ParsePuppetMXID(userID); !ok { - // This is not a ghost user, include it in the new list - chats[userID] = rooms - } else if _, ok := chats[userID]; !ok && method == http.MethodPatch { - // This is a ghost user, but we're not replacing the whole list, so include it too - chats[userID] = rooms - } - } - - err = intent.SetAccountData(event.AccountDataDirectChats.Type, &chats) - } - - if err != nil { - user.log.Warn().Err(err).Msg("Failed to update m.direct event") - } -} - -func (user *User) bridgeGuild(guildID string, everything bool) error { - guild := user.bridge.GetGuildByID(guildID, false) - if guild == nil { - return errors.New("guild not found") - } - meta, _ := user.Session.State.Guild(guildID) - err := guild.CreateMatrixRoom(user, meta) - if err != nil { - return err - } - log := user.log.With().Str("guild_id", guild.ID).Logger() - user.addGuildToSpace(guild, false, time.Now()) - for _, ch := range meta.Channels { - portal := user.GetPortalByMeta(ch) - if (everything && user.channelIsBridgeable(ch)) || ch.Type == discordgo.ChannelTypeGuildCategory { - err = portal.CreateMatrixRoom(user, ch) - if err != nil { - log.Error().Err(err).Str("channel_id", ch.ID). - Msg("Failed to create room for guild channel while bridging guild") - } - } - } - if everything { - guild.BridgingMode = database.GuildBridgeEverything - } else { - guild.BridgingMode = database.GuildBridgeCreateOnMessage - } - guild.Update() - - if user.Session.IsUser { - log.Debug().Msg("Subscribing to guild after bridging") - err = user.Session.SubscribeGuild(discordgo.GuildSubscribeData{ - GuildID: guild.ID, - Typing: true, - Activities: true, - Threads: true, - }) - if err != nil { - log.Warn().Err(err).Msg("Failed to subscribe to guild") - } - } - - return nil -} - -func (user *User) unbridgeGuild(guildID string) error { - if user.PermissionLevel < bridgeconfig.PermissionLevelAdmin && user.PortalHasOtherUsers(guildID) { - return errors.New("only bridge admins can unbridge guilds with other users") - } - guild := user.bridge.GetGuildByID(guildID, false) - if guild == nil { - return errors.New("guild not found") - } - guild.roomCreateLock.Lock() - defer guild.roomCreateLock.Unlock() - if guild.BridgingMode == database.GuildBridgeNothing && guild.MXID == "" { - return errors.New("that guild is not bridged") - } - guild.BridgingMode = database.GuildBridgeNothing - guild.Update() - for _, portal := range user.bridge.GetAllPortalsInGuild(guild.ID) { - portal.cleanup(false) - portal.RemoveMXID() - } - guild.cleanup() - guild.RemoveMXID() - return nil -}