1
0
Fork 0
mirror of https://github.com/mautrix/signal.git synced 2026-05-14 13:16:54 -04:00

Compare commits

...

1 commit

Author SHA1 Message Date
Adam Van Ymeren
d95cb01d83 Cancel Signal startup sync on disconnect 2026-04-08 15:01:27 -07:00
4 changed files with 341 additions and 93 deletions

View file

@ -29,6 +29,7 @@ import (
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalid"
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
)
@ -65,95 +66,109 @@ func (s *SignalClient) syncChats(ctx context.Context) {
}
zerolog.Ctx(ctx).Info().Int("chat_count", len(chats)).Msg("Fetched chats to sync from database")
for _, chat := range chats {
recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat")
continue
}
resyncEvt := &simplevent.ChatResync{
EventMeta: simplevent.EventMeta{
Type: bridgev2.RemoteEventChatResync,
LogContext: func(c zerolog.Context) zerolog.Context {
return c.
Int("message_count", chat.TotalMessages).
Uint64("backup_chat_id", chat.Id).
Uint64("backup_recipient_id", chat.RecipientId)
},
CreatePortal: true,
},
LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)),
}
switch dest := recipient.Destination.(type) {
case *backuppb.Recipient_Contact:
aci := tryCastUUID(dest.Contact.GetAci())
pni := tryCastUUID(dest.Contact.GetPni())
if chat.TotalMessages == 0 {
zerolog.Ctx(ctx).Debug().
Stringer("aci", aci).
Stringer("pni", pni).
Uint64("e164", dest.Contact.GetE164()).
Msg("Skipping direct chat with no messages and deleting data")
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
}
continue
}
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
continue
}
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
resyncEvt.PortalKey = dmInfo.PortalKey
resyncEvt.ChatInfo = dmInfo.PortalInfo
case *backuppb.Recipient_Self:
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
continue
}
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
resyncEvt.PortalKey = dmInfo.PortalKey
resyncEvt.ChatInfo = dmInfo.PortalInfo
case *backuppb.Recipient_Group:
if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength {
continue
}
rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier()
if err != nil {
zerolog.Ctx(ctx).Err(err).
Uint64("recipient_id", recipient.Id).
Msg("Failed to get group identifier from master key")
continue
}
groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:]))
groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info")
continue
}
resyncEvt.PortalKey = s.makePortalKey(string(groupID))
resyncEvt.ChatInfo = groupInfo
default:
zerolog.Ctx(ctx).Debug().
Type("destination_type", dest).
Uint64("backup_chat_id", chat.Id).
Uint64("backup_recipient_id", chat.RecipientId).
Msg("Ignoring and deleting chat with unsupported destination type")
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
}
continue
}
if !s.UserLogin.QueueRemoteEvent(resyncEvt).Success {
if !s.syncChat(ctx, chat) {
return
}
}
// TODO if Save fails, ChatsSynced remains true in memory even though it wasn't persisted.
// Fixing that properly likely needs a broader metadata mutation/rollback pattern.
s.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced = true
err = s.UserLogin.Save(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save user login metadata after syncing chats")
}
}
func (s *SignalClient) syncChat(ctx context.Context, chat *store.BackupChat) bool {
recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat")
return ctx.Err() == nil
} else if recipient == nil {
zerolog.Ctx(ctx).Debug().
Uint64("backup_chat_id", chat.Id).
Uint64("backup_recipient_id", chat.RecipientId).
Msg("Skipping chat with missing backup recipient")
return true
}
resyncEvt := &simplevent.ChatResync{
EventMeta: simplevent.EventMeta{
Type: bridgev2.RemoteEventChatResync,
LogContext: func(c zerolog.Context) zerolog.Context {
return c.
Int("message_count", chat.TotalMessages).
Uint64("backup_chat_id", chat.Id).
Uint64("backup_recipient_id", chat.RecipientId)
},
CreatePortal: true,
},
LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)),
}
switch dest := recipient.Destination.(type) {
case *backuppb.Recipient_Contact:
aci := tryCastUUID(dest.Contact.GetAci())
pni := tryCastUUID(dest.Contact.GetPni())
if chat.TotalMessages == 0 {
zerolog.Ctx(ctx).Debug().
Stringer("aci", aci).
Stringer("pni", pni).
Uint64("e164", dest.Contact.GetE164()).
Msg("Skipping direct chat with no messages and deleting data")
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
return ctx.Err() == nil
}
return true
}
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
return ctx.Err() == nil
}
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
resyncEvt.PortalKey = dmInfo.PortalKey
resyncEvt.ChatInfo = dmInfo.PortalInfo
case *backuppb.Recipient_Self:
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
return ctx.Err() == nil
}
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
resyncEvt.PortalKey = dmInfo.PortalKey
resyncEvt.ChatInfo = dmInfo.PortalInfo
case *backuppb.Recipient_Group:
if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength {
return true
}
rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier()
if err != nil {
zerolog.Ctx(ctx).Err(err).
Uint64("recipient_id", recipient.Id).
Msg("Failed to get group identifier from master key")
return true
}
groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:]))
groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info")
return ctx.Err() == nil
}
resyncEvt.PortalKey = s.makePortalKey(string(groupID))
resyncEvt.ChatInfo = groupInfo
default:
zerolog.Ctx(ctx).Debug().
Type("destination_type", dest).
Uint64("backup_chat_id", chat.Id).
Uint64("backup_recipient_id", chat.RecipientId).
Msg("Ignoring and deleting chat with unsupported destination type")
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
return ctx.Err() == nil
}
return true
}
return s.UserLogin.QueueRemoteEvent(resyncEvt).Success
}

