mirror of
https://github.com/mjl-/mox.git
synced 2025-01-01 03:13:48 +03:00
1214 lines
32 KiB
Go
1214 lines
32 KiB
Go
package bstore
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
// The convention for handling a errors on a Query: methods that return a bool
|
|
// will have set q.err using q.error(), which does cleanup. If a method returns
|
|
// an error, q.error was not yet called, but usually will be called immediately
|
|
// after with the returned err, possibly first adding context.
|
|
|
|
// Query selects data for Go struct T based on filters, sorting, limits. The
|
|
// query is completed by calling an operation, such as Count, Get, List,
|
|
// Update, Delete, etc.
|
|
//
|
|
// Record selection functions like FilterEqual and Limit return the (modified)
|
|
// query itself, allowing chaining of calls.
|
|
//
|
|
// Queries are automatically closed after their operation, with two exceptions:
|
|
// After using Next and NextID on a query that did not yet return a non-nil
|
|
// error, you must call Close.
|
|
//
|
|
// A Query is not safe for concurrent use.
|
|
type Query[T any] struct {
|
|
ctx context.Context
|
|
ctxDone <-chan struct{} // ctx.Done(), kept here for fast access.
|
|
st storeType // Of T.
|
|
pkType reflect.Type // Shortcut for st.Current.Fields[0].
|
|
xtx *Tx // If nil, a new transaction is automatically created from db. Using a tx goes through tx() one exists.
|
|
xdb *DB // If not nil, xtx was created to execute the operation and is when the operation finishes (also on error).
|
|
err error // If set, returned by operations. For indicating failed filters, or that an operation has finished.
|
|
xfilterIDs *filterIDs[T] // Kept separately from filters because these filters make us use the PK without further index planning.
|
|
xfilters []filter[T]
|
|
xorders []order
|
|
|
|
// If 0, no limit has been set. Otherwise the number of IDs or records to
|
|
// read. Set by limit calls, and set to 1 for an Exists.
|
|
xlimit int
|
|
|
|
// Set when Next/NextID is called. We prevent non-Next calls after that moment.
|
|
// Should prevent unexpected results for users.
|
|
nextOnly bool
|
|
|
|
gatherIDs reflect.Value // Pointer to slice to pktype, for PKs of updated/deleted records.
|
|
gathers *[]T // For full records of updated/deleted records.
|
|
|
|
exec *exec[T]
|
|
|
|
stats Stats
|
|
}
|
|
|
|
// sentinel interface to for sanity checking.
|
|
type filter[T any] interface {
|
|
filter()
|
|
}
|
|
|
|
// filter by one or more IDs.
|
|
type filterIDs[T any] struct {
|
|
ids map[any]struct{} // Regular values.
|
|
pks [][]byte // Packed values.
|
|
}
|
|
|
|
func (filterIDs[T]) filter() {}
|
|
|
|
type filterFn[T any] struct {
|
|
fn func(value T) bool // Function to call for each record, returning true if the record is selected.
|
|
}
|
|
|
|
func (filterFn[T]) filter() {}
|
|
|
|
// Filter on field equality.
|
|
type filterEqual[T any] struct {
|
|
field field
|
|
rvalue reflect.Value
|
|
}
|
|
|
|
func (filterEqual[T]) filter() {}
|
|
|
|
// Filter on field non-equality.
|
|
type filterNotEqual[T any] struct {
|
|
field field
|
|
rvalue reflect.Value
|
|
}
|
|
|
|
func (filterNotEqual[T]) filter() {}
|
|
|
|
// Like filterEqual, but for one or more values.
|
|
type filterIn[T any] struct {
|
|
field field
|
|
rvalues []reflect.Value
|
|
}
|
|
|
|
func (filterIn[T]) filter() {}
|
|
|
|
// Like filterNonEqual, but for one or more values.
|
|
type filterNotIn[T any] struct {
|
|
field field
|
|
rvalues []reflect.Value
|
|
}
|
|
|
|
func (filterNotIn[T]) filter() {}
|
|
|
|
// For matching one of the values in a field that is a slice of the same type.
|
|
type filterInSlice[T any] struct {
|
|
field field // Of field type, a slice.
|
|
rvalue reflect.Value
|
|
}
|
|
|
|
func (filterInSlice[T]) filter() {}
|
|
|
|
type compareOp byte
|
|
|
|
const (
|
|
opGreater compareOp = iota
|
|
opGreaterEqual
|
|
opLess
|
|
opLessEqual
|
|
)
|
|
|
|
// filter by comparison.
|
|
type filterCompare[T any] struct {
|
|
field field
|
|
op compareOp
|
|
value reflect.Value
|
|
}
|
|
|
|
func (filterCompare[T]) filter() {}
|
|
|
|
// ordering of result.
|
|
type order struct {
|
|
field field
|
|
asc bool
|
|
}
|
|
|
|
// Pair represents a primary key with lazily loaded record data. When user only
|
|
// cares about IDs we don't have to parse the full record. And if we go through
|
|
// in index we don't have to fetch the full record either.
|
|
type pair[T any] struct {
|
|
bk []byte
|
|
bv []byte // If nil, data must be loaded.
|
|
value *T // If not nil, the parsed form of bv.
|
|
}
|
|
|
|
// Value returns a fully parsed record. It first fetches the record data if not
|
|
// yet present.
|
|
func (p *pair[T]) Value(e *exec[T]) (T, error) {
|
|
var zero T
|
|
if p.value != nil {
|
|
return *p.value, nil
|
|
}
|
|
if p.bv == nil {
|
|
e.q.stats.Records.Get++
|
|
p.bv = e.rb.Get(p.bk)
|
|
if p.bv == nil {
|
|
return zero, fmt.Errorf("%w: no data for key", ErrStore)
|
|
}
|
|
}
|
|
var v T
|
|
err := e.q.st.parseFull(reflect.ValueOf(&v).Elem(), p.bk, p.bv)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
p.value = &v
|
|
return v, nil
|
|
}
|
|
|
|
// QueryDB returns a new Query for type T. When an operation on the query is
|
|
// executed, a read-only/writable transaction is created as appropriate for the
|
|
// operation.
|
|
func QueryDB[T any](ctx context.Context, db *DB) *Query[T] {
|
|
// We lock db for storeTypes. We keep it locked until Query is done.
|
|
db.typesMutex.RLock()
|
|
q := &Query[T]{xdb: db}
|
|
q.init(ctx, db)
|
|
return q
|
|
}
|
|
|
|
// QueryTx returns a new Query that operates on type T using transaction tx.
|
|
// The context of the transaction is used for the query.
|
|
func QueryTx[T any](tx *Tx) *Query[T] {
|
|
// note: Since we are in a transaction, we already hold an rlock on the
|
|
// db types.
|
|
q := &Query[T]{xtx: tx}
|
|
if tx.err != nil {
|
|
q.err = tx.err
|
|
return q
|
|
}
|
|
q.init(tx.ctx, tx.db)
|
|
return q
|
|
}
|
|
|
|
func (q *Query[T]) ctxErr() error {
|
|
select {
|
|
case <-q.ctxDone:
|
|
return q.ctx.Err()
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Stats returns the current statistics for this query. When a query finishes,
|
|
// its stats are added to those of its transaction. When a transaction
|
|
// finishes, its stats are added to those of its database.
|
|
func (q *Query[T]) Stats() Stats {
|
|
return q.stats
|
|
}
|
|
|
|
func (q *Query[T]) init(ctx context.Context, db *DB) {
|
|
var v T
|
|
t := reflect.TypeOf(v)
|
|
if t.Kind() != reflect.Struct {
|
|
q.errorf("%w: type must be struct, not pointer or other type", ErrType)
|
|
return
|
|
}
|
|
q.st, q.err = db.storeType(t)
|
|
if q.err == nil {
|
|
q.stats.LastType = q.st.Name
|
|
q.pkType = q.st.Current.Fields[0].structField.Type
|
|
}
|
|
q.ctx = ctx
|
|
q.ctxDone = ctx.Done()
|
|
if err := q.ctxErr(); q.err == nil && err != nil {
|
|
q.err = err
|
|
}
|
|
}
|
|
|
|
func (q *Query[T]) tx(write bool) (*Tx, error) {
|
|
if q.xtx == nil {
|
|
if q.xdb == nil {
|
|
q.errorf("%w: missing db and tx: use QueryDB or QueryTx to make a new Query", ErrParam)
|
|
return nil, q.err
|
|
}
|
|
tx, err := q.xdb.bdb.Begin(write)
|
|
if err != nil {
|
|
q.error(err)
|
|
return nil, q.err
|
|
}
|
|
q.xtx = &Tx{ctx: q.ctx, db: q.xdb, btx: tx}
|
|
if write {
|
|
q.stats.Writes++
|
|
} else {
|
|
q.stats.Reads++
|
|
}
|
|
}
|
|
return q.xtx, nil
|
|
}
|
|
|
|
// error sets an error for the query, to be returned when next operations are executed.
|
|
// All Query instances go through this function for proper rollback and/or runlock
|
|
// as needed. If a query finished successfully, ErrFinished is set.
|
|
func (q *Query[T]) error(err error) {
|
|
if q.xtx != nil && q.xdb != nil {
|
|
txerr := q.xtx.btx.Rollback()
|
|
if sanityChecks && txerr != nil {
|
|
panic(fmt.Sprintf("xtx rollback: %v", txerr))
|
|
}
|
|
q.dbAddStats()
|
|
q.xtx = nil
|
|
}
|
|
if q.xdb != nil {
|
|
q.xdb.typesMutex.RUnlock()
|
|
q.xdb = nil
|
|
}
|
|
if q.xtx != nil {
|
|
q.txAddStats()
|
|
}
|
|
// This is the only place besides init that sets an error on query.
|
|
q.err = err
|
|
}
|
|
|
|
// errorf calls error with a formatted error.
|
|
func (q *Query[T]) errorf(format string, args ...any) {
|
|
q.error(fmt.Errorf(format, args...))
|
|
}
|
|
|
|
// Close closes a Query. Must always be called for Queries on which Next or
|
|
// NextID was called. Other operations call Close themselves.
|
|
func (q *Query[T]) Close() error {
|
|
var err error
|
|
if q.xtx != nil && q.xdb != nil {
|
|
err = q.xtx.btx.Rollback()
|
|
q.dbAddStats()
|
|
q.xtx = nil
|
|
}
|
|
q.error(ErrFinished)
|
|
return err
|
|
}
|
|
|
|
// txAddStats adds stats to a transaction that Query did not create.
|
|
func (q *Query[T]) txAddStats() {
|
|
q.xtx.stats.add(q.stats)
|
|
q.stats = Stats{}
|
|
}
|
|
|
|
// dbAddStats adds stats to the database directly, because Query created the
|
|
// transaction and the tx is never exposed, so no need to go through it.
|
|
func (q *Query[T]) dbAddStats() {
|
|
q.xdb.statsMutex.Lock()
|
|
q.xdb.stats.add(q.stats)
|
|
q.xdb.statsMutex.Unlock()
|
|
q.stats = Stats{}
|
|
}
|
|
|
|
// Operations that will do database operations get a defer call to this finish
|
|
// function, to ensure we also close transactions that we made.
|
|
func (q *Query[T]) finish(rerr *error) {
|
|
if q.xtx != nil && q.xdb != nil {
|
|
if *rerr == nil && q.xtx.btx.Writable() {
|
|
if err := q.xtx.btx.Commit(); err != nil {
|
|
*rerr = err
|
|
}
|
|
} else if err := q.xtx.btx.Rollback(); err != nil && sanityChecks {
|
|
panic(fmt.Errorf("rolling back: %v", err))
|
|
}
|
|
q.dbAddStats()
|
|
q.xtx = nil
|
|
}
|
|
x := recover()
|
|
if x != nil {
|
|
q.errorf("%v", x)
|
|
panic(x)
|
|
}
|
|
q.error(ErrFinished)
|
|
}
|
|
|
|
// checkNotNext is called by all operations except Next and NextID to ensure
|
|
// that the user does not mix Next/NextID with regular operations.
|
|
func (q *Query[T]) checkNotNext() {
|
|
if q.err == nil && q.nextOnly {
|
|
q.errorf("%w: can only use further Next calls", ErrParam)
|
|
}
|
|
}
|
|
|
|
func (q *Query[T]) checkErr() bool {
|
|
if q.err == nil && q.xtx == nil && q.xdb == nil {
|
|
// Probably the result of using a Query zero value.
|
|
q.errorf("%w: invalid query, use QueryDB or QueryTx to make a query", ErrParam)
|
|
}
|
|
if q.err == nil {
|
|
if err := q.ctxErr(); err != nil {
|
|
q.err = err
|
|
}
|
|
}
|
|
return q.err == nil
|
|
}
|
|
|
|
func (q *Query[T]) addFilter(f filter[T]) {
|
|
q.xfilters = append(q.xfilters, f)
|
|
}
|
|
|
|
// nextKey returns the key and optionally value for the next matching record.
|
|
// If there is no more matching record, ErrAbsent is returned and the query
|
|
// finished. ErrAbsent should be set on the query by the calling operation if
|
|
// appropriate (but not for Update/Delete, because it would prevent further
|
|
// operations on the query and its transaction).
|
|
//
|
|
// The actual work is handled by executing a query plan. One is created on the
|
|
// first call, and the nextKey is forwarded to the plan execution thereafter.
|
|
//
|
|
// write indicates if a writable tx needs to be created (if any) for the
|
|
// operation that is initiating this data selection.
|
|
//
|
|
// value indicates if a full record should be parsed and returned, as opposed
|
|
// to only the PK. Some callers only care about the IDs of records, which can
|
|
// be handled more efficiently when going through an index.
|
|
func (q *Query[T]) nextKey(write, value bool) ([]byte, T, error) {
|
|
if q.exec == nil {
|
|
p, err := q.selectPlan()
|
|
if err != nil {
|
|
q.error(err)
|
|
var zero T
|
|
return nil, zero, err
|
|
}
|
|
// log.Printf("plan %#v", p)
|
|
q.exec = p.exec(q)
|
|
}
|
|
return q.exec.nextKey(write, value)
|
|
}
|
|
|
|
// fetch the PK of the next selected record, and parse into pkv.
|
|
func (q *Query[T]) nextID(write bool, pkv reflect.Value) error {
|
|
bk, _, err := q.nextKey(write, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return parsePK(pkv, bk)
|
|
}
|
|
|
|
// foreachKey calls fn on each selected record. If value is set, fn's v is set,
|
|
// otherwise the zero value.
|
|
func (q *Query[T]) foreachKey(write, value bool, fn func(bk []byte, v T) error) error {
|
|
if q.err != nil {
|
|
return q.err
|
|
}
|
|
for {
|
|
bk, v, err := q.nextKey(write, value)
|
|
if err == ErrAbsent {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
} else if err := fn(bk, v); err == StopForEach {
|
|
q.error(ErrFinished)
|
|
return nil
|
|
} else if err != nil {
|
|
q.error(err)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// foreachID calls fn with the primary key value for each selected record.
|
|
func (q *Query[T]) foreachID(write bool, fn func(pkv any) error) error {
|
|
if q.err != nil {
|
|
return q.err
|
|
}
|
|
v := reflect.New(q.pkType).Elem()
|
|
for {
|
|
err := q.nextID(write, v)
|
|
if err == ErrAbsent {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
} else if err := fn(v.Interface()); err != nil {
|
|
q.error(err)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// lookup field name in the current typeVersion.
|
|
func (q *Query[T]) lookupField(name string) (field, bool) {
|
|
for _, ff := range q.st.Current.Fields {
|
|
if ff.Name == name {
|
|
return ff, true
|
|
}
|
|
}
|
|
q.errorf("%w: unknown field %q", ErrParam, name)
|
|
return field{}, false
|
|
}
|
|
|
|
// Kinds that can be converted without loss of precision, identity is not in here.
|
|
type convertKinds struct{ from, to kind }
|
|
|
|
var convertFieldKinds = map[convertKinds]struct{}{
|
|
{kindInt8, kindInt16}: {},
|
|
{kindInt8, kindInt32}: {},
|
|
{kindInt8, kindInt64}: {},
|
|
{kindInt8, kindInt}: {},
|
|
{kindInt16, kindInt32}: {},
|
|
{kindInt16, kindInt64}: {},
|
|
{kindInt16, kindInt}: {},
|
|
{kindInt32, kindInt}: {},
|
|
{kindInt32, kindInt64}: {},
|
|
{kindInt, kindInt32}: {},
|
|
{kindInt, kindInt64}: {},
|
|
|
|
{kindUint8, kindUint16}: {},
|
|
{kindUint8, kindUint32}: {},
|
|
{kindUint8, kindUint64}: {},
|
|
{kindUint8, kindUint}: {},
|
|
{kindUint16, kindUint32}: {},
|
|
{kindUint16, kindUint64}: {},
|
|
{kindUint16, kindUint}: {},
|
|
{kindUint32, kindUint}: {},
|
|
{kindUint32, kindUint64}: {},
|
|
{kindUint, kindUint32}: {},
|
|
{kindUint, kindUint64}: {},
|
|
|
|
{kindFloat32, kindFloat64}: {},
|
|
}
|
|
|
|
// Check type of value for field and return a reflect value that can directly be set on the field.
|
|
// If the field is a pointer, we allow non-pointers and convert them.
|
|
// We require value to be of a type that can be converted without loss of precision to the type of field.
|
|
func (q *Query[T]) prepareValue(fname string, ft fieldType, st reflect.Type, rv reflect.Value) (reflect.Value, bool) {
|
|
if !rv.IsValid() {
|
|
q.errorf("%w: invalid value", ErrParam)
|
|
return rv, false
|
|
}
|
|
// Quick check first.
|
|
t := rv.Type()
|
|
if t == st {
|
|
return rv, true
|
|
}
|
|
if !ft.Ptr && rv.Kind() == reflect.Ptr {
|
|
q.errorf("%w: cannot set ptr value to nonptr field", ErrParam)
|
|
return rv, false
|
|
}
|
|
|
|
k, err := typeKind(t)
|
|
if err != nil {
|
|
q.errorf("%w: type of field: %s", ErrParam, err)
|
|
return reflect.Value{}, false
|
|
}
|
|
if _, ok := convertFieldKinds[convertKinds{k, ft.Kind}]; !ok && k != ft.Kind {
|
|
q.errorf("%w: got %v for field %q, need %v", ErrParam, rv.Type(), fname, ft.Kind)
|
|
return reflect.Value{}, false
|
|
}
|
|
if k != ft.Kind {
|
|
dt := st
|
|
if ft.Ptr {
|
|
dt = dt.Elem()
|
|
}
|
|
rv = rv.Convert(dt)
|
|
}
|
|
if ft.Ptr && rv.Kind() != reflect.Ptr {
|
|
nv := reflect.New(st.Elem())
|
|
nv.Elem().Set(rv)
|
|
rv = nv
|
|
}
|
|
return rv, true
|
|
}
|
|
|
|
// checkPK checks if t is the type of the current typeVersion's PK, and returns
|
|
// a userfriendly error message otherwise.
|
|
func (q *Query[T]) checkPK(t reflect.Type) bool {
|
|
if t != q.pkType {
|
|
q.errorf("%w: id type was %s, must be %s", ErrParam, t, q.pkType)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// FilterID selects the records with primary key id, which must be of the type
|
|
// of T's primary key.
|
|
func (q *Query[T]) FilterID(id any) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
kv := reflect.ValueOf(id)
|
|
if !q.checkPK(kv.Type()) {
|
|
return q
|
|
}
|
|
pk, err := packPK(kv)
|
|
if err != nil {
|
|
q.error(err)
|
|
return q
|
|
}
|
|
|
|
if q.xfilterIDs != nil {
|
|
// Intersection of this ID with the previous IDs. Either it is this single ID or the list becomes empty.
|
|
if _, ok := q.xfilterIDs.ids[id]; !ok {
|
|
q.xfilterIDs = &filterIDs[T]{map[any]struct{}{}, [][]byte{}}
|
|
return q
|
|
}
|
|
}
|
|
q.xfilterIDs = &filterIDs[T]{map[any]struct{}{id: {}}, [][]byte{pk}}
|
|
return q
|
|
}
|
|
|
|
// FilterIDs selects the records with a primary key that is in ids. Ids must be
|
|
// a slice of T's primary key type.
|
|
func (q *Query[T]) FilterIDs(ids any) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
kv := reflect.ValueOf(ids)
|
|
if kv.Kind() != reflect.Slice {
|
|
q.errorf("%w: ids must be slice of %v, not %T", ErrParam, q.pkType, ids)
|
|
return q
|
|
}
|
|
if !q.checkPK(kv.Type().Elem()) {
|
|
return q
|
|
}
|
|
|
|
n := kv.Len()
|
|
pks := make([][]byte, 0, n)
|
|
var prevIDs map[any]struct{}
|
|
if q.xfilterIDs != nil {
|
|
prevIDs = q.xfilterIDs.ids // We use this to check intersection.
|
|
}
|
|
// todo: should we fail for a zero PK?
|
|
nids := map[any]struct{}{}
|
|
for i := 0; i < n; i++ {
|
|
rev := kv.Index(i)
|
|
ev := rev.Interface()
|
|
if _, ok := prevIDs[ev]; !ok && prevIDs != nil {
|
|
continue
|
|
}
|
|
nids[ev] = struct{}{}
|
|
pk, err := packPK(rev)
|
|
if err != nil {
|
|
q.error(err)
|
|
return q
|
|
}
|
|
pks = append(pks, pk)
|
|
}
|
|
q.xfilterIDs = &filterIDs[T]{nids, pks}
|
|
return q
|
|
}
|
|
|
|
// FilterFn calls fn for each record selected so far. If fn returns true, the
|
|
// record is kept for further filters and finally the operation.
|
|
func (q *Query[T]) FilterFn(fn func(value T) bool) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
if fn == nil {
|
|
q.errorf("%w: nil fn", ErrParam)
|
|
return q
|
|
}
|
|
q.addFilter(filterFn[T]{fn})
|
|
return q
|
|
}
|
|
|
|
// gatherNonzeroFields returns fields and values that are non-zero. Used for
|
|
// Update and FilterNonzero.
|
|
//
|
|
// allowID indicates if the primary key is allowed to be nonzero (not for
|
|
// Updates).
|
|
//
|
|
// At least one field must be nonzero.
|
|
func gatherNonzeroFields(tv *typeVersion, rv reflect.Value, allowID bool) ([]field, []reflect.Value, error) {
|
|
var fields []field
|
|
var values []reflect.Value
|
|
|
|
for i, f := range tv.Fields {
|
|
fv := rv.FieldByIndex(f.structField.Index)
|
|
if f.Type.isZero(fv) {
|
|
continue
|
|
}
|
|
if i == 0 && !allowID {
|
|
return nil, nil, fmt.Errorf("%w: primary key must be zero", ErrParam)
|
|
}
|
|
fields = append(fields, f)
|
|
values = append(values, fv)
|
|
}
|
|
if len(fields) == 0 {
|
|
return nil, nil, fmt.Errorf("%w: must have at least one nonzero field", ErrParam)
|
|
}
|
|
return fields, values, nil
|
|
}
|
|
|
|
// FilterNonzero gathers the nonzero fields from value, and selects records that
|
|
// have equal values for those fields. At least one value must be nonzero. If a
|
|
// value comes from an external source, e.g. user input, make sure it is not
|
|
// the zero value.
|
|
//
|
|
// Keep in mind that filtering on an embed/anonymous field looks at individual
|
|
// fields in the embedded field for non-zeroness, not at the embed field as a whole.
|
|
func (q *Query[T]) FilterNonzero(value T) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
fields, values, err := gatherNonzeroFields(q.st.Current, reflect.ValueOf(value), true)
|
|
if err != nil {
|
|
q.error(err)
|
|
return q
|
|
}
|
|
for i, f := range fields {
|
|
if f.Name == q.st.Current.Fields[0].Name {
|
|
q.FilterID(values[i].Interface())
|
|
} else {
|
|
q.addFilter(filterEqual[T]{f, values[i]})
|
|
}
|
|
}
|
|
return q
|
|
}
|
|
|
|
// FilterEqual selects records that have one of values for fieldName.
|
|
//
|
|
// Note: Value must be a compatible type for comparison with fieldName. Go
|
|
// constant numbers become ints, which are not compatible with uint or float
|
|
// types.
|
|
func (q *Query[T]) FilterEqual(fieldName string, values ...any) *Query[T] {
|
|
q.filterEqual(fieldName, values, false)
|
|
return q
|
|
}
|
|
|
|
// FilterNotEqual selects records that do not have any of values for fieldName.
|
|
func (q *Query[T]) FilterNotEqual(fieldName string, values ...any) *Query[T] {
|
|
q.filterEqual(fieldName, values, true)
|
|
return q
|
|
}
|
|
|
|
func (q *Query[T]) filterEqual(fieldName string, values []any, not bool) {
|
|
if !q.checkErr() {
|
|
return
|
|
}
|
|
ff, ok := q.lookupField(fieldName)
|
|
if !ok {
|
|
return
|
|
}
|
|
if len(values) == 0 {
|
|
q.errorf("%w: need at least one value for (not) equal", ErrParam)
|
|
return
|
|
}
|
|
if ff.Type.Ptr {
|
|
q.errorf("%w: cannot compare pointer values", ErrParam)
|
|
return
|
|
}
|
|
if len(values) == 1 {
|
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(values[0]))
|
|
if !ok {
|
|
return
|
|
}
|
|
if not {
|
|
q.addFilter(filterNotEqual[T]{ff, rv})
|
|
} else {
|
|
q.addFilter(filterEqual[T]{ff, rv})
|
|
}
|
|
return
|
|
}
|
|
rvs := make([]reflect.Value, len(values))
|
|
for i, value := range values {
|
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(value))
|
|
if !ok {
|
|
return
|
|
}
|
|
rvs[i] = rv
|
|
}
|
|
if not {
|
|
q.addFilter(filterNotIn[T]{ff, rvs})
|
|
} else {
|
|
q.addFilter(filterIn[T]{ff, rvs})
|
|
}
|
|
}
|
|
|
|
// FilterIn selects records that have one of values of the string slice fieldName.
|
|
//
|
|
// If fieldName has an index, it is used to select rows.
|
|
//
|
|
// Note: Value must be a compatible type for comparison with the elements of
|
|
// fieldName. Go constant numbers become ints, which are not compatible with uint
|
|
// or float types.
|
|
func (q *Query[T]) FilterIn(fieldName string, value any) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
ff, ok := q.lookupField(fieldName)
|
|
if !ok {
|
|
return q
|
|
}
|
|
if ff.Type.Ptr {
|
|
q.errorf("%w: cannot compare pointer values", ErrParam)
|
|
return q
|
|
}
|
|
if ff.Type.Kind != kindSlice {
|
|
q.errorf("%w: field for FilterIn must be a slice", ErrParam)
|
|
return q
|
|
}
|
|
et := ff.Type.ListElem
|
|
if et.Ptr {
|
|
q.errorf("%w: cannot compare element pointer values", ErrParam)
|
|
return q
|
|
}
|
|
rv, ok := q.prepareValue(ff.Name, *et, ff.structField.Type.Elem(), reflect.ValueOf(value))
|
|
if !ok {
|
|
return q
|
|
}
|
|
q.addFilter(filterInSlice[T]{ff, rv})
|
|
return q
|
|
}
|
|
|
|
// FilterGreater selects records that have fieldName > value.
|
|
//
|
|
// Note: Value must be a compatible type for comparison with fieldName. Go
|
|
// constant numbers become ints, which are not compatible with uint or float
|
|
// types.
|
|
func (q *Query[T]) FilterGreater(fieldName string, value any) *Query[T] {
|
|
return q.filterCompare(fieldName, opGreater, reflect.ValueOf(value))
|
|
}
|
|
|
|
// FilterGreaterEqual selects records that have fieldName >= value.
|
|
func (q *Query[T]) FilterGreaterEqual(fieldName string, value any) *Query[T] {
|
|
return q.filterCompare(fieldName, opGreaterEqual, reflect.ValueOf(value))
|
|
}
|
|
|
|
// FilterLess selects records that have fieldName < value.
|
|
func (q *Query[T]) FilterLess(fieldName string, value any) *Query[T] {
|
|
return q.filterCompare(fieldName, opLess, reflect.ValueOf(value))
|
|
}
|
|
|
|
// FilterLessEqual selects records that have fieldName <= value.
|
|
func (q *Query[T]) FilterLessEqual(fieldName string, value any) *Query[T] {
|
|
return q.filterCompare(fieldName, opLessEqual, reflect.ValueOf(value))
|
|
}
|
|
|
|
func (q *Query[T]) filterCompare(fieldName string, op compareOp, value reflect.Value) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
ff, ok := q.lookupField(fieldName)
|
|
if !ok {
|
|
return q
|
|
}
|
|
if !comparable(ff.Type) {
|
|
q.errorf("%w: cannot compare %s", ErrParam, ff.Type.Kind)
|
|
return q
|
|
}
|
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, value)
|
|
if !ok {
|
|
return q
|
|
}
|
|
q.addFilter(filterCompare[T]{ff, op, rv})
|
|
return q
|
|
}
|
|
|
|
// Limit stops selecting records after the first n records.
|
|
// Can only be called once. n must be > 1.
|
|
func (q *Query[T]) Limit(n int) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
if n <= 0 {
|
|
q.errorf("%w: limit must be >= 1", ErrParam)
|
|
return q
|
|
}
|
|
if q.xlimit > 0 {
|
|
q.errorf("%w: already have a limit", ErrParam)
|
|
return q
|
|
}
|
|
q.xlimit = n
|
|
return q
|
|
}
|
|
|
|
// SortAsc sorts the selected records by fieldNames in ascending order.
|
|
// Additional orderings can be added by more calls to SortAsc or SortDesc.
|
|
func (q *Query[T]) SortAsc(fieldNames ...string) *Query[T] {
|
|
return q.order(fieldNames, true)
|
|
}
|
|
|
|
// SortDesc sorts the selected records by fieldNames in descending order.
|
|
// Additional orderings can be added by more calls to SortAsc or SortDesc.
|
|
func (q *Query[T]) SortDesc(fieldNames ...string) *Query[T] {
|
|
return q.order(fieldNames, false)
|
|
}
|
|
|
|
func (q *Query[T]) order(fieldNames []string, asc bool) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
if len(fieldNames) == 0 {
|
|
q.errorf("%w: sort fieldNames must be non-empty", ErrParam)
|
|
return q
|
|
}
|
|
for _, name := range fieldNames {
|
|
ff, ok := q.lookupField(name)
|
|
if !ok {
|
|
return q
|
|
}
|
|
if !comparable(ff.Type) {
|
|
q.errorf("%w: cannot sort by unorderable %q", ErrParam, name)
|
|
return q
|
|
}
|
|
q.xorders = append(q.xorders, order{ff, asc})
|
|
}
|
|
return q
|
|
}
|
|
|
|
// Gather causes an Update or Delete operation to return the values of the
|
|
// affect records into l. For Update, the updated records are returned.
|
|
func (q *Query[T]) Gather(l *[]T) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
if l == nil {
|
|
q.errorf("%w: l must be non-nil", ErrParam)
|
|
return q
|
|
}
|
|
if q.gathers != nil {
|
|
q.errorf("%w: can only have one Gather", ErrParam)
|
|
return q
|
|
}
|
|
q.gathers = l
|
|
return q
|
|
}
|
|
|
|
// GatherIDs causes an Update or Delete operation to return the primary keys of
|
|
// affected records into ids, which must be a pointer to a slice of T's
|
|
// primary key.
|
|
func (q *Query[T]) GatherIDs(ids any) *Query[T] {
|
|
if !q.checkErr() {
|
|
return q
|
|
}
|
|
if ids == nil {
|
|
q.errorf("%w: ids must be non-nil", ErrParam)
|
|
return q
|
|
}
|
|
rv := reflect.ValueOf(ids)
|
|
t := rv.Type()
|
|
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice || t.Elem().Elem() != q.pkType {
|
|
q.errorf("%w: ids must be pointer to slice of %v, not %T", ErrParam, q.pkType, ids)
|
|
return q
|
|
}
|
|
if q.gatherIDs.IsValid() {
|
|
q.errorf("%w: can only have one GatherIDs", ErrParam)
|
|
return q
|
|
}
|
|
q.gatherIDs = rv
|
|
return q
|
|
}
|
|
|
|
func (q *Query[T]) gather(v T, rv reflect.Value) {
|
|
if q.gathers != nil {
|
|
*q.gathers = append(*q.gathers, v)
|
|
}
|
|
if q.gatherIDs.IsValid() {
|
|
ridv := rv.FieldByIndex(q.st.Current.Fields[0].structField.Index)
|
|
l := q.gatherIDs.Elem()
|
|
nl := reflect.Append(l, ridv)
|
|
l.Set(nl)
|
|
}
|
|
}
|
|
|
|
// Err returns if an error is set on the query. Can happen for invalid filters or
|
|
// canceled contexts.
|
|
// Finished queries return ErrFinished.
|
|
func (q *Query[T]) Err() error {
|
|
q.checkErr()
|
|
return q.err
|
|
}
|
|
|
|
// Delete removes the selected records, returning how many were deleted.
|
|
//
|
|
// See Gather and GatherIDs for collecting the deleted records or IDs.
|
|
func (q *Query[T]) Delete() (deleted int, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return 0, q.err
|
|
}
|
|
|
|
n := 0
|
|
err := q.foreachKey(true, true, func(bk []byte, ov T) error {
|
|
n++
|
|
rov := reflect.ValueOf(ov)
|
|
q.gather(ov, rov)
|
|
q.stats.Delete++
|
|
return q.xtx.delete(q.exec.rb, q.st, bk, rov)
|
|
})
|
|
return n, err
|
|
}
|
|
|
|
// Get returns the single selected record.
|
|
//
|
|
// ErrMultiple is returned if multiple records were selected.
|
|
// ErrAbsent is returned if no record was selected.
|
|
func (q *Query[T]) Get() (value T, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
var zero T
|
|
return zero, q.err
|
|
}
|
|
|
|
if _, v, err := q.nextKey(false, true); err != nil {
|
|
return v, err
|
|
} else if _, _, err := q.nextKey(false, false); err == nil {
|
|
return v, ErrMultiple
|
|
} else {
|
|
return v, nil
|
|
}
|
|
}
|
|
|
|
// Count returns the number of selected records.
|
|
func (q *Query[T]) Count() (n int, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return 0, q.err
|
|
}
|
|
|
|
err := q.foreachKey(false, false, func(kb []byte, unused T) error {
|
|
n++
|
|
return nil
|
|
})
|
|
return n, err
|
|
}
|
|
|
|
// List returns all selected records.
|
|
// On success with zero selected records, List returns the empty list.
|
|
func (q *Query[T]) List() (list []T, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return nil, q.err
|
|
}
|
|
|
|
l := []T{}
|
|
err := q.foreachKey(false, true, func(unused []byte, v T) error {
|
|
l = append(l, v)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return l, nil
|
|
}
|
|
|
|
// UpdateNonzero updates all selected records with the non-zero fields from
|
|
// value, returning the number of records updated.
|
|
//
|
|
// Recall that false, 0, "" are all zero values. Use UpdateField or
|
|
// UpdateFields to set fields to zero their value. This is especially relevant
|
|
// if the field value comes from an external source, e.g. user input.
|
|
//
|
|
// See Gather and GatherIDs for collecting the updated records or IDs.
|
|
//
|
|
// Keep in mind that updating on an embed/anonymous field looks at individual
|
|
// fields in the embedded field for non-zeroness, not at the embed field as a whole.
|
|
func (q *Query[T]) UpdateNonzero(value T) (updated int, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return 0, q.err
|
|
}
|
|
|
|
fields, values, err := gatherNonzeroFields(q.st.Current, reflect.ValueOf(value), false)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
sfl := make([]reflect.StructField, len(fields))
|
|
for i, f := range fields {
|
|
sfl[i] = f.structField
|
|
}
|
|
return q.update(sfl, values)
|
|
}
|
|
|
|
// UpdateField calls UpdateFields for fieldName and value.
|
|
func (q *Query[T]) UpdateField(fieldName string, value any) (updated int, rerr error) {
|
|
return q.UpdateFields(map[string]any{fieldName: value})
|
|
}
|
|
|
|
// UpdateFields updates all selected records, setting fields named by the map
|
|
// keys of fieldValues to the corresponding map value and returning the number
|
|
// of records updated.
|
|
//
|
|
// See Gather and GatherIDs for collecting the updated records or IDs.
|
|
//
|
|
// Entire embed fields can be updated, as well as their individual embedded
|
|
// fields.
|
|
func (q *Query[T]) UpdateFields(fieldValues map[string]any) (updated int, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return 0, q.err
|
|
}
|
|
|
|
if len(fieldValues) == 0 {
|
|
return 0, fmt.Errorf("%w: must update at least one field", ErrParam)
|
|
}
|
|
|
|
fields := make([]reflect.StructField, 0, len(fieldValues))
|
|
values := make([]reflect.Value, 0, len(fieldValues))
|
|
next:
|
|
for name, value := range fieldValues {
|
|
for i, f := range q.st.Current.Fields {
|
|
if f.Name != name {
|
|
continue
|
|
}
|
|
if i == 0 {
|
|
return 0, fmt.Errorf("%w: cannot update primary key", ErrParam)
|
|
}
|
|
rv, ok := q.prepareValue(f.Name, f.Type, f.structField.Type, reflect.ValueOf(value))
|
|
if !ok {
|
|
return 0, q.err
|
|
}
|
|
fields = append(fields, f.structField)
|
|
values = append(values, rv)
|
|
continue next
|
|
}
|
|
for _, ef := range q.st.Current.embedFields {
|
|
if ef.Name != name {
|
|
continue
|
|
}
|
|
rv, ok := q.prepareValue(ef.Name, ef.Type, ef.structField.Type, reflect.ValueOf(value))
|
|
if !ok {
|
|
return 0, q.err
|
|
}
|
|
fields = append(fields, ef.structField)
|
|
values = append(values, rv)
|
|
continue next
|
|
}
|
|
return 0, fmt.Errorf("%w: unknown field %q", ErrParam, name)
|
|
}
|
|
return q.update(fields, values)
|
|
}
|
|
|
|
func (q *Query[T]) update(fields []reflect.StructField, values []reflect.Value) (int, error) {
|
|
n := 0
|
|
ov := reflect.New(q.st.Type).Elem()
|
|
err := q.foreachKey(true, true, func(bk []byte, v T) error {
|
|
n++
|
|
rv := reflect.ValueOf(&v).Elem()
|
|
ov.Set(rv)
|
|
for i, sf := range fields {
|
|
frv := rv.FieldByIndex(sf.Index)
|
|
frv.Set(values[i])
|
|
}
|
|
q.gather(v, rv)
|
|
q.stats.Update++
|
|
return q.xtx.update(q.exec.rb, q.st, rv, ov, bk)
|
|
})
|
|
return n, err
|
|
}
|
|
|
|
// IDs sets idsptr to the primary keys of selected records. Idptrs must be a
|
|
// slice of T's primary key type.
|
|
func (q *Query[T]) IDs(idsptr any) (rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return q.err
|
|
}
|
|
|
|
if idsptr == nil {
|
|
return fmt.Errorf("%w: idsptr must not be nil", ErrParam)
|
|
}
|
|
rv := reflect.ValueOf(idsptr)
|
|
if rv.Type().Kind() != reflect.Ptr || rv.Type().Elem().Kind() != reflect.Slice || rv.Type().Elem().Elem() != q.pkType {
|
|
return fmt.Errorf("%w: idsptr must be a ptr to slice of %v, not %T", ErrParam, q.pkType, idsptr)
|
|
}
|
|
|
|
s := reflect.MakeSlice(rv.Type().Elem(), 0, 0)
|
|
err := q.foreachID(false, func(pkv any) error {
|
|
s = reflect.Append(s, reflect.ValueOf(pkv))
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rv.Elem().Set(s)
|
|
return nil
|
|
}
|
|
|
|
// todo: should we have an iteration object that we can call Next and NextID on?
|
|
|
|
// Next fetches the next record, moving the cursor forward.
|
|
//
|
|
// ErrAbsent is returned if no more records match.
|
|
//
|
|
// Automatically created transactions are read-only.
|
|
//
|
|
// Close must be called on a Query on which Next or NextID was called and that
|
|
// is not yet finished, i.e. has not yet returned an error (including
|
|
// ErrAbsent).
|
|
func (q *Query[T]) Next() (value T, rerr error) {
|
|
// note: no q.finish preamble because caller iterates over result themselves.
|
|
if !q.checkErr() {
|
|
var zero T
|
|
return zero, q.err
|
|
}
|
|
|
|
q.nextOnly = true
|
|
_, v, err := q.nextKey(false, true)
|
|
if err == ErrAbsent {
|
|
q.error(err)
|
|
}
|
|
return v, err
|
|
}
|
|
|
|
// NextID is like Next, but only fetches the primary key of the next matching
|
|
// record, storing it in idptr.
|
|
func (q *Query[T]) NextID(idptr any) (rerr error) {
|
|
// note: no q.finish preamble because caller iterates over result themselves.
|
|
if !q.checkErr() {
|
|
return q.err
|
|
}
|
|
|
|
q.nextOnly = true
|
|
rpkv := reflect.ValueOf(idptr)
|
|
if idptr == nil {
|
|
q.errorf("%w: idptr must be non-nil", ErrParam)
|
|
return q.err
|
|
}
|
|
t := rpkv.Type()
|
|
if t.Kind() != reflect.Ptr || t.Elem() != q.pkType {
|
|
return fmt.Errorf("%w: value must be ptr to %v, not %v", ErrParam, q.pkType, t)
|
|
}
|
|
err := q.nextID(false, rpkv.Elem())
|
|
if err == ErrAbsent {
|
|
q.error(err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Exists returns whether any record was selected.
|
|
func (q *Query[T]) Exists() (exists bool, rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return false, q.err
|
|
}
|
|
|
|
q.xlimit = 1
|
|
_, _, err := q.nextKey(false, false)
|
|
if err == ErrAbsent {
|
|
return false, nil
|
|
}
|
|
return err == nil, err
|
|
}
|
|
|
|
// StopForEach is an error value that, if returned by the function passed to
|
|
// Query.ForEach, stops further iterations.
|
|
var StopForEach error = errors.New("stop foreach")
|
|
|
|
// ForEach calls fn on each selected record.
|
|
// If fn returns StopForEach, ForEach stops iterating, so no longer calls fn,
|
|
// and returns nil.
|
|
func (q *Query[T]) ForEach(fn func(value T) error) (rerr error) {
|
|
defer q.finish(&rerr)
|
|
q.checkNotNext()
|
|
if !q.checkErr() {
|
|
return q.err
|
|
}
|
|
|
|
return q.foreachKey(false, true, func(bk []byte, v T) error {
|
|
return fn(v)
|
|
})
|
|
}
|