From d95cb01d83c841fdd3d4fc20e79e9d7e53f4697d Mon Sep 17 00:00:00 2001 From: Adam Van Ymeren Date: Wed, 8 Apr 2026 15:01:27 -0700 Subject: [PATCH] Cancel Signal startup sync on disconnect --- pkg/connector/chatsync.go | 181 +++++++++++++++++--------------- pkg/connector/client.go | 55 ++++++++-- pkg/connector/client_test.go | 196 +++++++++++++++++++++++++++++++++++ pkg/signalmeow/backup.go | 6 +- 4 files changed, 343 insertions(+), 95 deletions(-) create mode 100644 pkg/connector/client_test.go diff --git a/pkg/connector/chatsync.go b/pkg/connector/chatsync.go index 5211270..cadddf3 100644 --- a/pkg/connector/chatsync.go +++ b/pkg/connector/chatsync.go @@ -29,6 +29,7 @@ import ( "go.mau.fi/mautrix-signal/pkg/libsignalgo" "go.mau.fi/mautrix-signal/pkg/signalid" "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb" + "go.mau.fi/mautrix-signal/pkg/signalmeow/store" "go.mau.fi/mautrix-signal/pkg/signalmeow/types" ) @@ -65,95 +66,109 @@ func (s *SignalClient) syncChats(ctx context.Context) { } zerolog.Ctx(ctx).Info().Int("chat_count", len(chats)).Msg("Fetched chats to sync from database") for _, chat := range chats { - recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat") - continue - } - resyncEvt := &simplevent.ChatResync{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventChatResync, - LogContext: func(c zerolog.Context) zerolog.Context { - return c. - Int("message_count", chat.TotalMessages). - Uint64("backup_chat_id", chat.Id). - Uint64("backup_recipient_id", chat.RecipientId) - }, - CreatePortal: true, - }, - LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)), - } - switch dest := recipient.Destination.(type) { - case *backuppb.Recipient_Contact: - aci := tryCastUUID(dest.Contact.GetAci()) - pni := tryCastUUID(dest.Contact.GetPni()) - if chat.TotalMessages == 0 { - zerolog.Ctx(ctx).Debug(). - Stringer("aci", aci). - Stringer("pni", pni). - Uint64("e164", dest.Contact.GetE164()). - Msg("Skipping direct chat with no messages and deleting data") - err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store") - } - continue - } - processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data") - continue - } - dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat) - resyncEvt.PortalKey = dmInfo.PortalKey - resyncEvt.ChatInfo = dmInfo.PortalInfo - case *backuppb.Recipient_Self: - processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data") - continue - } - dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat) - resyncEvt.PortalKey = dmInfo.PortalKey - resyncEvt.ChatInfo = dmInfo.PortalInfo - case *backuppb.Recipient_Group: - if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength { - continue - } - rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier() - if err != nil { - zerolog.Ctx(ctx).Err(err). - Uint64("recipient_id", recipient.Id). - Msg("Failed to get group identifier from master key") - continue - } - groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:])) - groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info") - continue - } - resyncEvt.PortalKey = s.makePortalKey(string(groupID)) - resyncEvt.ChatInfo = groupInfo - default: - zerolog.Ctx(ctx).Debug(). - Type("destination_type", dest). - Uint64("backup_chat_id", chat.Id). - Uint64("backup_recipient_id", chat.RecipientId). - Msg("Ignoring and deleting chat with unsupported destination type") - err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store") - } - continue - } - if !s.UserLogin.QueueRemoteEvent(resyncEvt).Success { + if !s.syncChat(ctx, chat) { return } } + // TODO if Save fails, ChatsSynced remains true in memory even though it wasn't persisted. + // Fixing that properly likely needs a broader metadata mutation/rollback pattern. s.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced = true err = s.UserLogin.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save user login metadata after syncing chats") } } + +func (s *SignalClient) syncChat(ctx context.Context, chat *store.BackupChat) bool { + recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat") + return ctx.Err() == nil + } else if recipient == nil { + zerolog.Ctx(ctx).Debug(). + Uint64("backup_chat_id", chat.Id). + Uint64("backup_recipient_id", chat.RecipientId). + Msg("Skipping chat with missing backup recipient") + return true + } + resyncEvt := &simplevent.ChatResync{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatResync, + LogContext: func(c zerolog.Context) zerolog.Context { + return c. + Int("message_count", chat.TotalMessages). + Uint64("backup_chat_id", chat.Id). + Uint64("backup_recipient_id", chat.RecipientId) + }, + CreatePortal: true, + }, + LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)), + } + switch dest := recipient.Destination.(type) { + case *backuppb.Recipient_Contact: + aci := tryCastUUID(dest.Contact.GetAci()) + pni := tryCastUUID(dest.Contact.GetPni()) + if chat.TotalMessages == 0 { + zerolog.Ctx(ctx).Debug(). + Stringer("aci", aci). + Stringer("pni", pni). + Uint64("e164", dest.Contact.GetE164()). + Msg("Skipping direct chat with no messages and deleting data") + err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store") + return ctx.Err() == nil + } + return true + } + processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data") + return ctx.Err() == nil + } + dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat) + resyncEvt.PortalKey = dmInfo.PortalKey + resyncEvt.ChatInfo = dmInfo.PortalInfo + case *backuppb.Recipient_Self: + processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data") + return ctx.Err() == nil + } + dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat) + resyncEvt.PortalKey = dmInfo.PortalKey + resyncEvt.ChatInfo = dmInfo.PortalInfo + case *backuppb.Recipient_Group: + if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength { + return true + } + rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier() + if err != nil { + zerolog.Ctx(ctx).Err(err). + Uint64("recipient_id", recipient.Id). + Msg("Failed to get group identifier from master key") + return true + } + groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:])) + groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info") + return ctx.Err() == nil + } + resyncEvt.PortalKey = s.makePortalKey(string(groupID)) + resyncEvt.ChatInfo = groupInfo + default: + zerolog.Ctx(ctx).Debug(). + Type("destination_type", dest). + Uint64("backup_chat_id", chat.Id). + Uint64("backup_recipient_id", chat.RecipientId). + Msg("Ignoring and deleting chat with unsupported destination type") + err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store") + return ctx.Err() == nil + } + return true + } + return s.UserLogin.QueueRemoteEvent(resyncEvt).Success +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index e224c17..30cc7cc 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -18,7 +18,9 @@ package connector import ( "context" + "errors" "fmt" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -39,6 +41,7 @@ type SignalClient struct { Ghost *bridgev2.Ghost queueEmptyWaiter *exsync.Event + lifecycleCancel atomic.Pointer[context.CancelFunc] } var ( @@ -75,6 +78,7 @@ func (s *SignalClient) RegisterPushNotifications(ctx context.Context, pushType b } func (s *SignalClient) LogoutRemote(ctx context.Context) { + s.cancelLifecycleContext() if s.Client == nil { return } @@ -182,6 +186,7 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec } else { s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: err.Error()}) } + s.cancelLifecycleContext() err = s.Client.ClearKeysAndDisconnect(context.TODO()) if err != nil { s.UserLogin.Log.Error().Err(err).Msg("Failed to clear keys and disconnect") @@ -206,15 +211,6 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec } } -func (s *SignalClient) Connect(ctx context.Context) { - if s.Client == nil { - s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"}) - return - } - s.updateRemoteProfile(ctx, false) - s.tryConnect(ctx, 0, true) -} - func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.ConnectBackgroundParams) error { s.queueEmptyWaiter.Clear() ch, unauthCh, err := s.Client.StartWebsockets(ctx) @@ -271,6 +267,7 @@ func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.Connec } func (s *SignalClient) Disconnect() { + s.cancelLifecycleContext() if s.Client == nil { return } @@ -281,7 +278,7 @@ func (s *SignalClient) Disconnect() { } func (s *SignalClient) postLoginConnect() { - ctx := s.UserLogin.Log.WithContext(context.Background()) + ctx := s.newLifecycleContext(s.UserLogin.Log.WithContext(s.UserLogin.Bridge.BackgroundCtx)) // TODO it would be more proper to only connect after syncing, // but currently syncing will fetch group info online, so it has to be connected. s.tryConnect(ctx, 0, false) @@ -300,11 +297,19 @@ func (s *SignalClient) postLoginConnect() { } func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bool) { + if ctx.Err() != nil { + zerolog.Ctx(ctx).Debug().Err(ctx.Err()).Msg("Context canceled before starting receive loops") + return + } if retryCount == 0 { s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting}) } ch, err := s.Client.StartReceiveLoops(ctx) if err != nil { + if contextStopped(ctx, err) { + zerolog.Ctx(ctx).Debug().Err(err).Msg("Context canceled while starting receive loops") + return + } zerolog.Ctx(ctx).Err(err).Msg("Failed to start receive loops") s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "unknown-websocket-error", Message: err.Error()}) retryInSeconds := 2 << retryCount @@ -315,7 +320,7 @@ func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bo select { case <-time.After(time.Duration(retryInSeconds) * time.Second): case <-ctx.Done(): - zerolog.Ctx(ctx).Info().Msg("Context canceled, exit tryConnect") + zerolog.Ctx(ctx).Debug().Msg("Context canceled, exit tryConnect") return } s.tryConnect(ctx, retryCount+1, doSync) @@ -333,3 +338,31 @@ func (s *SignalClient) IsLoggedIn() bool { } return s.Client.IsLoggedIn() } + +func (s *SignalClient) Connect(ctx context.Context) { + if s.Client == nil { + s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"}) + return + } + ctx = s.newLifecycleContext(s.UserLogin.Log.WithContext(ctx)) + s.updateRemoteProfile(ctx, false) + s.tryConnect(ctx, 0, true) +} + +func (s *SignalClient) newLifecycleContext(parent context.Context) context.Context { + ctx, cancel := context.WithCancel(parent) + if oldCancel := s.lifecycleCancel.Swap(&cancel); oldCancel != nil { + (*oldCancel)() + } + return ctx +} + +func (s *SignalClient) cancelLifecycleContext() { + if cancel := s.lifecycleCancel.Swap(nil); cancel != nil { + (*cancel)() + } +} + +func contextStopped(ctx context.Context, err error) bool { + return ctx.Err() != nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +} diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go new file mode 100644 index 0000000..c5d0b11 --- /dev/null +++ b/pkg/connector/client_test.go @@ -0,0 +1,196 @@ +package connector + +import ( + "context" + "testing" + "time" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + bridgev2database "maunium.net/go/mautrix/bridgev2/database" + + "go.mau.fi/mautrix-signal/pkg/libsignalgo" + "go.mau.fi/mautrix-signal/pkg/signalid" + "go.mau.fi/mautrix-signal/pkg/signalmeow" + "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb" + signalstore "go.mau.fi/mautrix-signal/pkg/signalmeow/store" + "go.mau.fi/mautrix-signal/pkg/signalmeow/types" +) + +func TestLifecycleContextReplacementCancelsPrevious(t *testing.T) { + client := &SignalClient{} + + firstCtx := client.newLifecycleContext(context.Background()) + secondCtx := client.newLifecycleContext(context.Background()) + + select { + case <-firstCtx.Done(): + default: + t.Fatal("expected previous lifecycle context to be canceled") + } + select { + case <-secondCtx.Done(): + t.Fatal("expected current lifecycle context to remain active") + default: + } + + client.cancelLifecycleContext() + + select { + case <-secondCtx.Done(): + case <-time.After(time.Second): + t.Fatal("expected lifecycle context to be canceled") + } +} + +func TestSyncChatsStopsOnContextCancellation(t *testing.T) { + recipientLookupStarted := make(chan struct{}) + backupStore := &backupStoreStub{ + getBackupChatsFn: func(context.Context) ([]*signalstore.BackupChat, error) { + return []*signalstore.BackupChat{{ + Chat: &backuppb.Chat{ + Id: 1, + RecipientId: 2, + }, + }}, nil + }, + getBackupRecipientFn: func(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) { + close(recipientLookupStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + } + + client := &SignalClient{ + UserLogin: newTestUserLogin(), + Client: &signalmeow.Client{ + Store: &signalstore.Device{ + BackupStore: backupStore, + }, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + go func() { + client.syncChats(ctx) + close(done) + }() + + select { + case <-recipientLookupStarted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for backup recipient lookup") + } + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("syncChats did not exit after context cancellation") + } + if client.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced { + t.Fatal("expected chat sync to stop before marking metadata as synced") + } +} + +func TestSyncChatSkipsMissingBackupRecipient(t *testing.T) { + backupStore := &backupStoreStub{ + getBackupRecipientFn: func(context.Context, uint64) (*backuppb.Recipient, error) { + return nil, nil + }, + } + + client := &SignalClient{ + Client: &signalmeow.Client{ + Store: &signalstore.Device{ + BackupStore: backupStore, + }, + }, + } + + ok := client.syncChat(context.Background(), &signalstore.BackupChat{ + Chat: &backuppb.Chat{ + Id: 1, + RecipientId: 2, + }, + }) + + if !ok { + t.Fatal("expected missing backup recipient to be skipped") + } +} + +func newTestUserLogin() *bridgev2.UserLogin { + return &bridgev2.UserLogin{ + UserLogin: &bridgev2database.UserLogin{ + Metadata: &signalid.UserLoginMetadata{}, + }, + Log: zerolog.Nop(), + } +} + +type backupStoreStub struct { + getBackupChatsFn func(context.Context) ([]*signalstore.BackupChat, error) + getBackupRecipientFn func(context.Context, uint64) (*backuppb.Recipient, error) + deleteBackupChatFn func(context.Context, uint64) error +} + +func (b *backupStoreStub) AddBackupRecipient(context.Context, *backuppb.Recipient) error { + return nil +} + +func (b *backupStoreStub) AddBackupChat(context.Context, *backuppb.Chat) error { + return nil +} + +func (b *backupStoreStub) AddBackupChatItem(context.Context, *backuppb.ChatItem) error { + return nil +} + +func (b *backupStoreStub) RecalculateChatCounts(context.Context) error { + return nil +} + +func (b *backupStoreStub) ClearBackup(context.Context) error { + return nil +} + +func (b *backupStoreStub) GetBackupRecipient(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) { + if b.getBackupRecipientFn != nil { + return b.getBackupRecipientFn(ctx, recipientID) + } + return nil, nil +} + +func (b *backupStoreStub) GetBackupChatByUserID(context.Context, libsignalgo.ServiceID) (*signalstore.BackupChat, error) { + return nil, nil +} + +func (b *backupStoreStub) GetBackupChatByGroupID(context.Context, types.GroupIdentifier) (*signalstore.BackupChat, error) { + return nil, nil +} + +func (b *backupStoreStub) GetBackupChats(ctx context.Context) ([]*signalstore.BackupChat, error) { + if b.getBackupChatsFn != nil { + return b.getBackupChatsFn(ctx) + } + return nil, nil +} + +func (b *backupStoreStub) GetBackupChatItems(context.Context, uint64, time.Time, bool, int) ([]*backuppb.ChatItem, error) { + return nil, nil +} + +func (b *backupStoreStub) DeleteBackupChat(ctx context.Context, chatID uint64) error { + if b.deleteBackupChatFn != nil { + return b.deleteBackupChatFn(ctx, chatID) + } + return nil +} + +func (b *backupStoreStub) DeleteBackupChatItems(context.Context, uint64, time.Time) error { + return nil +} diff --git a/pkg/signalmeow/backup.go b/pkg/signalmeow/backup.go index fcbfff0..f003e19 100644 --- a/pkg/signalmeow/backup.go +++ b/pkg/signalmeow/backup.go @@ -282,7 +282,11 @@ func (cli *Client) WaitForTransfer(ctx context.Context) (*TransferArchiveMetadat } reqDuration := time.Since(reqStart) if reqDuration < reqTimeout-10*time.Second { - time.Sleep(15 * time.Second) + select { + case <-time.After(15 * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } } } }