1
0
Fork 0
mirror of https://github.com/mautrix/signal.git synced 2026-05-15 05:36:53 -04:00
mautrix-signal/pkg/signalmeow/senderkey.go

431 lines
16 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2025 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package signalmeow
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"slices"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/exslices"
"google.golang.org/protobuf/proto"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
)
const SenderKeyMaxAge = 14 * 24 * time.Hour
type contextKey int
const (
contextKeyEncryptionLock contextKey = iota
)
func (cli *Client) ResetSenderKey(ctx context.Context, groupID types.GroupIdentifier) (uuid.UUID, error) {
cli.encryptionLock.Lock()
defer cli.encryptionLock.Unlock()
info, err := cli.Store.SenderKeyStore.GetSenderKeyInfo(ctx, groupID)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to get sender key info: %w", err)
} else if info == nil {
return uuid.Nil, nil
} else if myAddress, err := cli.Store.ACIServiceID().Address(uint(cli.Store.DeviceID)); err != nil {
return uuid.Nil, fmt.Errorf("failed to get own address: %w", err)
} else if err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, info.DistributionID); err != nil {
return info.DistributionID, fmt.Errorf("failed to delete sender key: %w", err)
} else if err = cli.Store.SenderKeyStore.DeleteSenderKeyInfo(ctx, groupID); err != nil {
return info.DistributionID, fmt.Errorf("failed to delete sender key info: %w", err)
}
return info.DistributionID, nil
}
func (cli *Client) sendToGroupWithSenderKey(
ctx context.Context,
groupID *libsignalgo.GroupIdentifier,
allRecipients []libsignalgo.ServiceID,
sec SendEndorsementCache,
content *signalpb.Content,
messageTimestamp uint64,
retries int,
) (*GroupMessageSendResult, error) {
if retries >= 3 {
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, nil, groupID)
}
myAddress, err := cli.Store.ACIServiceID().Address(uint(cli.Store.DeviceID))
if err != nil {
return nil, fmt.Errorf("failed to get own address: %w", err)
}
log := zerolog.Ctx(ctx)
cli.encryptionLock.Lock()
unlocked := false
doUnlock := func() {
if !unlocked {
unlocked = true
cli.encryptionLock.Unlock()
}
}
defer doUnlock()
ctx = context.WithValue(ctx, contextKeyEncryptionLock, true)
result := &GroupMessageSendResult{
SuccessfullySentTo: make([]SuccessfulSendResult, 0),
FailedToSendTo: make([]FailedSendResult, 0),
}
groupIDStr := types.GroupIdentifier(groupID.String())
deviceIDs, senderKeyRecipients, fallbackRecipients := cli.getDevicesIDs(ctx, allRecipients, sec, result)
if len(senderKeyRecipients) == 0 {
doUnlock()
log.Debug().Msg("No sender key recipients, falling back to normal send")
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, result, groupID)
}
ski, err := cli.Store.SenderKeyStore.GetSenderKeyInfo(ctx, groupIDStr)
if err != nil {
return nil, fmt.Errorf("failed to get sender key info: %w", err)
} else if ski == nil || time.Since(ski.CreatedAt) > SenderKeyMaxAge {
if ski != nil && time.Since(ski.CreatedAt) > SenderKeyMaxAge {
log.Debug().Any("old_sender_key_info", ski).Msg("Sender key expired, creating new one")
err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, ski.DistributionID)
if err != nil {
return nil, fmt.Errorf("failed to delete old sender key: %w", err)
}
} else {
log.Debug().Msg("No existing sender key, creating new one")
}
ski = &store.SenderKeyInfo{
DistributionID: uuid.New(),
CreatedAt: time.Now(),
SharedWith: make(map[libsignalgo.ServiceID][]int),
}
} else {
log.Debug().Any("sender_key_info", ski).Msg("Reusing existing sender key")
}
xak, devicesAddedTo, removedDevices := diffRecipients(ski.SharedWith, deviceIDs)
if len(removedDevices) > 0 {
log.Debug().
Any("removed_devices", removedDevices).
Msg("Resetting sender key due to recipient device changes")
devicesAddedTo = slices.Collect(maps.Keys(deviceIDs))
err = cli.Store.SenderKeyStore.DeleteSenderKey(ctx, myAddress, ski.DistributionID)
if err != nil {
return nil, fmt.Errorf("failed to delete old sender key: %w", err)
}
}
if len(devicesAddedTo) > 0 {
log.Debug().
Any("devices_added_to", devicesAddedTo).
Msg("Sending sender key distribution message to users with new devices")
skdm, err := libsignalgo.NewSenderKeyDistributionMessage(ctx, myAddress, ski.DistributionID, cli.Store.SenderKeyStore)
if err != nil {
return nil, fmt.Errorf("failed to create sender key distribution message: %w", err)
}
skdmBytes, err := skdm.Serialize()
if err != nil {
return nil, fmt.Errorf("failed to serialize sender key distribution message: %w", err)
}
var needsRetry bool
for _, recipient := range devicesAddedTo {
log := log.With().Str("subaction", "skdm").Stringer("recipient_id", recipient).Logger()
_, err = cli.sendContent(log.WithContext(ctx), recipient, messageTimestamp, &signalpb.Content{
SenderKeyDistributionMessage: skdmBytes,
}, 0, true, groupID, nil)
if errors.Is(err, ErrDevicesChanged) || errors.Is(err, ErrUnregisteredUser) {
log.Warn().Err(err).Msg("Failed to send sender key distribution message due to device changes, will retry")
needsRetry = true
} else if err != nil {
log.Err(err).Msg("Failed to send sender key distribution message")
fallbackRecipients = append(fallbackRecipients, recipient)
delete(deviceIDs, recipient)
senderKeyRecipients = slices.DeleteFunc(senderKeyRecipients, func(tuple store.SessionAddressTuple) bool {
return tuple.ServiceID == recipient
})
} else {
log.Debug().Msg("Successfully sent sender key distribution message")
ski.SharedWith[recipient] = deviceIDs[recipient].DeviceIDs
}
}
err = cli.Store.SenderKeyStore.PutSenderKeyInfo(ctx, groupIDStr, ski)
if err != nil {
return nil, fmt.Errorf("failed to store updated sender key info: %w", err)
}
if needsRetry {
doUnlock()
return cli.sendToGroupWithSenderKey(ctx, groupID, allRecipients, sec, content, messageTimestamp, retries+1)
}
}
ssCiphertext, err := cli.encryptWithSenderKey(ctx, groupID, ski.DistributionID, myAddress, senderKeyRecipients, content)
if err != nil {
return nil, err
}
for recipientID := range ski.SharedWith {
cli.addSendCache(recipientID, groupIDStr, messageTimestamp, content)
}
header := http.Header{}
header.Set("Content-Type", string(web.ContentTypeMultiRecipientMessage))
if sec.SendEndorsement != nil {
wantedEndorsements := make([]libsignalgo.GroupSendEndorsement, 0, len(deviceIDs))
for serviceID := range deviceIDs {
endorsement, ok := sec.MemberEndorsements[serviceID]
if !ok {
return nil, fmt.Errorf("missing group send endorsement for service ID %s", serviceID.String())
}
wantedEndorsements = append(wantedEndorsements, endorsement)
}
combinedEndorsement, err := libsignalgo.GroupSendEndorsementCombine(wantedEndorsements...)
if err != nil {
return nil, fmt.Errorf("failed to combine group send endorsements: %w", err)
}
groupSendToken, err := sec.GetTokenWith(combinedEndorsement)
if err != nil {
return nil, fmt.Errorf("failed to create group send full token: %w", err)
}
header.Set("Group-Send-Token", groupSendToken.String())
} else {
header.Set("Unidentified-Access-Key", xak.String())
}
path := fmt.Sprintf(
"/v1/messages/multi_recipient?ts=%d&urgent=%t&online=false",
messageTimestamp, isUrgent(content),
)
log.Debug().
Any("recipients", ski.SharedWith).
Any("fallback_recipients", fallbackRecipients).
Msg("Sending multi-recipient message with sender key")
resp, err := cli.UnauthedWS.SendRequest(ctx, http.MethodPut, path, ssCiphertext, header)
switch resp.GetStatus() {
case 200:
var respData MultiRecipient200Response
err = json.Unmarshal(resp.Body, &respData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal 200 response: %w", err)
}
log.Debug().
Any("response_data", respData).
Msg("Got successful multi-recipient send response")
for serviceID := range deviceIDs {
if slices.Contains(respData.UUIDs404, serviceID) {
err = cli.Store.ACISessionStore.RemoveAllSessionsForServiceID(ctx, serviceID)
if err != nil {
log.Err(err).Stringer("recipient_id", serviceID).
Msg("Failed to remove sessions after 404")
}
cli.Store.RecipientStore.MarkUnregistered(ctx, serviceID, true)
result.FailedToSendTo = append(result.FailedToSendTo, FailedSendResult{
Recipient: serviceID,
Error: fmt.Errorf("multi-recipient send 404"),
})
} else {
result.SuccessfullySentTo = append(result.SuccessfullySentTo, SuccessfulSendResult{
Recipient: serviceID,
Unidentified: true,
})
}
}
doUnlock()
// Send with fallback for any recipients that couldn't do sender key, plus our own sync copy
return cli.sendToGroup(ctx, fallbackRecipients, content, messageTimestamp, result, groupID)
case 401, 404:
log.Warn().Uint32("status_code", resp.GetStatus()).
Msg("Multi-recipient send failed, falling back to normal send")
doUnlock()
// Fall back to normal send for all recipients
return cli.sendToGroup(ctx, allRecipients, content, messageTimestamp, nil, groupID)
case 409, 410:
log.Warn().Uint32("status_code", resp.GetStatus()).
Msg("Multi-recipient send failed due to outdated device list, refreshing and retrying")
err = cli.handleMultiRecipient409410Response(ctx, resp)
if err != nil {
return nil, err
}
doUnlock()
// Retry recursively after fixing device lists
return cli.sendToGroupWithSenderKey(ctx, groupID, allRecipients, sec, content, messageTimestamp, retries+1)
default:
return nil, fmt.Errorf("unexpected status code %d in multi-recipient send", resp.GetStatus())
}
}
func (cli *Client) encryptWithSenderKey(
ctx context.Context,
groupID *libsignalgo.GroupIdentifier,
distributionID uuid.UUID,
myAddress *libsignalgo.Address,
senderKeyRecipients []store.SessionAddressTuple,
content *signalpb.Content,
) ([]byte, error) {
plaintext, err := proto.Marshal(content)
if err != nil {
return nil, fmt.Errorf("failed to marshal content: %w", err)
}
plaintext, err = addPadding(3, plaintext)
if err != nil {
return nil, fmt.Errorf("failed to add padding: %w", err)
}
ciphertext, err := libsignalgo.GroupEncrypt(ctx, plaintext, myAddress, distributionID, cli.Store.SenderKeyStore)
if err != nil {
return nil, fmt.Errorf("failed to encrypt group message: %w", err)
}
cert, err := cli.senderCertificate(ctx, false)
if err != nil {
return nil, fmt.Errorf("failed to get sender certificate: %w", err)
}
usmc, err := libsignalgo.NewUnidentifiedSenderMessageContent(ciphertext, cert, getContentHint(content), groupID)
if err != nil {
return nil, fmt.Errorf("failed to create unidentified sender message content: %w", err)
}
ssCiphertext, err := libsignalgo.SealedSenderMultiRecipientEncrypt(ctx, usmc, senderKeyRecipients, cli.Store.ACIIdentityStore)
if err != nil {
return nil, fmt.Errorf("failed to create sealed sender multi-recipient message: %w", err)
}
return ssCiphertext, nil
}
func diffRecipients(
prevDevices map[libsignalgo.ServiceID][]int,
newDevices map[libsignalgo.ServiceID]senderKeySendMeta,
) (
xak *libsignalgo.AccessKey,
devicesAddedTo []libsignalgo.ServiceID,
globalRemovedDevices map[libsignalgo.ServiceID][]int,
) {
collector := make(map[libsignalgo.ServiceID]uint8, max(len(prevDevices), len(newDevices)))
for key := range prevDevices {
collector[key] |= 0b01
}
for key := range newDevices {
collector[key] |= 0b10
}
globalRemovedDevices = make(map[libsignalgo.ServiceID][]int)
for serviceID, mask := range collector {
if mask != 0b01 {
xak = xak.Xor(newDevices[serviceID].AccessKey)
}
switch mask {
case 0b01:
// Someone left the group
globalRemovedDevices[serviceID] = prevDevices[serviceID]
case 0b10:
// Someone was added to the group
devicesAddedTo = append(devicesAddedTo, serviceID)
case 0b11:
removedDevices, addedDevices := exslices.Diff(prevDevices[serviceID], newDevices[serviceID].DeviceIDs)
if len(removedDevices) > 0 {
// Device was removed
globalRemovedDevices[serviceID] = removedDevices
} else if len(addedDevices) > 0 {
// User got new devices
devicesAddedTo = append(devicesAddedTo, serviceID)
}
}
}
return
}
type senderKeySendMeta struct {
DeviceIDs []int
AccessKey *libsignalgo.AccessKey
}
func (cli *Client) getDevicesIDs(
ctx context.Context,
recipients []libsignalgo.ServiceID,
sendEndorsement SendEndorsementCache,
result *GroupMessageSendResult,
) (
map[libsignalgo.ServiceID]senderKeySendMeta,
[]store.SessionAddressTuple,
[]libsignalgo.ServiceID,
) {
log := zerolog.Ctx(ctx)
out := make(map[libsignalgo.ServiceID]senderKeySendMeta)
senderKeyRecipients := make([]store.SessionAddressTuple, 0, len(recipients))
fallbackRecipients := make([]libsignalgo.ServiceID, 0)
for _, recipient := range recipients {
if recipient == cli.Store.ACIServiceID() {
// We'll send a sync copy to ourselves, not sender key and no need to include in fallback recipients either
continue
}
fallbackRecipients = append(fallbackRecipients, recipient)
if recipient.Type != libsignalgo.ServiceIDTypeACI {
continue
}
_, hasEndorsement := sendEndorsement.MemberEndorsements[recipient]
if !hasEndorsement {
continue
}
profileKey, err := cli.Store.RecipientStore.LoadProfileKey(ctx, recipient.UUID)
if err != nil {
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to get profile key")
continue
} else if profileKey == nil {
log.Debug().Stringer("recipient_id", recipient.UUID).Msg("No profile key for recipient")
continue
}
accessKey, err := profileKey.DeriveAccessKey()
if err != nil {
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to derive access key")
continue
}
sessions, err := cli.Store.ACISessionStore.AllSessionsForServiceID(ctx, recipient)
if err == nil && len(sessions) == 0 {
// No sessions, make one with prekey
err = cli.FetchAndProcessPreKey(ctx, recipient, -1)
if errors.Is(err, ErrUnregisteredUser) {
fallbackRecipients = fallbackRecipients[:len(fallbackRecipients)-1]
result.FailedToSendTo = append(result.FailedToSendTo, FailedSendResult{
Recipient: recipient,
Error: err,
})
log.Debug().
Stringer("recipient_id", recipient).
Msg("Recipient is not registered, won't try to send")
continue
} else if err != nil {
log.Warn().Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to fetch keys for recipient")
continue
}
sessions, err = cli.Store.ACISessionStore.AllSessionsForServiceID(ctx, recipient)
}
if err != nil {
log.Err(err).Stringer("recipient_id", recipient.UUID).Msg("Failed to get sessions for recipient")
continue
} else if len(sessions) == 0 {
log.Debug().Stringer("recipient_id", recipient.UUID).Msg("No sessions for recipient after fetching keys")
continue
}
fallbackRecipients = fallbackRecipients[:len(fallbackRecipients)-1]
out[recipient] = senderKeySendMeta{
DeviceIDs: exslices.CastFunc(sessions, func(from store.SessionAddressTuple) int {
return from.DeviceID
}),
AccessKey: accessKey,
}
senderKeyRecipients = append(senderKeyRecipients, sessions...)
}
return out, senderKeyRecipients, fallbackRecipients
}