View file

@ -18,7 +18,9 @@ package connector
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/rs/zerolog"
@ -39,6 +41,7 @@ type SignalClient struct {
Ghost *bridgev2.Ghost
queueEmptyWaiter *exsync.Event
lifecycleCancel atomic.Pointer[context.CancelFunc]
}
var (
@ -75,6 +78,7 @@ func (s *SignalClient) RegisterPushNotifications(ctx context.Context, pushType b
}
func (s *SignalClient) LogoutRemote(ctx context.Context) {
s.cancelLifecycleContext()
if s.Client == nil {
return
}
@ -182,6 +186,7 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec
} else {
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: err.Error()})
}
s.cancelLifecycleContext()
err = s.Client.ClearKeysAndDisconnect(context.TODO())
if err != nil {
s.UserLogin.Log.Error().Err(err).Msg("Failed to clear keys and disconnect")
@ -206,15 +211,6 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec
}
}
func (s *SignalClient) Connect(ctx context.Context) {
if s.Client == nil {
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"})
return
}
s.updateRemoteProfile(ctx, false)
s.tryConnect(ctx, 0, true)
}
func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.ConnectBackgroundParams) error {
s.queueEmptyWaiter.Clear()
ch, unauthCh, err := s.Client.StartWebsockets(ctx)
@ -271,6 +267,7 @@ func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.Connec
}
func (s *SignalClient) Disconnect() {
s.cancelLifecycleContext()
if s.Client == nil {
return
}
@ -281,7 +278,7 @@ func (s *SignalClient) Disconnect() {
}
func (s *SignalClient) postLoginConnect() {
ctx := s.UserLogin.Log.WithContext(context.Background())
ctx := s.newLifecycleContext(s.UserLogin.Log.WithContext(s.UserLogin.Bridge.BackgroundCtx))
// TODO it would be more proper to only connect after syncing,
// but currently syncing will fetch group info online, so it has to be connected.
s.tryConnect(ctx, 0, false)
@ -300,11 +297,19 @@ func (s *SignalClient) postLoginConnect() {
}
func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bool) {
if ctx.Err() != nil {
zerolog.Ctx(ctx).Debug().Err(ctx.Err()).Msg("Context canceled before starting receive loops")
return
}
if retryCount == 0 {
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting})
}
ch, err := s.Client.StartReceiveLoops(ctx)
if err != nil {
if contextStopped(ctx, err) {
zerolog.Ctx(ctx).Debug().Err(err).Msg("Context canceled while starting receive loops")
return
}
zerolog.Ctx(ctx).Err(err).Msg("Failed to start receive loops")
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "unknown-websocket-error", Message: err.Error()})
retryInSeconds := 2 << retryCount
@ -315,7 +320,7 @@ func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bo
select {
case <-time.After(time.Duration(retryInSeconds) * time.Second):
case <-ctx.Done():
zerolog.Ctx(ctx).Info().Msg("Context canceled, exit tryConnect")
zerolog.Ctx(ctx).Debug().Msg("Context canceled, exit tryConnect")
return
}
s.tryConnect(ctx, retryCount+1, doSync)
@ -333,3 +338,31 @@ func (s *SignalClient) IsLoggedIn() bool {
}
return s.Client.IsLoggedIn()
}
func (s *SignalClient) Connect(ctx context.Context) {
if s.Client == nil {
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"})
return
}
ctx = s.newLifecycleContext(s.UserLogin.Log.WithContext(ctx))
s.updateRemoteProfile(ctx, false)
s.tryConnect(ctx, 0, true)
}
func (s *SignalClient) newLifecycleContext(parent context.Context) context.Context {
ctx, cancel := context.WithCancel(parent)
if oldCancel := s.lifecycleCancel.Swap(&cancel); oldCancel != nil {
(*oldCancel)()
}
return ctx
}
func (s *SignalClient) cancelLifecycleContext() {
if cancel := s.lifecycleCancel.Swap(nil); cancel != nil {
(*cancel)()
}
}
func contextStopped(ctx context.Context, err error) bool {
return ctx.Err() != nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}

View file

@ -0,0 +1,196 @@
package connector
import (
"context"
"testing"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
bridgev2database "maunium.net/go/mautrix/bridgev2/database"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalid"
"go.mau.fi/mautrix-signal/pkg/signalmeow"
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
signalstore "go.mau.fi/mautrix-signal/pkg/signalmeow/store"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
)
func TestLifecycleContextReplacementCancelsPrevious(t *testing.T) {
client := &SignalClient{}
firstCtx := client.newLifecycleContext(context.Background())
secondCtx := client.newLifecycleContext(context.Background())
select {
case <-firstCtx.Done():
default:
t.Fatal("expected previous lifecycle context to be canceled")
}
select {
case <-secondCtx.Done():
t.Fatal("expected current lifecycle context to remain active")
default:
}
client.cancelLifecycleContext()
select {
case <-secondCtx.Done():
case <-time.After(time.Second):
t.Fatal("expected lifecycle context to be canceled")
}
}
func TestSyncChatsStopsOnContextCancellation(t *testing.T) {
recipientLookupStarted := make(chan struct{})
backupStore := &backupStoreStub{
getBackupChatsFn: func(context.Context) ([]*signalstore.BackupChat, error) {
return []*signalstore.BackupChat{{
Chat: &backuppb.Chat{
Id: 1,
RecipientId: 2,
},
}}, nil
},
getBackupRecipientFn: func(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) {
close(recipientLookupStarted)
<-ctx.Done()
return nil, ctx.Err()
},
}
client := &SignalClient{
UserLogin: newTestUserLogin(),
Client: &signalmeow.Client{
Store: &signalstore.Device{
BackupStore: backupStore,
},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
client.syncChats(ctx)
close(done)
}()
select {
case <-recipientLookupStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for backup recipient lookup")
}
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("syncChats did not exit after context cancellation")
}
if client.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced {
t.Fatal("expected chat sync to stop before marking metadata as synced")
}
}
func TestSyncChatSkipsMissingBackupRecipient(t *testing.T) {
backupStore := &backupStoreStub{
getBackupRecipientFn: func(context.Context, uint64) (*backuppb.Recipient, error) {
return nil, nil
},
}
client := &SignalClient{
Client: &signalmeow.Client{
Store: &signalstore.Device{
BackupStore: backupStore,
},
},
}
ok := client.syncChat(context.Background(), &signalstore.BackupChat{
Chat: &backuppb.Chat{
Id: 1,
RecipientId: 2,
},
})
if !ok {
t.Fatal("expected missing backup recipient to be skipped")
}
}
func newTestUserLogin() *bridgev2.UserLogin {
return &bridgev2.UserLogin{
UserLogin: &bridgev2database.UserLogin{
Metadata: &signalid.UserLoginMetadata{},
},
Log: zerolog.Nop(),
}
}
type backupStoreStub struct {
getBackupChatsFn func(context.Context) ([]*signalstore.BackupChat, error)
getBackupRecipientFn func(context.Context, uint64) (*backuppb.Recipient, error)
deleteBackupChatFn func(context.Context, uint64) error
}
func (b *backupStoreStub) AddBackupRecipient(context.Context, *backuppb.Recipient) error {
return nil
}
func (b *backupStoreStub) AddBackupChat(context.Context, *backuppb.Chat) error {
return nil
}
func (b *backupStoreStub) AddBackupChatItem(context.Context, *backuppb.ChatItem) error {
return nil
}
func (b *backupStoreStub) RecalculateChatCounts(context.Context) error {
return nil
}
func (b *backupStoreStub) ClearBackup(context.Context) error {
return nil
}
func (b *backupStoreStub) GetBackupRecipient(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) {
if b.getBackupRecipientFn != nil {
return b.getBackupRecipientFn(ctx, recipientID)
}
return nil, nil
}
func (b *backupStoreStub) GetBackupChatByUserID(context.Context, libsignalgo.ServiceID) (*signalstore.BackupChat, error) {
return nil, nil
}
func (b *backupStoreStub) GetBackupChatByGroupID(context.Context, types.GroupIdentifier) (*signalstore.BackupChat, error) {
return nil, nil
}
func (b *backupStoreStub) GetBackupChats(ctx context.Context) ([]*signalstore.BackupChat, error) {
if b.getBackupChatsFn != nil {
return b.getBackupChatsFn(ctx)
}
return nil, nil
}
func (b *backupStoreStub) GetBackupChatItems(context.Context, uint64, time.Time, bool, int) ([]*backuppb.ChatItem, error) {
return nil, nil
}
func (b *backupStoreStub) DeleteBackupChat(ctx context.Context, chatID uint64) error {
if b.deleteBackupChatFn != nil {
return b.deleteBackupChatFn(ctx, chatID)
}
return nil
}
func (b *backupStoreStub) DeleteBackupChatItems(context.Context, uint64, time.Time) error {
return nil
}

View file

@ -282,7 +282,11 @@ func (cli *Client) WaitForTransfer(ctx context.Context) (*TransferArchiveMetadat
}
reqDuration := time.Since(reqStart)
if reqDuration < reqTimeout-10*time.Second {
time.Sleep(15 * time.Second)
select {
case <-time.After(15 * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
}