mirror of
https://github.com/mautrix/whatsapp.git
synced 2026-05-16 02:26:53 -04:00
120 lines
3.5 KiB
Go
120 lines
3.5 KiB
Go
package wadb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/lib/pq"
|
|
"go.mau.fi/util/dbutil"
|
|
"maunium.net/go/mautrix/bridgev2/networkid"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type PollOptionQuery struct {
|
|
BridgeID networkid.BridgeID
|
|
*dbutil.Database
|
|
}
|
|
|
|
type pollOption struct {
|
|
id string
|
|
hash [32]byte
|
|
}
|
|
|
|
const (
|
|
putPollOptionBaseQuery = `
|
|
INSERT INTO whatsapp_poll_option_id (bridge_id, msg_mxid, opt_id, opt_hash)
|
|
VALUES ($1, $2, $3, $4)
|
|
`
|
|
getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM whatsapp_poll_option_id WHERE bridge_id=$1 AND msg_mxid=$2 AND opt_hash = ANY($3)"
|
|
getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM whatsapp_poll_option_id WHERE bridge_id=$1 AND msg_mxid=$2 AND opt_id = ANY($3)"
|
|
getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
|
|
getPollOptionQueryArrayPlaceholder = " = ANY($3)"
|
|
)
|
|
|
|
func init() {
|
|
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
|
|
panic("Array select query placeholder not found")
|
|
}
|
|
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
|
|
panic("Array select query placeholder not found")
|
|
}
|
|
}
|
|
|
|
var pollOptionInserter = dbutil.NewMassInsertBuilder[*pollOption, [2]any](
|
|
putPollOptionBaseQuery, "($1, $2, $%d, $%d)",
|
|
)
|
|
|
|
func (poq *PollOptionQuery) Put(ctx context.Context, mxid id.EventID, opts map[[32]byte]string) error {
|
|
if len(opts) == 0 {
|
|
return nil
|
|
}
|
|
optArray := make([]*pollOption, len(opts))
|
|
i := 0
|
|
for hash, optID := range opts {
|
|
optArray[i] = &pollOption{id: optID, hash: hash}
|
|
i++
|
|
}
|
|
query, args := pollOptionInserter.Build([2]any{poq.BridgeID, mxid}, optArray)
|
|
_, err := poq.Exec(ctx, query, args...)
|
|
return err
|
|
}
|
|
|
|
func (poq *PollOptionQuery) GetIDs(ctx context.Context, mxid id.EventID, hashes [][]byte) (map[[32]byte]string, error) {
|
|
return getPollOptions(
|
|
poq, ctx, mxid, getPollOptionIDsByHashesQuery, hashes,
|
|
func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
|
|
)
|
|
}
|
|
|
|
func (poq *PollOptionQuery) GetHashes(ctx context.Context, mxid id.EventID, ids []string) (map[string][32]byte, error) {
|
|
return getPollOptions(
|
|
poq, ctx, mxid, getPollOptionHashesByIDsQuery, ids,
|
|
func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
|
|
)
|
|
}
|
|
|
|
func getPollOptions[LookupKey any, Key comparable, Value any](
|
|
poq *PollOptionQuery,
|
|
ctx context.Context,
|
|
mxid id.EventID,
|
|
query string,
|
|
things []LookupKey,
|
|
getKeyValue func(option *pollOption) (Key, Value),
|
|
) (map[Key]Value, error) {
|
|
var args []any
|
|
if poq.Dialect == dbutil.Postgres {
|
|
args = []any{poq.BridgeID, mxid, pq.Array(things)}
|
|
} else {
|
|
query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
|
|
args = make([]any, len(things)+2)
|
|
args[0] = poq.BridgeID
|
|
args[1] = mxid
|
|
for i, thing := range things {
|
|
args[i+2] = thing
|
|
}
|
|
}
|
|
return dbutil.RowIterAsMap(
|
|
dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(poq.Query(ctx, query, args...)),
|
|
getKeyValue,
|
|
)
|
|
}
|
|
|
|
func scanPollOption(row dbutil.Scannable) (*pollOption, error) {
|
|
var optionID string
|
|
var hash []byte
|
|
err := row.Scan(&optionID, &hash)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(hash) != 32 {
|
|
return nil, fmt.Errorf("invalid hash length: %d", len(hash))
|
|
}
|
|
return &pollOption{
|
|
id: optionID,
|
|
hash: [32]byte(hash),
|
|
}, nil
|
|
}
|
|
|
|
func (po *pollOption) GetMassInsertValues() [2]any {
|
|
return [2]any{po.id, po.hash[:]}
|
|
}
|