2024-03-22 21:24:30 +02:00
// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
2025-01-30 18:11:06 +02:00
"github.com/rs/zerolog"
2024-03-22 21:24:30 +02:00
"go.mau.fi/util/dbutil"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
)
type RecipientStore interface {
LoadProfileKey ( ctx context . Context , theirACI uuid . UUID ) ( * libsignalgo . ProfileKey , error )
StoreProfileKey ( ctx context . Context , theirACI uuid . UUID , key libsignalgo . ProfileKey ) error
MyProfileKey ( ctx context . Context ) ( * libsignalgo . ProfileKey , error )
LoadAndUpdateRecipient ( ctx context . Context , aci , pni uuid . UUID , updater RecipientUpdaterFunc ) ( * types . Recipient , error )
2025-11-20 13:29:38 +02:00
IsBlocked ( ctx context . Context , aci uuid . UUID ) ( bool , error )
2024-03-22 21:24:30 +02:00
LoadRecipientByE164 ( ctx context . Context , e164 string ) ( * types . Recipient , error )
StoreRecipient ( ctx context . Context , recipient * types . Recipient ) error
UpdateRecipientE164 ( ctx context . Context , aci , pni uuid . UUID , e164 string ) ( * types . Recipient , error )
2024-06-20 17:29:55 +03:00
2025-11-28 16:51:51 +02:00
IsUnregistered ( ctx context . Context , serviceID libsignalgo . ServiceID ) bool
MarkUnregistered ( ctx context . Context , serviceID libsignalgo . ServiceID , unregistered bool )
2024-06-20 17:29:55 +03:00
LoadAllContacts ( ctx context . Context ) ( [ ] * types . Recipient , error )
2024-03-22 21:24:30 +02:00
}
var _ RecipientStore = ( * sqlStore ) ( nil )
const (
getAllRecipientsQuery = `
SELECT
aci_uuid ,
pni_uuid ,
e164_number ,
contact_name ,
contact_avatar_hash ,
2025-09-01 16:40:35 +03:00
nickname ,
2024-03-22 21:24:30 +02:00
profile_key ,
profile_name ,
profile_about ,
profile_about_emoji ,
profile_avatar_path ,
2024-03-25 15:18:25 +02:00
profile_fetched_at ,
2025-11-20 13:29:38 +02:00
needs_pni_signature ,
2025-12-11 17:53:48 +02:00
blocked ,
whitelisted
2024-03-22 21:24:30 +02:00
FROM signalmeow_recipients
WHERE account_id = $ 1
`
2024-06-20 17:29:55 +03:00
getAllRecipientsWithNameOrPhoneQuery = getAllRecipientsQuery + ` AND (contact_name <> '' OR profile_name <> '' OR e164_number <> '') `
getRecipientByACIQuery = getAllRecipientsQuery + ` AND aci_uuid = $2 `
getRecipientByPNIQuery = getAllRecipientsQuery + ` AND pni_uuid = $2 `
getRecipientByACIOrPNIQuery = getAllRecipientsQuery + ` AND (($2<>'00000000-0000-0000-0000-000000000000' AND aci_uuid = $2) OR ($3<>'00000000-0000-0000-0000-000000000000' AND pni_uuid = $3)) `
getRecipientByPhoneQuery = getAllRecipientsQuery + ` AND e164_number = $2 `
deleteRecipientByPNIQuery = ` DELETE FROM signalmeow_recipients WHERE account_id = $1 AND pni_uuid = $2 `
upsertACIRecipientQuery = `
2024-03-22 21:24:30 +02:00
INSERT INTO signalmeow_recipients (
account_id ,
aci_uuid ,
pni_uuid ,
e164_number ,
contact_name ,
contact_avatar_hash ,
2025-09-01 16:40:35 +03:00
nickname ,
2024-03-22 21:24:30 +02:00
profile_key ,
profile_name ,
profile_about ,
profile_about_emoji ,
profile_avatar_path ,
2024-03-25 15:18:25 +02:00
profile_fetched_at ,
2025-11-20 13:29:38 +02:00
needs_pni_signature ,
2025-12-11 17:53:48 +02:00
blocked ,
whitelisted
2024-03-22 21:24:30 +02:00
)
2025-12-11 17:53:48 +02:00
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 , $ 9 , $ 10 , $ 11 , $ 12 , $ 13 , $ 14 , $ 15 , $ 16 )
2024-03-22 21:24:30 +02:00
ON CONFLICT ( account_id , aci_uuid ) DO UPDATE SET
pni_uuid = excluded . pni_uuid ,
e164_number = excluded . e164_number ,
contact_name = excluded . contact_name ,
contact_avatar_hash = excluded . contact_avatar_hash ,
2025-09-01 16:40:35 +03:00
nickname = excluded . nickname ,
2024-03-22 21:24:30 +02:00
profile_key = excluded . profile_key ,
profile_name = excluded . profile_name ,
profile_about = excluded . profile_about ,
profile_about_emoji = excluded . profile_about_emoji ,
profile_avatar_path = excluded . profile_avatar_path ,
2024-03-25 15:18:25 +02:00
profile_fetched_at = excluded . profile_fetched_at ,
2025-11-20 13:29:38 +02:00
needs_pni_signature = excluded . needs_pni_signature ,
2025-12-11 17:53:48 +02:00
blocked = excluded . blocked ,
whitelisted = excluded . whitelisted
2024-03-22 21:24:30 +02:00
`
upsertPNIRecipientQuery = `
INSERT INTO signalmeow_recipients (
account_id ,
pni_uuid ,
e164_number ,
contact_name ,
contact_avatar_hash
)
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 )
ON CONFLICT ( account_id , pni_uuid ) DO UPDATE SET
e164_number = excluded . e164_number ,
contact_name = excluded . contact_name ,
contact_avatar_hash = excluded . contact_avatar_hash
`
)
func scanRecipient ( row dbutil . Scannable ) ( * types . Recipient , error ) {
var recipient types . Recipient
var aci , pni uuid . NullUUID
var profileKey [ ] byte
var profileFetchedAt sql . NullInt64
err := row . Scan (
& aci ,
& pni ,
& recipient . E164 ,
& recipient . ContactName ,
& recipient . ContactAvatar . Hash ,
2025-09-01 16:40:35 +03:00
& recipient . Nickname ,
2024-03-22 21:24:30 +02:00
& profileKey ,
& recipient . Profile . Name ,
& recipient . Profile . About ,
& recipient . Profile . AboutEmoji ,
& recipient . Profile . AvatarPath ,
& profileFetchedAt ,
2024-03-25 15:18:25 +02:00
& recipient . NeedsPNISignature ,
2025-11-20 13:29:38 +02:00
& recipient . Blocked ,
2025-12-11 17:53:48 +02:00
& recipient . Whitelisted ,
2024-03-22 21:24:30 +02:00
)
if errors . Is ( err , sql . ErrNoRows ) {
return nil , nil
} else if err != nil {
return nil , err
}
recipient . ACI = aci . UUID
recipient . PNI = pni . UUID
if profileFetchedAt . Valid {
recipient . Profile . FetchedAt = time . UnixMilli ( profileFetchedAt . Int64 )
}
if len ( profileKey ) == libsignalgo . ProfileKeyLength {
recipient . Profile . Key = libsignalgo . ProfileKey ( profileKey )
}
return & recipient , err
}
func ( s * sqlStore ) LoadRecipientByACI ( ctx context . Context , theirUUID uuid . UUID ) ( * types . Recipient , error ) {
return scanRecipient ( s . db . QueryRow ( ctx , getRecipientByACIQuery , s . AccountID , theirUUID ) )
}
func ( s * sqlStore ) LoadRecipientByPNI ( ctx context . Context , theirUUID uuid . UUID ) ( * types . Recipient , error ) {
return scanRecipient ( s . db . QueryRow ( ctx , getRecipientByPNIQuery , s . AccountID , theirUUID ) )
}
type RecipientUpdaterFunc func ( recipient * types . Recipient ) ( changed bool , err error )
func ( s * sqlStore ) mergeRecipients ( ctx context . Context , first , second * types . Recipient , updater RecipientUpdaterFunc ) ( * types . Recipient , error ) {
if first . ACI == uuid . Nil {
first , second = second , first
}
first . PNI = second . PNI
2025-01-30 18:11:06 +02:00
zerolog . Ctx ( ctx ) . Debug ( ) .
Stringer ( "aci" , first . ACI ) .
Stringer ( "pni" , first . PNI ) .
Msg ( "Merging recipient entries in database" )
2024-03-22 21:24:30 +02:00
if second . E164 != "" {
first . E164 = second . E164
}
if first . ContactName == "" {
first . ContactName = second . ContactName
}
if first . ContactAvatar . Hash == "" {
first . ContactAvatar = second . ContactAvatar
}
_ , err := updater ( first )
if err != nil {
return first , fmt . Errorf ( "failed to run updater function: %w" , err )
}
err = s . DeleteRecipientByPNI ( ctx , first . PNI )
if err != nil {
return first , fmt . Errorf ( "failed to delete duplicate PNI row: %w" , err )
}
err = s . StoreRecipient ( ctx , first )
if err != nil {
return first , fmt . Errorf ( "failed to store merged row: %w" , err )
}
return first , nil
}
func ( s * sqlStore ) LoadAndUpdateRecipient ( ctx context . Context , aci , pni uuid . UUID , updater RecipientUpdaterFunc ) ( outRecipient * types . Recipient , outErr error ) {
if aci == uuid . Nil && pni == uuid . Nil {
return nil , fmt . Errorf ( "no ACI or PNI provided in LoadAndUpdateRecipient call" )
}
if updater == nil {
updater = func ( recipient * types . Recipient ) ( bool , error ) {
return false , nil
}
}
2025-11-20 13:29:38 +02:00
defer func ( ) {
if outRecipient != nil && outRecipient . ACI != uuid . Nil && outErr == nil {
s . blockCacheLock . Lock ( )
s . blockCache [ outRecipient . ACI ] = outRecipient . Blocked
s . blockCacheLock . Unlock ( )
}
} ( )
2025-02-04 11:48:31 +02:00
if ctx . Value ( contextKeyContactLock ) == nil {
s . contactLock . Lock ( )
defer s . contactLock . Unlock ( )
}
2024-03-22 21:24:30 +02:00
outErr = s . db . DoTxn ( ctx , nil , func ( ctx context . Context ) error {
var entries [ ] * types . Recipient
var err error
if aci != uuid . Nil && pni != uuid . Nil {
query := getRecipientByACIOrPNIQuery
if s . db . Dialect == dbutil . Postgres {
query += " FOR UPDATE"
}
entries , err = dbutil . ConvertRowFn [ * types . Recipient ] ( scanRecipient ) .
NewRowIter ( s . db . Query ( ctx , query , s . AccountID , aci , pni ) ) .
AsList ( )
} else if aci != uuid . Nil {
var entry * types . Recipient
entry , err = s . LoadRecipientByACI ( ctx , aci )
if entry != nil {
entries = [ ] * types . Recipient { entry }
}
} else if pni != uuid . Nil {
var entry * types . Recipient
entry , err = s . LoadRecipientByPNI ( ctx , pni )
if entry != nil {
entries = [ ] * types . Recipient { entry }
}
} else {
panic ( "impossible case" )
}
if err != nil {
return err
} else if len ( entries ) > 2 {
return fmt . Errorf ( "got more than two recipient rows for ACI %s and PNI %s" , aci , pni )
} else if len ( entries ) < 2 {
if len ( entries ) == 0 {
outRecipient = & types . Recipient {
ACI : aci ,
PNI : pni ,
}
} else {
outRecipient = entries [ 0 ]
}
changed , err := updater ( outRecipient )
if err != nil {
return fmt . Errorf ( "failed to run updater function: %w" , err )
}
2024-06-25 21:28:08 +03:00
// SQL only supports one ON CONFLICT clause, which means StoreRecipient will key on the ACI if it's present.
// If we're adding an ACI to a PNI row, just delete the PNI row first to avoid conflicts on the PNI key.
if outRecipient . PNI != uuid . Nil && outRecipient . ACI == uuid . Nil && aci != uuid . Nil {
2025-01-30 18:11:06 +02:00
zerolog . Ctx ( ctx ) . Debug ( ) .
Stringer ( "aci" , outRecipient . ACI ) .
Stringer ( "pni" , outRecipient . PNI ) .
Msg ( "Deleting old PNI-only row before inserting row with both IDs" )
2024-06-25 21:28:08 +03:00
err = s . DeleteRecipientByPNI ( ctx , outRecipient . PNI )
if err != nil {
return fmt . Errorf ( "failed to delete old PNI row: %w" , err )
}
}
2024-03-22 21:24:30 +02:00
if outRecipient . PNI == uuid . Nil && pni != uuid . Nil {
outRecipient . PNI = pni
changed = true
}
if outRecipient . ACI == uuid . Nil && aci != uuid . Nil {
outRecipient . ACI = aci
changed = true
}
if changed || len ( entries ) == 0 {
2025-01-30 18:11:06 +02:00
zerolog . Ctx ( ctx ) . Trace ( ) .
Stringer ( "aci" , outRecipient . ACI ) .
Stringer ( "pni" , outRecipient . PNI ) .
Msg ( "Saving recipient row" )
2024-03-22 21:24:30 +02:00
err = s . StoreRecipient ( ctx , outRecipient )
if err != nil {
return fmt . Errorf ( "failed to store updated recipient row: %w" , err )
}
}
return nil
} else if outRecipient , err = s . mergeRecipients ( ctx , entries [ 0 ] , entries [ 1 ] , updater ) ; err != nil {
return fmt . Errorf ( "failed to merge recipient rows for ACI %s and PNI %s: %w" , aci , pni , err )
} else {
return nil
}
} )
return
}
2025-11-20 13:29:38 +02:00
func ( s * sqlStore ) IsBlocked ( ctx context . Context , aci uuid . UUID ) ( bool , error ) {
s . blockCacheLock . RLock ( )
cachedVal , ok := s . blockCache [ aci ]
s . blockCacheLock . RUnlock ( )
if ok {
return cachedVal , nil
}
recipient , err := s . LoadAndUpdateRecipient ( ctx , aci , uuid . Nil , nil )
if err != nil {
return false , err
}
return recipient . Blocked , nil
}
2024-03-22 21:24:30 +02:00
func ( s * sqlStore ) UpdateRecipientE164 ( ctx context . Context , aci , pni uuid . UUID , e164 string ) ( * types . Recipient , error ) {
return s . LoadAndUpdateRecipient ( ctx , aci , pni , func ( recipient * types . Recipient ) ( bool , error ) {
if recipient . E164 != e164 {
recipient . E164 = e164
return true , nil
}
return false , nil
} )
}
func ( s * sqlStore ) LoadRecipientByE164 ( ctx context . Context , e164 string ) ( * types . Recipient , error ) {
return scanRecipient ( s . db . QueryRow ( ctx , getRecipientByPhoneQuery , s . AccountID , e164 ) )
}
2024-06-20 17:29:55 +03:00
func ( s * sqlStore ) LoadAllContacts ( ctx context . Context ) ( [ ] * types . Recipient , error ) {
rows , err := s . db . Query ( ctx , getAllRecipientsWithNameOrPhoneQuery , s . AccountID )
return dbutil . NewRowIterWithError ( rows , scanRecipient , err ) . AsList ( )
}
2024-03-22 21:24:30 +02:00
func ( s * sqlStore ) DeleteRecipientByPNI ( ctx context . Context , pni uuid . UUID ) error {
_ , err := s . db . Exec ( ctx , deleteRecipientByPNIQuery , s . AccountID , pni )
return err
}
func nullableUUID ( u uuid . UUID ) uuid . NullUUID {
return uuid . NullUUID { UUID : u , Valid : u != uuid . Nil }
}
func ( s * sqlStore ) StoreRecipient ( ctx context . Context , recipient * types . Recipient ) ( err error ) {
if recipient . ACI != uuid . Nil {
_ , err = s . db . Exec (
ctx ,
upsertACIRecipientQuery ,
s . AccountID ,
recipient . ACI ,
nullableUUID ( recipient . PNI ) ,
recipient . E164 ,
recipient . ContactName ,
recipient . ContactAvatar . Hash ,
2025-09-01 16:40:35 +03:00
recipient . Nickname ,
2024-03-22 21:24:30 +02:00
recipient . Profile . Key . Slice ( ) ,
recipient . Profile . Name ,
recipient . Profile . About ,
recipient . Profile . AboutEmoji ,
recipient . Profile . AvatarPath ,
dbutil . UnixMilliPtr ( recipient . Profile . FetchedAt ) ,
2024-03-25 15:18:25 +02:00
recipient . NeedsPNISignature ,
2025-11-20 13:29:38 +02:00
recipient . Blocked ,
2025-12-11 17:53:48 +02:00
recipient . Whitelisted ,
2024-03-22 21:24:30 +02:00
)
2025-11-20 13:29:38 +02:00
s . blockCacheLock . Lock ( )
s . blockCache [ recipient . ACI ] = recipient . Blocked
s . blockCacheLock . Unlock ( )
2024-03-22 21:24:30 +02:00
} else if recipient . PNI != uuid . Nil {
_ , err = s . db . Exec (
ctx ,
upsertPNIRecipientQuery ,
s . AccountID ,
recipient . PNI ,
recipient . E164 ,
recipient . ContactName ,
recipient . ContactAvatar . Hash ,
)
} else {
err = fmt . Errorf ( "no ACI or PNI provided in StoreRecipient call" )
}
return
}
2025-11-28 16:51:51 +02:00
const (
isUnregisteredQuery = ` SELECT 1 FROM signalmeow_unregistered_users WHERE aci_uuid=$1 `
markUnregisteredQuery = ` INSERT INTO signalmeow_unregistered_users (aci_uuid) VALUES ($1) ON CONFLICT (aci_uuid) DO NOTHING `
markRegisteredQuery = ` DELETE FROM signalmeow_unregistered_users WHERE aci_uuid=$1 `
)
func ( s * sqlStore ) IsUnregistered ( ctx context . Context , serviceID libsignalgo . ServiceID ) ( unregistered bool ) {
if serviceID . Type != libsignalgo . ServiceIDTypeACI {
return
}
_ = s . db . QueryRow ( ctx , isUnregisteredQuery , serviceID . UUID ) . Scan ( & unregistered )
return
}
func ( s * sqlStore ) MarkUnregistered ( ctx context . Context , serviceID libsignalgo . ServiceID , unregistered bool ) {
if serviceID . Type != libsignalgo . ServiceIDTypeACI {
return
}
var err error
if unregistered {
_ , err = s . db . Exec ( ctx , markUnregisteredQuery , serviceID . UUID )
} else {
_ , err = s . db . Exec ( ctx , markRegisteredQuery , serviceID . UUID )
}
if err != nil {
zerolog . Ctx ( ctx ) . Err ( err ) .
Stringer ( "service_id" , serviceID ) .
Bool ( "unregistered" , unregistered ) .
Msg ( "Failed to mark recipient as unregistered" )
}
}