1
0
Fork 0
mirror of https://github.com/mautrix/whatsapp.git synced 2026-05-16 10:33:14 -04:00
mautrix-whatsapp/pkg/connector/wadb/polloption.go

120 lines
3.5 KiB
Go
Raw Permalink Normal View History

package wadb
import (
"context"
"fmt"
"strings"
"github.com/lib/pq"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/id"
)
type PollOptionQuery struct {
BridgeID networkid.BridgeID
*dbutil.Database
}
type pollOption struct {
id string
hash [32]byte
}
const (
putPollOptionBaseQuery = `
INSERT INTO whatsapp_poll_option_id (bridge_id, msg_mxid, opt_id, opt_hash)
VALUES ($1, $2, $3, $4)
`
getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM whatsapp_poll_option_id WHERE bridge_id=$1 AND msg_mxid=$2 AND opt_hash = ANY($3)"
getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM whatsapp_poll_option_id WHERE bridge_id=$1 AND msg_mxid=$2 AND opt_id = ANY($3)"
getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
getPollOptionQueryArrayPlaceholder = " = ANY($3)"
)
func init() {
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
}
var pollOptionInserter = dbutil.NewMassInsertBuilder[*pollOption, [2]any](
putPollOptionBaseQuery, "($1, $2, $%d, $%d)",
)
func (poq *PollOptionQuery) Put(ctx context.Context, mxid id.EventID, opts map[[32]byte]string) error {
if len(opts) == 0 {
return nil
}
optArray := make([]*pollOption, len(opts))
i := 0
for hash, optID := range opts {
optArray[i] = &pollOption{id: optID, hash: hash}
i++
}
query, args := pollOptionInserter.Build([2]any{poq.BridgeID, mxid}, optArray)
_, err := poq.Exec(ctx, query, args...)
return err
}
func (poq *PollOptionQuery) GetIDs(ctx context.Context, mxid id.EventID, hashes [][]byte) (map[[32]byte]string, error) {
return getPollOptions(
poq, ctx, mxid, getPollOptionIDsByHashesQuery, hashes,
func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
)
}
func (poq *PollOptionQuery) GetHashes(ctx context.Context, mxid id.EventID, ids []string) (map[string][32]byte, error) {
return getPollOptions(
poq, ctx, mxid, getPollOptionHashesByIDsQuery, ids,
func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
)
}
func getPollOptions[LookupKey any, Key comparable, Value any](
poq *PollOptionQuery,
ctx context.Context,
mxid id.EventID,
query string,
things []LookupKey,
getKeyValue func(option *pollOption) (Key, Value),
) (map[Key]Value, error) {
var args []any
if poq.Dialect == dbutil.Postgres {
args = []any{poq.BridgeID, mxid, pq.Array(things)}
} else {
query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
args = make([]any, len(things)+2)
args[0] = poq.BridgeID
args[1] = mxid
for i, thing := range things {
args[i+2] = thing
}
}
return dbutil.RowIterAsMap(
dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(poq.Query(ctx, query, args...)),
getKeyValue,
)
}
func scanPollOption(row dbutil.Scannable) (*pollOption, error) {
var optionID string
var hash []byte
err := row.Scan(&optionID, &hash)
if err != nil {
return nil, err
} else if len(hash) != 32 {
return nil, fmt.Errorf("invalid hash length: %d", len(hash))
}
return &pollOption{
id: optionID,
hash: [32]byte(hash),
}, nil
}
func (po *pollOption) GetMassInsertValues() [2]any {
return [2]any{po.id, po.hash[:]}
}