mirror of
https://github.com/mautrix/signal.git
synced 2026-05-14 13:16:54 -04:00
Compare commits
1 commit
main
...
adam/codex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d95cb01d83 |
4 changed files with 341 additions and 93 deletions
|
|
@ -29,6 +29,7 @@ import (
|
||||||
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
||||||
"go.mau.fi/mautrix-signal/pkg/signalid"
|
"go.mau.fi/mautrix-signal/pkg/signalid"
|
||||||
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/store"
|
||||||
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -65,95 +66,109 @@ func (s *SignalClient) syncChats(ctx context.Context) {
|
||||||
}
|
}
|
||||||
zerolog.Ctx(ctx).Info().Int("chat_count", len(chats)).Msg("Fetched chats to sync from database")
|
zerolog.Ctx(ctx).Info().Int("chat_count", len(chats)).Msg("Fetched chats to sync from database")
|
||||||
for _, chat := range chats {
|
for _, chat := range chats {
|
||||||
recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId)
|
if !s.syncChat(ctx, chat) {
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
resyncEvt := &simplevent.ChatResync{
|
|
||||||
EventMeta: simplevent.EventMeta{
|
|
||||||
Type: bridgev2.RemoteEventChatResync,
|
|
||||||
LogContext: func(c zerolog.Context) zerolog.Context {
|
|
||||||
return c.
|
|
||||||
Int("message_count", chat.TotalMessages).
|
|
||||||
Uint64("backup_chat_id", chat.Id).
|
|
||||||
Uint64("backup_recipient_id", chat.RecipientId)
|
|
||||||
},
|
|
||||||
CreatePortal: true,
|
|
||||||
},
|
|
||||||
LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)),
|
|
||||||
}
|
|
||||||
switch dest := recipient.Destination.(type) {
|
|
||||||
case *backuppb.Recipient_Contact:
|
|
||||||
aci := tryCastUUID(dest.Contact.GetAci())
|
|
||||||
pni := tryCastUUID(dest.Contact.GetPni())
|
|
||||||
if chat.TotalMessages == 0 {
|
|
||||||
zerolog.Ctx(ctx).Debug().
|
|
||||||
Stringer("aci", aci).
|
|
||||||
Stringer("pni", pni).
|
|
||||||
Uint64("e164", dest.Contact.GetE164()).
|
|
||||||
Msg("Skipping direct chat with no messages and deleting data")
|
|
||||||
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil)
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
|
|
||||||
resyncEvt.PortalKey = dmInfo.PortalKey
|
|
||||||
resyncEvt.ChatInfo = dmInfo.PortalInfo
|
|
||||||
case *backuppb.Recipient_Self:
|
|
||||||
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
|
|
||||||
resyncEvt.PortalKey = dmInfo.PortalKey
|
|
||||||
resyncEvt.ChatInfo = dmInfo.PortalInfo
|
|
||||||
case *backuppb.Recipient_Group:
|
|
||||||
if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier()
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).
|
|
||||||
Uint64("recipient_id", recipient.Id).
|
|
||||||
Msg("Failed to get group identifier from master key")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:]))
|
|
||||||
groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat)
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
resyncEvt.PortalKey = s.makePortalKey(string(groupID))
|
|
||||||
resyncEvt.ChatInfo = groupInfo
|
|
||||||
default:
|
|
||||||
zerolog.Ctx(ctx).Debug().
|
|
||||||
Type("destination_type", dest).
|
|
||||||
Uint64("backup_chat_id", chat.Id).
|
|
||||||
Uint64("backup_recipient_id", chat.RecipientId).
|
|
||||||
Msg("Ignoring and deleting chat with unsupported destination type")
|
|
||||||
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
|
|
||||||
if err != nil {
|
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !s.UserLogin.QueueRemoteEvent(resyncEvt).Success {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// TODO if Save fails, ChatsSynced remains true in memory even though it wasn't persisted.
|
||||||
|
// Fixing that properly likely needs a broader metadata mutation/rollback pattern.
|
||||||
s.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced = true
|
s.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced = true
|
||||||
err = s.UserLogin.Save(ctx)
|
err = s.UserLogin.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to save user login metadata after syncing chats")
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to save user login metadata after syncing chats")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SignalClient) syncChat(ctx context.Context, chat *store.BackupChat) bool {
|
||||||
|
recipient, err := s.Client.Store.BackupStore.GetBackupRecipient(ctx, chat.RecipientId)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to get recipient for chat")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
} else if recipient == nil {
|
||||||
|
zerolog.Ctx(ctx).Debug().
|
||||||
|
Uint64("backup_chat_id", chat.Id).
|
||||||
|
Uint64("backup_recipient_id", chat.RecipientId).
|
||||||
|
Msg("Skipping chat with missing backup recipient")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
resyncEvt := &simplevent.ChatResync{
|
||||||
|
EventMeta: simplevent.EventMeta{
|
||||||
|
Type: bridgev2.RemoteEventChatResync,
|
||||||
|
LogContext: func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.
|
||||||
|
Int("message_count", chat.TotalMessages).
|
||||||
|
Uint64("backup_chat_id", chat.Id).
|
||||||
|
Uint64("backup_recipient_id", chat.RecipientId)
|
||||||
|
},
|
||||||
|
CreatePortal: true,
|
||||||
|
},
|
||||||
|
LatestMessageTS: time.UnixMilli(int64(chat.LatestMessageID)),
|
||||||
|
}
|
||||||
|
switch dest := recipient.Destination.(type) {
|
||||||
|
case *backuppb.Recipient_Contact:
|
||||||
|
aci := tryCastUUID(dest.Contact.GetAci())
|
||||||
|
pni := tryCastUUID(dest.Contact.GetPni())
|
||||||
|
if chat.TotalMessages == 0 {
|
||||||
|
zerolog.Ctx(ctx).Debug().
|
||||||
|
Stringer("aci", aci).
|
||||||
|
Stringer("pni", pni).
|
||||||
|
Uint64("e164", dest.Contact.GetE164()).
|
||||||
|
Msg("Skipping direct chat with no messages and deleting data")
|
||||||
|
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, aci, pni, nil)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
|
||||||
|
resyncEvt.PortalKey = dmInfo.PortalKey
|
||||||
|
resyncEvt.ChatInfo = dmInfo.PortalInfo
|
||||||
|
case *backuppb.Recipient_Self:
|
||||||
|
processedRecipient, err := s.Client.Store.RecipientStore.LoadAndUpdateRecipient(ctx, s.Client.Store.ACI, uuid.Nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full recipient data")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
dmInfo := s.makeCreateDMResponse(ctx, processedRecipient, chat)
|
||||||
|
resyncEvt.PortalKey = dmInfo.PortalKey
|
||||||
|
resyncEvt.ChatInfo = dmInfo.PortalInfo
|
||||||
|
case *backuppb.Recipient_Group:
|
||||||
|
if len(dest.Group.MasterKey) != libsignalgo.GroupMasterKeyLength {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
rawGroupID, err := libsignalgo.GroupMasterKey(dest.Group.MasterKey).GroupIdentifier()
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).
|
||||||
|
Uint64("recipient_id", recipient.Id).
|
||||||
|
Msg("Failed to get group identifier from master key")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
groupID := types.GroupIdentifier(base64.StdEncoding.EncodeToString(rawGroupID[:]))
|
||||||
|
groupInfo, err := s.getGroupInfo(ctx, groupID, dest.Group.GetSnapshot().GetVersion(), chat)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to get full group info")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
resyncEvt.PortalKey = s.makePortalKey(string(groupID))
|
||||||
|
resyncEvt.ChatInfo = groupInfo
|
||||||
|
default:
|
||||||
|
zerolog.Ctx(ctx).Debug().
|
||||||
|
Type("destination_type", dest).
|
||||||
|
Uint64("backup_chat_id", chat.Id).
|
||||||
|
Uint64("backup_recipient_id", chat.RecipientId).
|
||||||
|
Msg("Ignoring and deleting chat with unsupported destination type")
|
||||||
|
err = s.Client.Store.BackupStore.DeleteBackupChat(ctx, chat.Id)
|
||||||
|
if err != nil {
|
||||||
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete chat from backup store")
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return s.UserLogin.QueueRemoteEvent(resyncEvt).Success
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,9 @@ package connector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
@ -39,6 +41,7 @@ type SignalClient struct {
|
||||||
Ghost *bridgev2.Ghost
|
Ghost *bridgev2.Ghost
|
||||||
|
|
||||||
queueEmptyWaiter *exsync.Event
|
queueEmptyWaiter *exsync.Event
|
||||||
|
lifecycleCancel atomic.Pointer[context.CancelFunc]
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -75,6 +78,7 @@ func (s *SignalClient) RegisterPushNotifications(ctx context.Context, pushType b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignalClient) LogoutRemote(ctx context.Context) {
|
func (s *SignalClient) LogoutRemote(ctx context.Context) {
|
||||||
|
s.cancelLifecycleContext()
|
||||||
if s.Client == nil {
|
if s.Client == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -182,6 +186,7 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec
|
||||||
} else {
|
} else {
|
||||||
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: err.Error()})
|
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: err.Error()})
|
||||||
}
|
}
|
||||||
|
s.cancelLifecycleContext()
|
||||||
err = s.Client.ClearKeysAndDisconnect(context.TODO())
|
err = s.Client.ClearKeysAndDisconnect(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.UserLogin.Log.Error().Err(err).Msg("Failed to clear keys and disconnect")
|
s.UserLogin.Log.Error().Err(err).Msg("Failed to clear keys and disconnect")
|
||||||
|
|
@ -206,15 +211,6 @@ func (s *SignalClient) bridgeStateLoop(statusChan <-chan signalmeow.SignalConnec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignalClient) Connect(ctx context.Context) {
|
|
||||||
if s.Client == nil {
|
|
||||||
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.updateRemoteProfile(ctx, false)
|
|
||||||
s.tryConnect(ctx, 0, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.ConnectBackgroundParams) error {
|
func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.ConnectBackgroundParams) error {
|
||||||
s.queueEmptyWaiter.Clear()
|
s.queueEmptyWaiter.Clear()
|
||||||
ch, unauthCh, err := s.Client.StartWebsockets(ctx)
|
ch, unauthCh, err := s.Client.StartWebsockets(ctx)
|
||||||
|
|
@ -271,6 +267,7 @@ func (s *SignalClient) ConnectBackground(ctx context.Context, _ *bridgev2.Connec
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignalClient) Disconnect() {
|
func (s *SignalClient) Disconnect() {
|
||||||
|
s.cancelLifecycleContext()
|
||||||
if s.Client == nil {
|
if s.Client == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -281,7 +278,7 @@ func (s *SignalClient) Disconnect() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignalClient) postLoginConnect() {
|
func (s *SignalClient) postLoginConnect() {
|
||||||
ctx := s.UserLogin.Log.WithContext(context.Background())
|
ctx := s.newLifecycleContext(s.UserLogin.Log.WithContext(s.UserLogin.Bridge.BackgroundCtx))
|
||||||
// TODO it would be more proper to only connect after syncing,
|
// TODO it would be more proper to only connect after syncing,
|
||||||
// but currently syncing will fetch group info online, so it has to be connected.
|
// but currently syncing will fetch group info online, so it has to be connected.
|
||||||
s.tryConnect(ctx, 0, false)
|
s.tryConnect(ctx, 0, false)
|
||||||
|
|
@ -300,11 +297,19 @@ func (s *SignalClient) postLoginConnect() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bool) {
|
func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bool) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
zerolog.Ctx(ctx).Debug().Err(ctx.Err()).Msg("Context canceled before starting receive loops")
|
||||||
|
return
|
||||||
|
}
|
||||||
if retryCount == 0 {
|
if retryCount == 0 {
|
||||||
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting})
|
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting})
|
||||||
}
|
}
|
||||||
ch, err := s.Client.StartReceiveLoops(ctx)
|
ch, err := s.Client.StartReceiveLoops(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if contextStopped(ctx, err) {
|
||||||
|
zerolog.Ctx(ctx).Debug().Err(err).Msg("Context canceled while starting receive loops")
|
||||||
|
return
|
||||||
|
}
|
||||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to start receive loops")
|
zerolog.Ctx(ctx).Err(err).Msg("Failed to start receive loops")
|
||||||
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "unknown-websocket-error", Message: err.Error()})
|
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: "unknown-websocket-error", Message: err.Error()})
|
||||||
retryInSeconds := 2 << retryCount
|
retryInSeconds := 2 << retryCount
|
||||||
|
|
@ -315,7 +320,7 @@ func (s *SignalClient) tryConnect(ctx context.Context, retryCount int, doSync bo
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Duration(retryInSeconds) * time.Second):
|
case <-time.After(time.Duration(retryInSeconds) * time.Second):
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
zerolog.Ctx(ctx).Info().Msg("Context canceled, exit tryConnect")
|
zerolog.Ctx(ctx).Debug().Msg("Context canceled, exit tryConnect")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.tryConnect(ctx, retryCount+1, doSync)
|
s.tryConnect(ctx, retryCount+1, doSync)
|
||||||
|
|
@ -333,3 +338,31 @@ func (s *SignalClient) IsLoggedIn() bool {
|
||||||
}
|
}
|
||||||
return s.Client.IsLoggedIn()
|
return s.Client.IsLoggedIn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SignalClient) Connect(ctx context.Context) {
|
||||||
|
if s.Client == nil {
|
||||||
|
s.UserLogin.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You're not logged into Signal"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = s.newLifecycleContext(s.UserLogin.Log.WithContext(ctx))
|
||||||
|
s.updateRemoteProfile(ctx, false)
|
||||||
|
s.tryConnect(ctx, 0, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SignalClient) newLifecycleContext(parent context.Context) context.Context {
|
||||||
|
ctx, cancel := context.WithCancel(parent)
|
||||||
|
if oldCancel := s.lifecycleCancel.Swap(&cancel); oldCancel != nil {
|
||||||
|
(*oldCancel)()
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SignalClient) cancelLifecycleContext() {
|
||||||
|
if cancel := s.lifecycleCancel.Swap(nil); cancel != nil {
|
||||||
|
(*cancel)()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextStopped(ctx context.Context, err error) bool {
|
||||||
|
return ctx.Err() != nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
|
||||||
196
pkg/connector/client_test.go
Normal file
196
pkg/connector/client_test.go
Normal file
|
|
@ -0,0 +1,196 @@
|
||||||
|
package connector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"maunium.net/go/mautrix/bridgev2"
|
||||||
|
bridgev2database "maunium.net/go/mautrix/bridgev2/database"
|
||||||
|
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/signalid"
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/signalmeow"
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf/backuppb"
|
||||||
|
signalstore "go.mau.fi/mautrix-signal/pkg/signalmeow/store"
|
||||||
|
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLifecycleContextReplacementCancelsPrevious(t *testing.T) {
|
||||||
|
client := &SignalClient{}
|
||||||
|
|
||||||
|
firstCtx := client.newLifecycleContext(context.Background())
|
||||||
|
secondCtx := client.newLifecycleContext(context.Background())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-firstCtx.Done():
|
||||||
|
default:
|
||||||
|
t.Fatal("expected previous lifecycle context to be canceled")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-secondCtx.Done():
|
||||||
|
t.Fatal("expected current lifecycle context to remain active")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
client.cancelLifecycleContext()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-secondCtx.Done():
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("expected lifecycle context to be canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncChatsStopsOnContextCancellation(t *testing.T) {
|
||||||
|
recipientLookupStarted := make(chan struct{})
|
||||||
|
backupStore := &backupStoreStub{
|
||||||
|
getBackupChatsFn: func(context.Context) ([]*signalstore.BackupChat, error) {
|
||||||
|
return []*signalstore.BackupChat{{
|
||||||
|
Chat: &backuppb.Chat{
|
||||||
|
Id: 1,
|
||||||
|
RecipientId: 2,
|
||||||
|
},
|
||||||
|
}}, nil
|
||||||
|
},
|
||||||
|
getBackupRecipientFn: func(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) {
|
||||||
|
close(recipientLookupStarted)
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &SignalClient{
|
||||||
|
UserLogin: newTestUserLogin(),
|
||||||
|
Client: &signalmeow.Client{
|
||||||
|
Store: &signalstore.Device{
|
||||||
|
BackupStore: backupStore,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.syncChats(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-recipientLookupStarted:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for backup recipient lookup")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("syncChats did not exit after context cancellation")
|
||||||
|
}
|
||||||
|
if client.UserLogin.Metadata.(*signalid.UserLoginMetadata).ChatsSynced {
|
||||||
|
t.Fatal("expected chat sync to stop before marking metadata as synced")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncChatSkipsMissingBackupRecipient(t *testing.T) {
|
||||||
|
backupStore := &backupStoreStub{
|
||||||
|
getBackupRecipientFn: func(context.Context, uint64) (*backuppb.Recipient, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &SignalClient{
|
||||||
|
Client: &signalmeow.Client{
|
||||||
|
Store: &signalstore.Device{
|
||||||
|
BackupStore: backupStore,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := client.syncChat(context.Background(), &signalstore.BackupChat{
|
||||||
|
Chat: &backuppb.Chat{
|
||||||
|
Id: 1,
|
||||||
|
RecipientId: 2,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected missing backup recipient to be skipped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestUserLogin() *bridgev2.UserLogin {
|
||||||
|
return &bridgev2.UserLogin{
|
||||||
|
UserLogin: &bridgev2database.UserLogin{
|
||||||
|
Metadata: &signalid.UserLoginMetadata{},
|
||||||
|
},
|
||||||
|
Log: zerolog.Nop(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type backupStoreStub struct {
|
||||||
|
getBackupChatsFn func(context.Context) ([]*signalstore.BackupChat, error)
|
||||||
|
getBackupRecipientFn func(context.Context, uint64) (*backuppb.Recipient, error)
|
||||||
|
deleteBackupChatFn func(context.Context, uint64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) AddBackupRecipient(context.Context, *backuppb.Recipient) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) AddBackupChat(context.Context, *backuppb.Chat) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) AddBackupChatItem(context.Context, *backuppb.ChatItem) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) RecalculateChatCounts(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) ClearBackup(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) GetBackupRecipient(ctx context.Context, recipientID uint64) (*backuppb.Recipient, error) {
|
||||||
|
if b.getBackupRecipientFn != nil {
|
||||||
|
return b.getBackupRecipientFn(ctx, recipientID)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) GetBackupChatByUserID(context.Context, libsignalgo.ServiceID) (*signalstore.BackupChat, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) GetBackupChatByGroupID(context.Context, types.GroupIdentifier) (*signalstore.BackupChat, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) GetBackupChats(ctx context.Context) ([]*signalstore.BackupChat, error) {
|
||||||
|
if b.getBackupChatsFn != nil {
|
||||||
|
return b.getBackupChatsFn(ctx)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) GetBackupChatItems(context.Context, uint64, time.Time, bool, int) ([]*backuppb.ChatItem, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) DeleteBackupChat(ctx context.Context, chatID uint64) error {
|
||||||
|
if b.deleteBackupChatFn != nil {
|
||||||
|
return b.deleteBackupChatFn(ctx, chatID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backupStoreStub) DeleteBackupChatItems(context.Context, uint64, time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -282,7 +282,11 @@ func (cli *Client) WaitForTransfer(ctx context.Context) (*TransferArchiveMetadat
|
||||||
}
|
}
|
||||||
reqDuration := time.Since(reqStart)
|
reqDuration := time.Since(reqStart)
|
||||||
if reqDuration < reqTimeout-10*time.Second {
|
if reqDuration < reqTimeout-10*time.Second {
|
||||||
time.Sleep(15 * time.Second)
|
select {
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue