2023-12-17 15:54:35 +02:00
// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
2024-01-05 13:44:41 +02:00
package store
2023-05-31 16:39:09 -04:00
import (
"context"
"database/sql"
"errors"
2024-01-04 01:06:45 +02:00
"fmt"
2023-05-31 16:39:09 -04:00
2024-01-04 01:06:45 +02:00
"go.mau.fi/util/dbutil"
2023-05-31 16:39:09 -04:00
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
)
2024-03-19 19:15:30 +02:00
var _ SessionStore = ( * scopedSQLStore ) ( nil )
2023-05-31 16:39:09 -04:00
const (
2025-11-27 16:53:59 +02:00
loadSessionQuery = ` SELECT their_service_id, their_device_id, record FROM signalmeow_sessions WHERE account_id=$1 AND service_id=$2 AND their_service_id=$3 AND their_device_id=$4 `
2024-03-15 15:30:44 +02:00
storeSessionQuery = `
2024-03-19 19:15:30 +02:00
INSERT INTO signalmeow_sessions ( account_id , service_id , their_service_id , their_device_id , record )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 )
ON CONFLICT ( account_id , service_id , their_service_id , their_device_id ) DO UPDATE SET record = excluded . record
2024-03-15 15:30:44 +02:00
`
2025-11-27 16:53:59 +02:00
allSessionsQuery = ` SELECT their_service_id, their_device_id, record FROM signalmeow_sessions WHERE account_id=$1 AND service_id=$2 AND their_service_id=$3 `
removeSessionQuery = ` DELETE FROM signalmeow_sessions WHERE account_id=$1 AND service_id=$2 AND their_service_id=$3 AND their_device_id=$4 `
removeSessionsForRecipientQuery = "DELETE FROM signalmeow_sessions WHERE account_id=$1 AND their_service_id=$2"
deleteAllSessionsQuery = "DELETE FROM signalmeow_sessions WHERE account_id=$1"
2023-05-31 16:39:09 -04:00
)
2025-11-27 16:53:59 +02:00
type SessionAddressTuple = libsignalgo . SessionAddressTuple
2024-03-19 19:15:30 +02:00
type SessionStore interface {
libsignalgo . SessionStore
2024-03-20 18:43:32 +02:00
ServiceScopedStore
2024-03-15 15:30:44 +02:00
// AllSessionsForServiceID returns all sessions for the given service ID.
2025-11-27 16:53:59 +02:00
AllSessionsForServiceID ( ctx context . Context , theirID libsignalgo . ServiceID ) ( [ ] SessionAddressTuple , error )
2023-05-31 16:39:09 -04:00
// RemoveSession removes the session for the given address.
2024-01-04 01:06:45 +02:00
RemoveSession ( ctx context . Context , address * libsignalgo . Address ) error
2025-11-27 16:53:59 +02:00
RemoveAllSessionsForServiceID ( ctx context . Context , theirID libsignalgo . ServiceID ) error
2023-09-20 14:28:09 -04:00
// RemoveAllSessions removes all sessions for our ACI UUID
RemoveAllSessions ( ctx context . Context ) error
2023-05-31 16:39:09 -04:00
}
2025-11-27 16:53:59 +02:00
func scanSessionRecord ( row dbutil . Scannable ) ( tuple SessionAddressTuple , err error ) {
var rawServiceID string
var rawRecord [ ] byte
err = row . Scan ( & rawServiceID , & tuple . DeviceID , & rawRecord )
2023-05-31 16:39:09 -04:00
if errors . Is ( err , sql . ErrNoRows ) {
2025-11-27 16:53:59 +02:00
err = nil
2023-05-31 16:39:09 -04:00
} else if err != nil {
2025-11-27 16:53:59 +02:00
// return error as-is
} else if tuple . Record , err = libsignalgo . DeserializeSessionRecord ( rawRecord ) ; err != nil {
err = fmt . Errorf ( "failed to deserialize session record: %w" , err )
} else if tuple . ServiceID , err = libsignalgo . ServiceIDFromString ( rawServiceID ) ; err != nil {
err = fmt . Errorf ( "failed to parse service ID: %w" , err )
} else if tuple . Address , err = tuple . ServiceID . Address ( uint ( tuple . DeviceID ) ) ; err != nil {
err = fmt . Errorf ( "failed to construct address: %w" , err )
2023-05-31 16:39:09 -04:00
}
2025-11-27 16:53:59 +02:00
return
2023-05-31 16:39:09 -04:00
}
2024-03-19 19:15:30 +02:00
func ( s * scopedSQLStore ) RemoveSession ( ctx context . Context , address * libsignalgo . Address ) error {
2024-03-15 15:30:44 +02:00
theirServiceID , err := address . Name ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-03-15 15:30:44 +02:00
return fmt . Errorf ( "failed to get their service ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
2024-01-04 01:06:45 +02:00
deviceID , err := address . DeviceID ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-01-04 01:06:45 +02:00
return fmt . Errorf ( "failed to get their device ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
2024-03-19 19:15:30 +02:00
_ , err = s . db . Exec ( ctx , removeSessionQuery , s . AccountID , s . ServiceID , theirServiceID , deviceID )
2023-05-31 16:39:09 -04:00
return err
}
2025-11-27 16:53:59 +02:00
func ( s * scopedSQLStore ) AllSessionsForServiceID ( ctx context . Context , theirID libsignalgo . ServiceID ) ( [ ] SessionAddressTuple , error ) {
2024-03-19 19:15:30 +02:00
rows , err := s . db . Query ( ctx , allSessionsQuery , s . AccountID , s . ServiceID , theirID )
2023-05-31 16:39:09 -04:00
if err != nil {
2025-11-27 16:53:59 +02:00
return nil , err
2023-05-31 16:39:09 -04:00
}
2025-11-27 16:53:59 +02:00
return dbutil . NewRowIterWithError ( rows , scanSessionRecord , err ) . AsList ( )
}
func ( s * scopedSQLStore ) RemoveAllSessionsForServiceID ( ctx context . Context , theirID libsignalgo . ServiceID ) error {
_ , err := s . db . Exec ( ctx , removeSessionsForRecipientQuery , s . AccountID , theirID )
return err
2023-05-31 16:39:09 -04:00
}
2024-03-19 19:15:30 +02:00
func ( s * scopedSQLStore ) LoadSession ( ctx context . Context , address * libsignalgo . Address ) ( * libsignalgo . SessionRecord , error ) {
2024-03-15 15:30:44 +02:00
theirServiceID , err := address . Name ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-03-15 15:30:44 +02:00
return nil , fmt . Errorf ( "failed to get their service ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
2024-01-04 01:06:45 +02:00
deviceID , err := address . DeviceID ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-01-04 01:06:45 +02:00
return nil , fmt . Errorf ( "failed to get their device ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
2025-11-27 16:53:59 +02:00
tuple , err := scanSessionRecord ( s . db . QueryRow ( ctx , loadSessionQuery , s . AccountID , s . ServiceID , theirServiceID , deviceID ) )
return tuple . Record , err
2023-05-31 16:39:09 -04:00
}
2024-03-19 19:15:30 +02:00
func ( s * scopedSQLStore ) StoreSession ( ctx context . Context , address * libsignalgo . Address , record * libsignalgo . SessionRecord ) error {
2024-03-15 15:30:44 +02:00
theirServiceID , err := address . Name ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-03-15 15:30:44 +02:00
return fmt . Errorf ( "failed to get their service ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
2024-01-04 01:06:45 +02:00
deviceID , err := address . DeviceID ( )
2023-05-31 16:39:09 -04:00
if err != nil {
2024-01-04 01:06:45 +02:00
return fmt . Errorf ( "failed to get their device ID: %w" , err )
2023-05-31 16:39:09 -04:00
}
serialized , err := record . Serialize ( )
if err != nil {
2024-01-04 01:06:45 +02:00
return fmt . Errorf ( "failed to serialize session record: %w" , err )
2023-05-31 16:39:09 -04:00
}
2024-03-19 19:15:30 +02:00
_ , err = s . db . Exec ( ctx , storeSessionQuery , s . AccountID , s . ServiceID , theirServiceID , deviceID , serialized )
2023-05-31 16:39:09 -04:00
return err
}
2023-09-20 14:28:09 -04:00
2024-03-19 19:15:30 +02:00
func ( s * scopedSQLStore ) RemoveAllSessions ( ctx context . Context ) error {
_ , err := s . db . Exec ( ctx , deleteAllSessionsQuery , s . AccountID )
2023-09-20 14:28:09 -04:00
return err
}