1
0
Fork 0
mirror of https://github.com/mautrix/signal.git synced 2026-05-15 05:36:53 -04:00
mautrix-signal/pkg/signalmeow/web/signalwebsocket.go

716 lines
22 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package web
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/coder/websocket"
"github.com/rs/zerolog"
"go.mau.fi/util/exsync"
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
"go.mau.fi/mautrix-signal/pkg/signalmeow/wspb"
)
var WebsocketPingInterval = 30 * time.Second
var WebsocketPingTimeout = 20 * time.Second
var WebsocketPingTimeoutLimit = 5
const WebsocketProvisioningPath = "/v1/websocket/provisioning/"
const WebsocketPath = "/v1/websocket/"
type SimpleResponse struct {
Status int
WriteCallback func(time.Time)
}
type RequestHandlerFunc func(context.Context, *signalpb.WebSocketRequestMessage) (*SimpleResponse, error)
type SignalWebsocket struct {
ws atomic.Pointer[websocket.Conn]
basicAuth *url.Userinfo
sendChannel chan SignalWebsocketSendMessage
statusChannel chan SignalWebsocketConnectionStatus
closeLock sync.RWMutex
closeEvt *exsync.Event
closeCalled atomic.Bool
cancel atomic.Pointer[context.CancelFunc]
cancelConn atomic.Pointer[context.CancelCauseFunc]
}
func NewSignalWebsocket(basicAuth *url.Userinfo) *SignalWebsocket {
return &SignalWebsocket{
basicAuth: basicAuth,
sendChannel: make(chan SignalWebsocketSendMessage),
statusChannel: make(chan SignalWebsocketConnectionStatus),
closeEvt: exsync.NewEvent(),
}
}
type SignalWebsocketConnectionEvent int
const (
SignalWebsocketConnectionEventConnecting SignalWebsocketConnectionEvent = iota // Implicit to catch default value (0), doesn't get sent
SignalWebsocketConnectionEventConnected
SignalWebsocketConnectionEventDisconnected
SignalWebsocketConnectionEventLoggedOut
SignalWebsocketConnectionEventError
SignalWebsocketConnectionEventFatalError
SignalWebsocketConnectionEventCleanShutdown
)
// mapping from SignalWebsocketConnectionEvent to its string representation
var signalWebsocketConnectionEventNames = map[SignalWebsocketConnectionEvent]string{
SignalWebsocketConnectionEventConnecting: "SignalWebsocketConnectionEventConnecting",
SignalWebsocketConnectionEventConnected: "SignalWebsocketConnectionEventConnected",
SignalWebsocketConnectionEventDisconnected: "SignalWebsocketConnectionEventDisconnected",
SignalWebsocketConnectionEventLoggedOut: "SignalWebsocketConnectionEventLoggedOut",
SignalWebsocketConnectionEventError: "SignalWebsocketConnectionEventError",
SignalWebsocketConnectionEventFatalError: "SignalWebsocketConnectionEventFatalError",
SignalWebsocketConnectionEventCleanShutdown: "SignalWebsocketConnectionEventCleanShutdown",
}
// Implement the fmt.Stringer interface
func (s SignalWebsocketConnectionEvent) String() string {
return signalWebsocketConnectionEventNames[s]
}
type SignalWebsocketConnectionStatus struct {
Event SignalWebsocketConnectionEvent
Err error
}
func (s *SignalWebsocket) IsConnected() bool {
return s.ws.Load() != nil
}
func (s *SignalWebsocket) Close() (err error) {
if s == nil {
return nil
}
s.closeCalled.Store(true)
if ws := s.ws.Swap(nil); ws != nil {
err = ws.Close(websocket.StatusNormalClosure, "")
}
if cancelLoop := s.cancel.Swap(nil); cancelLoop != nil {
(*cancelLoop)()
}
<-s.closeEvt.GetChan()
return err
}
func (s *SignalWebsocket) Connect(ctx context.Context, requestHandler RequestHandlerFunc) chan SignalWebsocketConnectionStatus {
go s.connectLoop(ctx, requestHandler)
return s.statusChannel
}
func (s *SignalWebsocket) pushStatus(ctx context.Context, status SignalWebsocketConnectionEvent, err error) {
select {
case s.statusChannel <- SignalWebsocketConnectionStatus{
Event: status,
Err: err,
}:
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
zerolog.Ctx(ctx).Error().Msg("Status channel didn't accept status")
}
}
func (s *SignalWebsocket) pushOutgoing(ctx context.Context, send SignalWebsocketSendMessage) error {
if ctx.Err() != nil {
return ctx.Err()
}
s.closeLock.RLock()
defer s.closeLock.RUnlock()
if s.sendChannel == nil {
return errors.New("connection is not open")
}
select {
case s.sendChannel <- send:
return nil
case <-ctx.Done():
return ctx.Err()
case <-s.closeEvt.GetChan():
return errors.New("connection closed before send could be queued")
}
}
var ErrForcedReconnect = errors.New("forced reconnect")
func (s *SignalWebsocket) ForceReconnect() {
if s == nil {
return
}
cancelFn := s.cancelConn.Load()
if cancelFn == nil {
return
}
(*cancelFn)(ErrForcedReconnect)
}
func (s *SignalWebsocket) connectLoop(
ctx context.Context,
requestHandler RequestHandlerFunc,
) {
log := zerolog.Ctx(ctx).With().
Str("loop", "signal_websocket_connect_loop").
Logger()
ctx, cancel := context.WithCancel(ctx)
s.cancel.Store(&cancel)
incomingRequestChan := make(chan *signalpb.WebSocketRequestMessage, 256)
defer func() {
s.closeEvt.Set()
cancel()
s.closeLock.Lock()
defer s.closeLock.Unlock()
close(incomingRequestChan)
close(s.statusChannel)
close(s.sendChannel)
incomingRequestChan = nil
s.statusChannel = nil
s.sendChannel = nil
}()
const initialBackoff = 10 * time.Second
const backoffIncrement = 5 * time.Second
const maxBackoff = 60 * time.Second
if s.ws.Load() != nil {
panic("Already connected")
}
// First set up request handler loop. This exists outside of the
// connection loops because we want to maintain it across reconnections
go func() {
for {
select {
case <-ctx.Done():
log.Info().Msg("ctx done, stopping request loop")
return
case request, ok := <-incomingRequestChan:
if !ok {
// Main connection loop must have closed, so we should stop
log.Info().Msg("incomingRequestChan closed, stopping request loop")
return
}
if request == nil {
log.Fatal().Msg("Received nil request")
}
if requestHandler == nil {
log.Fatal().Msg("Received request but no handler")
}
// Handle the request with the request handler function
response, err := requestHandler(ctx, request)
if err != nil {
log.Err(err).Uint64("request_id", request.GetId()).Msg("Error handling request")
} else if response != nil {
err = s.pushOutgoing(ctx, SignalWebsocketSendMessage{
RequestMessage: request,
ResponseMessage: response,
})
if err != nil {
log.Err(err).Uint64("request_id", request.GetId()).Msg("Error queuing response message")
}
} else {
log.Warn().Uint64("request_id", request.GetId()).Msg("Request handler didn't return a response nor an error")
}
}
}
}()
// Main connection loop - if there's a problem with anything just
// kill everything (including the websocket) and build it all up again
backoff := initialBackoff
retrying := false
errorCount := 0
isFirstConnect := true
wsURL := (&url.URL{
Scheme: "wss",
Host: APIHostname,
Path: WebsocketPath,
User: s.basicAuth,
}).String()
for {
if retrying {
if backoff > maxBackoff {
backoff = maxBackoff
}
log.Warn().Dur("backoff", backoff).Msg("Failed to connect, waiting to retry...")
select {
case <-time.After(backoff):
case <-ctx.Done():
}
backoff += backoffIncrement
} else if !isFirstConnect && s.basicAuth != nil {
select {
case <-time.After(initialBackoff):
case <-ctx.Done():
}
}
if ctx.Err() != nil {
log.Info().Msg("ctx done, stopping connection loop")
return
}
isFirstConnect = false
ws, resp, err := OpenWebsocket(ctx, wsURL)
if resp != nil {
if resp.StatusCode != 101 {
// Server didn't want to open websocket
if resp.StatusCode >= 500 {
// We can try again if it's a 5xx
s.pushStatus(ctx, SignalWebsocketConnectionEventDisconnected, fmt.Errorf("5xx opening websocket: %v", resp.Status))
} else if resp.StatusCode == 403 {
// We are logged out, so we should stop trying to reconnect
s.pushStatus(ctx, SignalWebsocketConnectionEventLoggedOut, fmt.Errorf("403 opening websocket, we are logged out"))
return // NOT RETRYING, KILLING THE CONNECTION LOOP
} else if resp.StatusCode > 0 && resp.StatusCode < 500 {
// Unexpected status code
s.pushStatus(ctx, SignalWebsocketConnectionEventFatalError, fmt.Errorf("unexpected status opening websocket: %v", resp.Status))
return // NOT RETRYING, KILLING THE CONNECTION LOOP
} else {
// Something is very wrong
s.pushStatus(ctx, SignalWebsocketConnectionEventError, fmt.Errorf("unexpected error opening websocket: %v", resp.Status))
}
// Retry the connection
retrying = true
continue
}
}
if err != nil {
// Unexpected error opening websocket
if backoff < maxBackoff {
s.pushStatus(ctx, SignalWebsocketConnectionEventDisconnected, fmt.Errorf("transient error opening websocket: %w", err))
} else {
s.pushStatus(ctx, SignalWebsocketConnectionEventError, fmt.Errorf("continuing error opening websocket: %w", err))
}
retrying = true
continue
}
// Succssfully connected
s.pushStatus(ctx, SignalWebsocketConnectionEventConnected, nil)
s.ws.Store(ws)
retrying = false
backoff = initialBackoff
responseChannels := exsync.NewMap[uint64, chan *signalpb.WebSocketResponseMessage]()
loopCtx, loopCancel := context.WithCancelCause(ctx)
s.cancelConn.Store(&loopCancel)
var wg sync.WaitGroup
wg.Add(3)
// Read loop (for reading incoming reqeusts and responses to outgoing requests)
go func() {
defer wg.Done()
err := readLoop(loopCtx, ws, incomingRequestChan, responseChannels)
// Don't want to put an err into loopCancel if we don't have one
if err != nil {
err = fmt.Errorf("error in readLoop: %w", err)
}
if s.closeCalled.Load() {
// Exit during Close() so cancel the reconnect loop as well
cancel()
}
loopCancel(err)
log.Info().Msg("readLoop exited")
}()
// Write loop (for sending outgoing requests and responses to incoming requests)
go func() {
defer wg.Done()
err := writeLoop(loopCtx, ws, s.sendChannel, responseChannels)
// Don't want to put an err into loopCancel if we don't have one
if err != nil {
err = fmt.Errorf("error in writeLoop: %w", err)
}
loopCancel(err)
log.Info().Msg("writeLoop exited")
}()
// Ping loop (send a keepalive Ping every 30s)
go func() {
defer wg.Done()
ticker := time.NewTicker(WebsocketPingInterval)
defer ticker.Stop()
pingTimeoutCount := 0
for {
select {
case <-ticker.C:
pingCtx, cancel := context.WithTimeout(loopCtx, WebsocketPingTimeout)
err := ws.Ping(pingCtx)
cancel()
if err != nil {
pingTimeoutCount++
log.Err(err).Msg("Failed to send ping")
if pingTimeoutCount >= WebsocketPingTimeoutLimit {
log.Warn().Msg("Ping timeout count exceeded, closing websocket")
err = ws.Close(websocket.StatusNormalClosure, "Ping timeout")
if err != nil {
log.Err(err).Msg("Error closing websocket after ping timeout")
}
return
}
} else if pingTimeoutCount > 0 {
pingTimeoutCount = 0
log.Debug().Msg("Recovered from ping error")
} else {
log.Trace().Msg("Sent keepalive")
}
case <-loopCtx.Done():
return
}
}
}()
// Wait for read or write or ping loop to exit (which means there was an error)
log.Debug().Msg("Finished preparing connection, waiting for loop context to finish")
<-loopCtx.Done()
ctxCauseErr := context.Cause(loopCtx)
log.Debug().AnErr("ctx_cause_err", ctxCauseErr).Msg("Read or write loop exited")
if ctxCauseErr == nil || errors.Is(ctxCauseErr, context.Canceled) {
s.pushStatus(ctx, SignalWebsocketConnectionEventCleanShutdown, nil)
} else {
errorCount++
s.pushStatus(ctx, SignalWebsocketConnectionEventDisconnected, ctxCauseErr)
if errors.Is(ctxCauseErr, ErrForcedReconnect) {
// Skip the delay for forced reconnects
// TODO should the delay be lowered globally?
isFirstConnect = true
}
}
// Clean up
ws.Close(websocket.StatusGoingAway, "Going away")
for _, responseChannel := range responseChannels.SwapData(nil) {
close(responseChannel)
}
loopCancel(nil)
wg.Wait()
log.Debug().Msg("Finished websocket cleanup")
if errorCount > 500 {
// Something is really wrong, we better panic.
// This is a last defense against a runaway error loop,
// like the WS continually closing and reconnecting
log.Fatal().Int("error_count", errorCount).Msg("Too many errors, panicking")
}
}
}
func readLoop(
ctx context.Context,
ws *websocket.Conn,
incomingRequestChan chan *signalpb.WebSocketRequestMessage,
responseChannels *exsync.Map[uint64, chan *signalpb.WebSocketResponseMessage],
) error {
log := zerolog.Ctx(ctx).With().
Str("loop", "signal_websocket_read_loop").
Logger()
for {
if ctx.Err() != nil {
return ctx.Err()
}
msg := &signalpb.WebSocketMessage{}
//ctx, _ := context.WithTimeout(ctx, 10*time.Second) // For testing
err := wspb.Read(ctx, ws, msg)
if err != nil {
if err == context.Canceled {
log.Info().Msg("readLoop context canceled")
} else if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Info().Msg("readLoop received StatusNormalClosure")
return nil
}
return fmt.Errorf("error reading message: %w", err)
}
if msg.Type == nil {
return errors.New("received message with no type")
} else if *msg.Type == signalpb.WebSocketMessage_REQUEST {
if msg.Request == nil {
return errors.New("received request message with no request")
}
log.Trace().
Uint64("request_id", *msg.Request.Id).
Str("request_verb", *msg.Request.Verb).
Str("request_path", *msg.Request.Path).
Msg("Received WS request")
incomingRequestChan <- msg.Request
} else if *msg.Type == signalpb.WebSocketMessage_RESPONSE {
if msg.Response == nil {
log.Fatal().Msg("Received response with no response")
}
if msg.Response.Id == nil {
log.Fatal().Msg("Received response with no id")
}
responseChannel, ok := responseChannels.Pop(*msg.Response.Id)
if !ok {
log.Warn().
Uint64("response_id", *msg.Response.Id).
Msg("Received response with unknown id")
continue
}
logEvt := log.Debug().
Uint64("response_id", msg.Response.GetId()).
Uint32("response_status", msg.Response.GetStatus()).
Str("response_message", msg.Response.GetMessage())
if log.GetLevel() == zerolog.TraceLevel || len(msg.Response.Body) < 256 {
logEvt.Strs("response_headers", msg.Response.Headers)
if json.Valid(msg.Response.Body) {
logEvt.RawJSON("response_body", msg.Response.Body)
} else {
logEvt.Str("response_body", base64.StdEncoding.EncodeToString(msg.Response.Body))
}
}
logEvt.Msg("Received WS response")
responseChannel <- msg.Response
close(responseChannel)
} else if *msg.Type == signalpb.WebSocketMessage_UNKNOWN {
return fmt.Errorf("received message with unknown type: %v", *msg.Type)
} else {
return fmt.Errorf("received message with actually unknown type: %v", *msg.Type)
}
}
}
type SignalWebsocketSendMessage struct {
// Populate if we're sending a request:
RequestTime time.Time
ResponseChannel chan *signalpb.WebSocketResponseMessage
// Populate if we're sending a response:
ResponseMessage *SimpleResponse
// Populate this for request AND response
RequestMessage *signalpb.WebSocketRequestMessage
}
func writeLoop(
ctx context.Context,
ws *websocket.Conn,
sendChannel chan SignalWebsocketSendMessage,
responseChannels *exsync.Map[uint64, chan *signalpb.WebSocketResponseMessage],
) error {
log := zerolog.Ctx(ctx).With().
Str("loop", "signal_websocket_write_loop").
Logger()
for i := uint64(1); ; i++ {
select {
case <-ctx.Done():
if ctx.Err() != nil && ctx.Err() != context.Canceled {
return ctx.Err()
}
return nil
case request, ok := <-sendChannel:
if !ok {
return errors.New("send channel closed")
}
if request.RequestMessage != nil && request.ResponseChannel != nil {
msgType := signalpb.WebSocketMessage_REQUEST
message := &signalpb.WebSocketMessage{
Type: &msgType,
Request: request.RequestMessage,
}
request.RequestMessage.Id = &i
responseChannels.Set(i, request.ResponseChannel)
if !request.RequestTime.IsZero() {
elapsed := time.Since(request.RequestTime)
if elapsed > 1*time.Minute {
return fmt.Errorf("request too old (%v), not sending", elapsed)
} else if elapsed > 10*time.Second {
log.Warn().
Uint64("request_id", i).
Str("request_verb", *request.RequestMessage.Verb).
Str("request_path", *request.RequestMessage.Path).
Dur("elapsed", elapsed).
Msg("Sending WS request")
} else {
log.Debug().
Uint64("request_id", i).
Str("request_verb", *request.RequestMessage.Verb).
Str("request_path", *request.RequestMessage.Path).
Dur("elapsed", elapsed).
Msg("Sending WS request")
}
}
err := wspb.Write(ctx, ws, message)
if err != nil {
if ctx.Err() != nil && ctx.Err() != context.Canceled {
return ctx.Err()
}
return fmt.Errorf("error writing request message: %w", err)
}
} else if request.RequestMessage != nil && request.ResponseMessage != nil {
message := CreateWSResponse(ctx, *request.RequestMessage.Id, request.ResponseMessage.Status)
log.Debug().
Uint64("request_id", *request.RequestMessage.Id).
Int("response_status", request.ResponseMessage.Status).
Msg("Sending WS response")
writeStartTime := time.Now()
err := wspb.Write(ctx, ws, message)
if err != nil {
return fmt.Errorf("error writing response message: %w", err)
}
if request.ResponseMessage.WriteCallback != nil {
request.ResponseMessage.WriteCallback(writeStartTime)
}
} else {
return fmt.Errorf("invalid request: %+v", request)
}
}
}
}
func (s *SignalWebsocket) SendRequest(
ctx context.Context,
method,
path string,
body []byte,
headers http.Header,
) (*signalpb.WebSocketResponseMessage, error) {
if s == nil {
return nil, errors.New("websocket is nil")
}
headerArray := make([]string, len(headers))
var hasContentType bool
for key, values := range headers {
if strings.ToLower(key) == "content-type" {
hasContentType = true
}
for _, value := range values {
headerArray = append(headerArray, fmt.Sprintf("%s:%s", strings.ToLower(key), value))
}
}
if !hasContentType && body != nil {
headerArray = append(headerArray, "content-type:application/json")
}
return s.sendRequestInternal(ctx, &signalpb.WebSocketRequestMessage{
Verb: &method,
Path: &path,
Body: body,
Headers: headerArray,
}, time.Now(), 0)
}
func (s *SignalWebsocket) sendRequestInternal(
ctx context.Context,
request *signalpb.WebSocketRequestMessage,
startTime time.Time,
retryCount int,
) (*signalpb.WebSocketResponseMessage, error) {
if s.basicAuth != nil {
request.Headers = append(request.Headers, "authorization:Basic "+s.basicAuth.String())
}
responseChannel := make(chan *signalpb.WebSocketResponseMessage, 1)
err := s.pushOutgoing(ctx, SignalWebsocketSendMessage{
RequestMessage: request,
ResponseChannel: responseChannel,
RequestTime: startTime,
})
if err != nil {
return nil, err
}
response := <-responseChannel
isSelfDelete := request.GetVerb() == http.MethodDelete && strings.HasPrefix(request.GetPath(), "/v1/devices/")
if response == nil && !isSelfDelete {
// If out of retries, return error no matter what
if retryCount >= 3 {
// TODO: I think error isn't getting passed in this context (as it's not the one in writeLoop)
if ctx.Err() != nil {
return nil, fmt.Errorf("retried 3 times, giving up: %w", ctx.Err())
} else {
return nil, errors.New("retried 3 times, giving up")
}
}
if ctx.Err() != nil {
// if error contains "Took too long" don't retry
if strings.Contains(ctx.Err().Error(), "Took too long") {
return nil, ctx.Err()
}
}
zerolog.Ctx(ctx).Warn().Int("retry_count", retryCount).Msg("Received nil response, retrying recursively")
return s.sendRequestInternal(ctx, request, startTime, retryCount+1)
}
return response, nil
}
func OpenWebsocket(ctx context.Context, url string) (*websocket.Conn, *http.Response, error) {
opt := &websocket.DialOptions{
HTTPClient: SignalHTTPClient,
HTTPHeader: make(http.Header, 2),
}
opt.HTTPHeader.Set("User-Agent", UserAgent)
opt.HTTPHeader.Set("X-Signal-Agent", SignalAgent)
ws, resp, err := websocket.Dial(ctx, url, opt)
if ws != nil {
ws.SetReadLimit(1 << 20) // Increase read limit to 1MB from default of 32KB
}
return ws, resp, err
}
func CreateWSResponse(ctx context.Context, id uint64, status int) *signalpb.WebSocketMessage {
if status != 200 && status != 400 {
// TODO support more responses to Signal? Are there more?
zerolog.Ctx(ctx).Fatal().Int("status", status).Msg("Error creating response. Non 200/400 not supported yet.")
return nil
}
msg_type := signalpb.WebSocketMessage_RESPONSE
message := "OK"
if status == 400 {
message = "Unknown"
}
status32 := uint32(status)
response := &signalpb.WebSocketMessage{
Type: &msg_type,
Response: &signalpb.WebSocketResponseMessage{
Id: &id,
Message: &message,
Status: &status32,
Headers: []string{},
},
}
return response
}
func CreateWSRequest(method string, path string, body []byte, username *string, password *string) *signalpb.WebSocketRequestMessage {
request := &signalpb.WebSocketRequestMessage{
Verb: &method,
Path: &path,
Body: body,
}
request.Headers = []string{}
request.Headers = append(request.Headers, "content-type:application/json; charset=utf-8")
if username != nil && password != nil {
basicAuth := base64.StdEncoding.EncodeToString([]byte(*username + ":" + *password))
request.Headers = append(request.Headers, "authorization:Basic "+basicAuth)
}
return request
}