mirror of
https://github.com/mautrix/signal.git
synced 2026-05-15 05:36:53 -04:00
431 lines
16 KiB
Go
431 lines
16 KiB
Go
// mautrix-signal - A Matrix-signal puppeting bridge.
|
|
// Copyright (C) 2025 Tulir Asokan
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package signalmeow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"maps"
|
|
"net/http"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
"go.mau.fi/util/exslices"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
|
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
|
|
)
|
|
|
|
const SenderKeyMaxAge = 14 * 24 * time.Hour
|
|
|
|
type contextKey int
|
|
|
|
const (
|
|
contextKeyEncryptionLock contextKey = iota
|
|
)
|
|
|
|
func (cli *Client) ResetSenderKey(ctx context.Context, groupID types.GroupIdentifier) (uuid.UUID, error) {
|
|
cli.encryptionLock.Lock()
|
|
defer cli.encryptionLock.Unlock()
|
|
info, err := cli.Store.SenderKeyStore.GetSenderKeyInfo(ctx, groupID)
|
|
if err != nil {
|
|
return uuid.Nil, fmt.Errorf("failed to get sender key info: %w", err)
|
|
} else if info == nil {
|
|
return uuid.Nil, nil
|
|
} else if myAddress, err := cli.Store.ACIServiceID().Address(uint(cli.Store.DeviceID)); err != nil {
|
|
return uuid.Nil, fmt.Errorf("failed to get own address: %w", err)
|
|
} else if err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, info.DistributionID); err != nil {
|
|
return info.DistributionID, fmt.Errorf("failed to delete sender key: %w", err)
|
|
} else if err = cli.Store.SenderKeyStore.DeleteSenderKeyInfo(ctx, groupID); err != nil {
|
|
return info.DistributionID, fmt.Errorf("failed to delete sender key info: %w", err)
|
|
}
|
|
return info.DistributionID, nil
|
|
}
|
|
|
|
func (cli *Client) sendToGroupWithSenderKey(
|
|
ctx context.Context,
|
|
groupID *libsignalgo.GroupIdentifier,
|
|
allRecipients []libsignalgo.ServiceID,
|
|
sec SendEndorsementCache,
|
|
content *signalpb.Content,
|
|
messageTimestamp uint64,
|
|
retries int,
|
|
) (*GroupMessageSendResult, error) {
|
|
if retries >= 3 {
|
|
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, nil, groupID)
|
|
}
|
|
myAddress, err := cli.Store.ACIServiceID().Address(uint(cli.Store.DeviceID))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get own address: %w", err)
|
|
}
|
|
log := zerolog.Ctx(ctx)
|
|
|
|
cli.encryptionLock.Lock()
|
|
unlocked := false
|
|
doUnlock := func() {
|
|
if !unlocked {
|
|
unlocked = true
|
|
cli.encryptionLock.Unlock()
|
|
}
|
|
}
|
|
defer doUnlock()
|
|
ctx = context.WithValue(ctx, contextKeyEncryptionLock, true)
|
|
result := &GroupMessageSendResult{
|
|
SuccessfullySentTo: make([]SuccessfulSendResult, 0),
|
|
FailedToSendTo: make([]FailedSendResult, 0),
|
|
}
|
|
|
|
groupIDStr := types.GroupIdentifier(groupID.String())
|
|
deviceIDs, senderKeyRecipients, fallbackRecipients := cli.getDevicesIDs(ctx, allRecipients, sec, result)
|
|
if len(senderKeyRecipients) == 0 {
|
|
doUnlock()
|
|
log.Debug().Msg("No sender key recipients, falling back to normal send")
|
|
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, result, groupID)
|
|
}
|
|
ski, err := cli.Store.SenderKeyStore.GetSenderKeyInfo(ctx, groupIDStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get sender key info: %w", err)
|
|
} else if ski == nil || time.Since(ski.CreatedAt) > SenderKeyMaxAge {
|
|
if ski != nil && time.Since(ski.CreatedAt) > SenderKeyMaxAge {
|
|
log.Debug().Any("old_sender_key_info", ski).Msg("Sender key expired, creating new one")
|
|
err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, ski.DistributionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete old sender key: %w", err)
|
|
}
|
|
} else {
|
|
log.Debug().Msg("No existing sender key, creating new one")
|
|
}
|
|
ski = &store.SenderKeyInfo{
|
|
DistributionID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
SharedWith: make(map[libsignalgo.ServiceID][]int),
|
|
}
|
|
} else {
|
|
log.Debug().Any("sender_key_info", ski).Msg("Reusing existing sender key")
|
|
}
|
|
xak, devicesAddedTo, removedDevices := diffRecipients(ski.SharedWith, deviceIDs)
|
|
if len(removedDevices) > 0 {
|
|
log.Debug().
|
|
Any("removed_devices", removedDevices).
|
|
Msg("Resetting sender key due to recipient device changes")
|
|
devicesAddedTo = slices.Collect(maps.Keys(deviceIDs))
|
|
err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, ski.DistributionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete old sender key: %w", err)
|
|
}
|
|
}
|
|
if len(devicesAddedTo) > 0 {
|
|
log.Debug().
|
|
Any("devices_added_to", devicesAddedTo).
|
|
Msg("Sending sender key distribution message to users with new devices")
|
|
skdm, err := libsignalgo.NewSenderKeyDistributionMessage(ctx, myAddress, ski.DistributionID, cli.Store.SenderKeyStore)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create sender key distribution message: %w", err)
|
|
}
|
|
skdmBytes, err := skdm.Serialize()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize sender key distribution message: %w", err)
|
|
}
|
|
var needsRetry bool
|
|
for _, recipient := range devicesAddedTo {
|
|
log := log.With().Str("subaction", "skdm").Stringer("recipient_id", recipient).Logger()
|
|
_, err = cli.sendContent(log.WithContext(ctx), recipient, messageTimestamp, &signalpb.Content{
|
|
SenderKeyDistributionMessage: skdmBytes,
|
|
}, 0, true, groupID, nil)
|
|
if errors.Is(err, ErrDevicesChanged) || errors.Is(err, ErrUnregisteredUser) {
|
|
log.Warn().Err(err).Msg("Failed to send sender key distribution message due to device changes, will retry")
|
|
needsRetry = true
|
|
} else if err != nil {
|
|
log.Err(err).Msg("Failed to send sender key distribution message")
|
|
fallbackRecipients = append(fallbackRecipients, recipient)
|
|
delete(deviceIDs, recipient)
|
|
senderKeyRecipients = slices.DeleteFunc(senderKeyRecipients, func(tuple store.SessionAddressTuple) bool {
|
|
return tuple.ServiceID == recipient
|
|
})
|
|
} else {
|
|
log.Debug().Msg("Successfully sent sender key distribution message")
|
|
ski.SharedWith[recipient] = deviceIDs[recipient].DeviceIDs
|
|
}
|
|
}
|
|
err = cli.Store.SenderKeyStore.PutSenderKeyInfo(ctx, groupIDStr, ski)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to store updated sender key info: %w", err)
|
|
}
|
|
if needsRetry {
|
|
doUnlock()
|
|
return cli.sendToGroupWithSenderKey(ctx, groupID, allRecipients, sec, content, messageTimestamp, retries+1)
|
|
}
|
|
}
|
|
ssCiphertext, err := cli.encryptWithSenderKey(ctx, groupID, ski.DistributionID, myAddress, senderKeyRecipients, content)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for recipientID := range ski.SharedWith {
|
|
cli.addSendCache(recipientID, groupIDStr, messageTimestamp, content)
|
|
}
|
|
header := http.Header{}
|
|
header.Set("Content-Type", string(web.ContentTypeMultiRecipientMessage))
|
|
if sec.SendEndorsement != nil {
|
|
wantedEndorsements := make([]libsignalgo.GroupSendEndorsement, 0, len(deviceIDs))
|
|
for serviceID := range deviceIDs {
|
|
endorsement, ok := sec.MemberEndorsements[serviceID]
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing group send endorsement for service ID %s", serviceID.String())
|
|
}
|
|
wantedEndorsements = append(wantedEndorsements, endorsement)
|
|
}
|
|
combinedEndorsement, err := libsignalgo.GroupSendEndorsementCombine(wantedEndorsements...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to combine group send endorsements: %w", err)
|
|
}
|
|
groupSendToken, err := sec.GetTokenWith(combinedEndorsement)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create group send full token: %w", err)
|
|
}
|
|
header.Set("Group-Send-Token", groupSendToken.String())
|
|
} else {
|
|
header.Set("Unidentified-Access-Key", xak.String())
|
|
}
|
|
path := fmt.Sprintf(
|
|
"/v1/messages/multi_recipient?ts=%d&urgent=%t&online=false",
|
|
messageTimestamp, isUrgent(content),
|
|
)
|
|
log.Debug().
|
|
Any("recipients", ski.SharedWith).
|
|
Any("fallback_recipients", fallbackRecipients).
|
|
Msg("Sending multi-recipient message with sender key")
|
|
resp, err := cli.UnauthedWS.SendRequest(ctx, http.MethodPut, path, ssCiphertext, header)
|
|
switch resp.GetStatus() {
|
|
case 200:
|
|
var respData MultiRecipient200Response
|
|
err = json.Unmarshal(resp.Body, &respData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal 200 response: %w", err)
|
|
}
|
|
log.Debug().
|
|
Any("response_data", respData).
|
|
Msg("Got successful multi-recipient send response")
|
|
for serviceID := range deviceIDs {
|
|
if slices.Contains(respData.UUIDs404, serviceID) {
|
|
err = cli.Store.ACISessionStore.RemoveAllSessionsForServiceID(ctx, serviceID)
|
|
if err != nil {
|
|
log.Err(err).Stringer("recipient_id", serviceID).
|
|
Msg("Failed to remove sessions after 404")
|
|
}
|
|
cli.Store.RecipientStore.MarkUnregistered(ctx, serviceID, true)
|
|
result.FailedToSendTo = append(result.FailedToSendTo, FailedSendResult{
|
|
Recipient: serviceID,
|
|
Error: fmt.Errorf("multi-recipient send 404"),
|
|
})
|
|
} else {
|
|
result.SuccessfullySentTo = append(result.SuccessfullySentTo, SuccessfulSendResult{
|
|
Recipient: serviceID,
|
|
Unidentified: true,
|
|
})
|
|
}
|
|
}
|
|
doUnlock()
|
|
// Send with fallback for any recipients that couldn't do sender key, plus our own sync copy
|
|
return cli.sendToGroup(ctx, fallbackRecipients, content, messageTimestamp, result, groupID)
|
|
case 401, 404:
|
|
log.Warn().Uint32("status_code", resp.GetStatus()).
|
|
Msg("Multi-recipient send failed, falling back to normal send")
|
|
doUnlock()
|
|
// Fall back to normal send for all recipients
|
|
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, nil, groupID)
|
|
case 409, 410:
|
|
log.Warn().Uint32("status_code", resp.GetStatus()).
|
|
Msg("Multi-recipient send failed due to outdated device list, refreshing and retrying")
|
|
err = cli.handleMultiRecipient409410Response(ctx, resp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
doUnlock()
|
|
// Retry recursively after fixing device lists
|
|
return cli.sendToGroupWithSenderKey(ctx, groupID, allRecipients, sec, content, messageTimestamp, retries+1)
|
|
default:
|
|
return nil, fmt.Errorf("unexpected status code %d in multi-recipient send", resp.GetStatus())
|
|
}
|
|
}
|
|
|
|
func (cli *Client) encryptWithSenderKey(
|
|
ctx context.Context,
|
|
groupID *libsignalgo.GroupIdentifier,
|
|
distributionID uuid.UUID,
|
|
myAddress *libsignalgo.Address,
|
|
senderKeyRecipients []store.SessionAddressTuple,
|
|
content *signalpb.Content,
|
|
) ([]byte, error) {
|
|
plaintext, err := proto.Marshal(content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal content: %w", err)
|
|
}
|
|
plaintext, err = addPadding(3, plaintext)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to add padding: %w", err)
|
|
}
|
|
ciphertext, err := libsignalgo.GroupEncrypt(ctx, plaintext, myAddress, distributionID, cli.Store.SenderKeyStore)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encrypt group message: %w", err)
|
|
}
|
|
cert, err := cli.senderCertificate(ctx, false)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get sender certificate: %w", err)
|
|
}
|
|
usmc, err := libsignalgo.NewUnidentifiedSenderMessageContent(ciphertext, cert, getContentHint(content), groupID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create unidentified sender message content: %w", err)
|
|
}
|
|
ssCiphertext, err := libsignalgo.SealedSenderMultiRecipientEncrypt(ctx, usmc, senderKeyRecipients, cli.Store.ACIIdentityStore)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create sealed sender multi-recipient message: %w", err)
|
|
}
|
|
return ssCiphertext, nil
|
|
}
|
|
|
|
func diffRecipients(
|
|
prevDevices map[libsignalgo.ServiceID][]int,
|
|
newDevices map[libsignalgo.ServiceID]senderKeySendMeta,
|
|
) (
|
|
xak *libsignalgo.AccessKey,
|
|
devicesAddedTo []libsignalgo.ServiceID,
|
|
globalRemovedDevices map[libsignalgo.ServiceID][]int,
|
|
) {
|
|
collector := make(map[libsignalgo.ServiceID]uint8, max(len(prevDevices), len(newDevices)))
|
|
for key := range prevDevices {
|
|
collector[key] |= 0b01
|
|
}
|
|
for key := range newDevices {
|
|
collector[key] |= 0b10
|
|
}
|
|
globalRemovedDevices = make(map[libsignalgo.ServiceID][]int)
|
|
for serviceID, mask := range collector {
|
|
if mask != 0b01 {
|
|
xak = xak.Xor(newDevices[serviceID].AccessKey)
|
|
}
|
|
switch mask {
|
|
case 0b01:
|
|
// Someone left the group
|
|
globalRemovedDevices[serviceID] = prevDevices[serviceID]
|
|
case 0b10:
|
|
// Someone was added to the group
|
|
devicesAddedTo = append(devicesAddedTo, serviceID)
|
|
case 0b11:
|
|
removedDevices, addedDevices := exslices.Diff(prevDevices[serviceID], newDevices[serviceID].DeviceIDs)
|
|
if len(removedDevices) > 0 {
|
|
// Device was removed
|
|
globalRemovedDevices[serviceID] = removedDevices
|
|
} else if len(addedDevices) > 0 {
|
|
// User got new devices
|
|
devicesAddedTo = append(devicesAddedTo, serviceID)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
type senderKeySendMeta struct {
|
|
DeviceIDs []int
|
|
AccessKey *libsignalgo.AccessKey
|
|
}
|
|
|
|
func (cli *Client) getDevicesIDs(
|
|
ctx context.Context,
|
|
recipients []libsignalgo.ServiceID,
|
|
sendEndorsement SendEndorsementCache,
|
|
result *GroupMessageSendResult,
|
|
) (
|
|
map[libsignalgo.ServiceID]senderKeySendMeta,
|
|
[]store.SessionAddressTuple,
|
|
[]libsignalgo.ServiceID,
|
|
) {
|
|
log := zerolog.Ctx(ctx)
|
|
out := make(map[libsignalgo.ServiceID]senderKeySendMeta)
|
|
senderKeyRecipients := make([]store.SessionAddressTuple, 0, len(recipients))
|
|
fallbackRecipients := make([]libsignalgo.ServiceID, 0)
|
|
for _, recipient := range recipients {
|
|
if recipient == cli.Store.ACIServiceID() {
|
|
// We'll send a sync copy to ourselves, not sender key and no need to include in fallback recipients either
|
|
continue
|
|
}
|
|
fallbackRecipients = append(fallbackRecipients, recipient)
|
|
if recipient.Type != libsignalgo.ServiceIDTypeACI {
|
|
continue
|
|
}
|
|
_, hasEndorsement := sendEndorsement.MemberEndorsements[recipient]
|
|
if !hasEndorsement {
|
|
continue
|
|
}
|
|
profileKey, err := cli.Store.RecipientStore.LoadProfileKey(ctx, recipient.UUID)
|
|
if err != nil {
|
|
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to get profile key")
|
|
continue
|
|
} else if profileKey == nil {
|
|
log.Debug().Stringer("recipient_id", recipient.UUID).Msg("No profile key for recipient")
|
|
continue
|
|
}
|
|
accessKey, err := profileKey.DeriveAccessKey()
|
|
if err != nil {
|
|
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to derive access key")
|
|
continue
|
|
}
|
|
sessions, err := cli.Store.ACISessionStore.AllSessionsForServiceID(ctx, recipient)
|
|
if err == nil && len(sessions) == 0 {
|
|
// No sessions, make one with prekey
|
|
err = cli.FetchAndProcessPreKey(ctx, recipient, -1)
|
|
if errors.Is(err, ErrUnregisteredUser) {
|
|
fallbackRecipients = fallbackRecipients[:len(fallbackRecipients)-1]
|
|
result.FailedToSendTo = append(result.FailedToSendTo, FailedSendResult{
|
|
Recipient: recipient,
|
|
Error: err,
|
|
})
|
|
log.Debug().
|
|
Stringer("recipient_id", recipient).
|
|
Msg("Recipient is not registered, won't try to send")
|
|
continue
|
|
} else if err != nil {
|
|
log.Warn().Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to fetch keys for recipient")
|
|
continue
|
|
}
|
|
sessions, err = cli.Store.ACISessionStore.AllSessionsForServiceID(ctx, recipient)
|
|
}
|
|
if err != nil {
|
|
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to get sessions for recipient")
|
|
continue
|
|
} else if len(sessions) == 0 {
|
|
log.Debug().Stringer("recipient_id", recipient.UUID).Msg("No sessions for recipient after fetching keys")
|
|
continue
|
|
}
|
|
fallbackRecipients = fallbackRecipients[:len(fallbackRecipients)-1]
|
|
out[recipient] = senderKeySendMeta{
|
|
DeviceIDs: exslices.CastFunc(sessions, func(from store.SessionAddressTuple) int {
|
|
return from.DeviceID
|
|
}),
|
|
AccessKey: accessKey,
|
|
}
|
|
senderKeyRecipients = append(senderKeyRecipients, sessions...)
|
|
}
|
|
return out, senderKeyRecipients, fallbackRecipients
|
|
}
|