mox/vendor/github.com/mjl-/bstore/query.go

1246 lines
33 KiB
Go

package bstore
import (
"context"
"errors"
"fmt"
"reflect"
)
// The convention for handling a errors on a Query: methods that return a bool
// will have set q.err using q.error(), which does cleanup. If a method returns
// an error, q.error was not yet called, but usually will be called immediately
// after with the returned err, possibly first adding context.
// Query selects data for Go struct T based on filters, sorting, limits. The
// query is completed by calling an operation, such as Count, Get, List,
// Update, Delete, etc.
//
// Record selection functions like FilterEqual and Limit return the (modified)
// query itself, allowing chaining of calls.
//
// Queries are automatically closed after their operation, with two exceptions:
// After using Next and NextID on a query that did not yet return a non-nil
// error, you must call Close.
//
// A Query is not safe for concurrent use.
type Query[T any] struct {
ctx context.Context
ctxDone <-chan struct{} // ctx.Done(), kept here for fast access.
st storeType // Of T.
pkType reflect.Type // Shortcut for st.Current.Fields[0].
xtx *Tx // If nil, a new transaction is automatically created from db. Using a tx goes through tx() one exists.
xdb *DB // If not nil, xtx was created to execute the operation and is when the operation finishes (also on error).
err error // If set, returned by operations. For indicating failed filters, or that an operation has finished.
xfilterIDs *filterIDs[T] // Kept separately from filters because these filters make us use the PK without further index planning.
xfilters []filter[T]
xorders []order
// If 0, no limit has been set. Otherwise the number of IDs or records to
// read. Set by limit calls, and set to 1 for an Exists.
xlimit int
// Set when Next/NextID is called. We prevent non-Next calls after that moment.
// Should prevent unexpected results for users.
nextOnly bool
gatherIDs reflect.Value // Pointer to slice to pktype, for PKs of updated/deleted records.
gathers *[]T // For full records of updated/deleted records.
exec *exec[T]
stats Stats
}
// sentinel interface to for sanity checking.
type filter[T any] interface {
filter()
}
// filter by one or more IDs.
type filterIDs[T any] struct {
ids map[any]struct{} // Regular values.
pks [][]byte // Packed values.
}
func (filterIDs[T]) filter() {}
type filterFn[T any] struct {
fn func(value T) bool // Function to call for each record, returning true if the record is selected.
}
func (filterFn[T]) filter() {}
// Filter on field equality.
type filterEqual[T any] struct {
field field
rvalue reflect.Value
}
func (filterEqual[T]) filter() {}
// Filter on field non-equality.
type filterNotEqual[T any] struct {
field field
rvalue reflect.Value
}
func (filterNotEqual[T]) filter() {}
// Like filterEqual, but for one or more values.
type filterIn[T any] struct {
field field
rvalues []reflect.Value
}
func (filterIn[T]) filter() {}
// Like filterNonEqual, but for one or more values.
type filterNotIn[T any] struct {
field field
rvalues []reflect.Value
}
func (filterNotIn[T]) filter() {}
// For matching one of the values in a field that is a slice of the same type.
type filterInSlice[T any] struct {
field field // Of field type, a slice.
rvalue reflect.Value
}
func (filterInSlice[T]) filter() {}
type compareOp byte
const (
opGreater compareOp = iota
opGreaterEqual
opLess
opLessEqual
)
// filter by comparison.
type filterCompare[T any] struct {
field field
op compareOp
value reflect.Value
}
func (filterCompare[T]) filter() {}
// ordering of result.
type order struct {
field field
asc bool
}
// Pair represents a primary key with lazily loaded record data. When user only
// cares about IDs we don't have to parse the full record. And if we go through
// in index we don't have to fetch the full record either.
type pair[T any] struct {
bk []byte
bv []byte // If nil, data must be loaded.
value *T // If not nil, the parsed form of bv.
}
// Value returns a fully parsed record. It first fetches the record data if not
// yet present.
func (p *pair[T]) Value(e *exec[T]) (T, error) {
var zero T
if p.value != nil {
return *p.value, nil
}
if p.bv == nil {
e.q.stats.Records.Get++
p.bv = e.rb.Get(p.bk)
if p.bv == nil {
return zero, fmt.Errorf("%w: no data for key", ErrStore)
}
}
var v T
err := e.q.st.parseFull(reflect.ValueOf(&v).Elem(), p.bk, p.bv)
if err != nil {
return zero, err
}
p.value = &v
return v, nil
}
// QueryDB returns a new Query for type T. When an operation on the query is
// executed, a read-only/writable transaction is created as appropriate for the
// operation.
func QueryDB[T any](ctx context.Context, db *DB) *Query[T] {
// We lock db for storeTypes. We keep it locked until Query is done.
db.typesMutex.RLock()
q := &Query[T]{xdb: db}
q.init(ctx, db)
return q
}
// QueryTx returns a new Query that operates on type T using transaction tx.
// The context of the transaction is used for the query.
func QueryTx[T any](tx *Tx) *Query[T] {
// note: Since we are in a transaction, we already hold an rlock on the
// db types.
q := &Query[T]{xtx: tx}
if tx.err != nil {
q.err = tx.err
return q
}
q.init(tx.ctx, tx.db)
return q
}
func (q *Query[T]) ctxErr() error {
select {
case <-q.ctxDone:
return q.ctx.Err()
default:
return nil
}
}
// Stats returns the current statistics for this query. When a query finishes,
// its stats are added to those of its transaction. When a transaction
// finishes, its stats are added to those of its database.
func (q *Query[T]) Stats() Stats {
return q.stats
}
func (q *Query[T]) init(ctx context.Context, db *DB) {
var v T
t := reflect.TypeOf(v)
if t.Kind() != reflect.Struct {
q.errorf("%w: type must be struct, not pointer or other type", ErrType)
return
}
q.st, q.err = db.storeType(t)
if q.err == nil {
q.stats.LastType = q.st.Name
q.pkType = q.st.Current.Fields[0].structField.Type
}
q.ctx = ctx
q.ctxDone = ctx.Done()
if err := q.ctxErr(); q.err == nil && err != nil {
q.err = err
}
}
func (q *Query[T]) tx(write bool) (*Tx, error) {
if q.xtx == nil {
if q.xdb == nil {
q.errorf("%w: missing db and tx: use QueryDB or QueryTx to make a new Query", ErrParam)
return nil, q.err
}
tx, err := q.xdb.bdb.Begin(write)
if err != nil {
q.error(err)
return nil, q.err
}
q.xtx = &Tx{ctx: q.ctx, db: q.xdb, btx: tx}
if write {
q.stats.Writes++
} else {
q.stats.Reads++
}
}
return q.xtx, nil
}
// error sets an error for the query, to be returned when next operations are executed.
// All Query instances go through this function for proper rollback and/or runlock
// as needed. If a query finished successfully, ErrFinished is set.
func (q *Query[T]) error(err error) {
if q.xtx != nil && q.xdb != nil {
txerr := q.xtx.btx.Rollback()
if sanityChecks && txerr != nil {
panic(fmt.Sprintf("xtx rollback: %v", txerr))
}
q.dbAddStats()
q.xtx = nil
}
if q.xdb != nil {
q.xdb.typesMutex.RUnlock()
q.xdb = nil
}
if q.xtx != nil {
q.txAddStats()
}
// This is the only place besides init that sets an error on query.
q.err = err
}
// errorf calls error with a formatted error.
func (q *Query[T]) errorf(format string, args ...any) {
q.error(fmt.Errorf(format, args...))
}
// Close closes a Query. Must always be called for Queries on which Next or
// NextID was called. Other operations call Close themselves.
func (q *Query[T]) Close() error {
var err error
if q.xtx != nil && q.xdb != nil {
err = q.xtx.btx.Rollback()
q.dbAddStats()
q.xtx = nil
}
q.error(ErrFinished)
return err
}
// txAddStats adds stats to a transaction that Query did not create.
func (q *Query[T]) txAddStats() {
q.xtx.stats.add(q.stats)
q.stats = Stats{}
}
// dbAddStats adds stats to the database directly, because Query created the
// transaction and the tx is never exposed, so no need to go through it.
func (q *Query[T]) dbAddStats() {
q.xdb.statsMutex.Lock()
q.xdb.stats.add(q.stats)
q.xdb.statsMutex.Unlock()
q.stats = Stats{}
}
// Operations that will do database operations get a defer call to this finish
// function, to ensure we also close transactions that we made.
func (q *Query[T]) finish(rerr *error) {
if q.xtx != nil && q.xdb != nil {
if *rerr == nil && q.xtx.btx.Writable() {
if err := q.xtx.btx.Commit(); err != nil {
*rerr = err
}
} else if err := q.xtx.btx.Rollback(); err != nil && sanityChecks {
panic(fmt.Errorf("rolling back: %v", err))
}
q.dbAddStats()
q.xtx = nil
}
x := recover()
if x != nil {
q.errorf("%v", x)
panic(x)
}
q.error(ErrFinished)
}
// checkNotNext is called by all operations except Next and NextID to ensure
// that the user does not mix Next/NextID with regular operations.
func (q *Query[T]) checkNotNext() {
if q.err == nil && q.nextOnly {
q.errorf("%w: can only use further Next calls", ErrParam)
}
}
func (q *Query[T]) checkErr() bool {
if q.err == nil && q.xtx == nil && q.xdb == nil {
// Probably the result of using a Query zero value.
q.errorf("%w: invalid query, use QueryDB or QueryTx to make a query", ErrParam)
}
if q.err == nil {
if err := q.ctxErr(); err != nil {
q.err = err
}
}
return q.err == nil
}
func (q *Query[T]) addFilter(f filter[T]) {
q.xfilters = append(q.xfilters, f)
}
// nextKey returns the key and optionally value for the next matching record.
// If there is no more matching record, ErrAbsent is returned and the query
// finished. ErrAbsent should be set on the query by the calling operation if
// appropriate (but not for Update/Delete, because it would prevent further
// operations on the query and its transaction).
//
// The actual work is handled by executing a query plan. One is created on the
// first call, and the nextKey is forwarded to the plan execution thereafter.
//
// write indicates if a writable tx needs to be created (if any) for the
// operation that is initiating this data selection.
//
// value indicates if a full record should be parsed and returned, as opposed
// to only the PK. Some callers only care about the IDs of records, which can
// be handled more efficiently when going through an index.
func (q *Query[T]) nextKey(write, value bool) ([]byte, T, error) {
if q.exec == nil {
p, err := q.selectPlan()
if err != nil {
q.error(err)
var zero T
return nil, zero, err
}
// log.Printf("plan %#v", p)
q.exec = p.exec(q)
}
return q.exec.nextKey(write, value)
}
// fetch the PK of the next selected record, and parse into pkv.
func (q *Query[T]) nextID(write bool, pkv reflect.Value) error {
bk, _, err := q.nextKey(write, false)
if err != nil {
return err
}
return parsePK(pkv, bk)
}
// foreachKey calls fn on each selected record. If value is set, fn's v is set,
// otherwise the zero value.
func (q *Query[T]) foreachKey(write, value bool, fn func(bk []byte, v T) error) error {
if q.err != nil {
return q.err
}
for {
bk, v, err := q.nextKey(write, value)
if err == ErrAbsent {
return nil
} else if err != nil {
return err
} else if err := fn(bk, v); err == StopForEach {
q.error(ErrFinished)
return nil
} else if err != nil {
q.error(err)
return err
}
}
}
// foreachID calls fn with the primary key value for each selected record.
func (q *Query[T]) foreachID(write bool, fn func(pkv any) error) error {
if q.err != nil {
return q.err
}
v := reflect.New(q.pkType).Elem()
for {
err := q.nextID(write, v)
if err == ErrAbsent {
return nil
} else if err != nil {
return err
} else if err := fn(v.Interface()); err != nil {
q.error(err)
return err
}
}
}
// lookup field name in the current typeVersion.
func (q *Query[T]) lookupField(name string) (field, bool) {
for _, ff := range q.st.Current.Fields {
if ff.Name == name {
return ff, true
}
}
q.errorf("%w: unknown field %q", ErrParam, name)
return field{}, false
}
// Kinds that can be converted without loss of precision, identity is not in here.
type convertKinds struct{ from, to kind }
var convertFieldKinds = map[convertKinds]struct{}{
{kindInt8, kindInt16}: {},
{kindInt8, kindInt32}: {},
{kindInt8, kindInt64}: {},
{kindInt8, kindInt}: {},
{kindInt16, kindInt32}: {},
{kindInt16, kindInt64}: {},
{kindInt16, kindInt}: {},
{kindInt32, kindInt}: {},
{kindInt32, kindInt64}: {},
{kindInt, kindInt32}: {},
{kindInt, kindInt64}: {},
{kindUint8, kindUint16}: {},
{kindUint8, kindUint32}: {},
{kindUint8, kindUint64}: {},
{kindUint8, kindUint}: {},
{kindUint16, kindUint32}: {},
{kindUint16, kindUint64}: {},
{kindUint16, kindUint}: {},
{kindUint32, kindUint}: {},
{kindUint32, kindUint64}: {},
{kindUint, kindUint32}: {},
{kindUint, kindUint64}: {},
{kindFloat32, kindFloat64}: {},
}
// Check type of value for field and return a reflect value that can directly be set on the field.
// If the field is a pointer, we allow non-pointers and convert them.
// We require value to be of a type that can be converted without loss of precision to the type of field.
func (q *Query[T]) prepareValue(fname string, ft fieldType, st reflect.Type, rv reflect.Value) (reflect.Value, bool) {
if !rv.IsValid() {
q.errorf("%w: invalid value", ErrParam)
return rv, false
}
// Quick check first.
t := rv.Type()
if t == st {
return rv, true
}
if !ft.Ptr && rv.Kind() == reflect.Ptr {
q.errorf("%w: cannot set ptr value to nonptr field", ErrParam)
return rv, false
}
k, err := typeKind(t)
if err != nil {
q.errorf("%w: type of field: %s", ErrParam, err)
return reflect.Value{}, false
}
if _, ok := convertFieldKinds[convertKinds{k, ft.Kind}]; !ok && k != ft.Kind {
q.errorf("%w: got %v for field %q, need %v", ErrParam, rv.Type(), fname, ft.Kind)
return reflect.Value{}, false
}
if k != ft.Kind {
dt := st
if ft.Ptr {
dt = dt.Elem()
}
rv = rv.Convert(dt)
}
if ft.Ptr && rv.Kind() != reflect.Ptr {
nv := reflect.New(st.Elem())
nv.Elem().Set(rv)
rv = nv
}
return rv, true
}
// checkPK checks if t is the type of the current typeVersion's PK, and returns
// a userfriendly error message otherwise.
func (q *Query[T]) checkPK(t reflect.Type) bool {
if t != q.pkType {
q.errorf("%w: id type was %s, must be %s", ErrParam, t, q.pkType)
return false
}
return true
}
// FilterID selects the records with primary key id, which must be of the type
// of T's primary key.
func (q *Query[T]) FilterID(id any) *Query[T] {
if !q.checkErr() {
return q
}
kv := reflect.ValueOf(id)
if !q.checkPK(kv.Type()) {
return q
}
pk, err := packPK(kv)
if err != nil {
q.error(err)
return q
}
if q.xfilterIDs != nil {
// Intersection of this ID with the previous IDs. Either it is this single ID or the list becomes empty.
if _, ok := q.xfilterIDs.ids[id]; !ok {
q.xfilterIDs = &filterIDs[T]{map[any]struct{}{}, [][]byte{}}
return q
}
}
q.xfilterIDs = &filterIDs[T]{map[any]struct{}{id: {}}, [][]byte{pk}}
return q
}
// FilterIDs selects the records with a primary key that is in ids. Ids must be
// a slice of T's primary key type.
func (q *Query[T]) FilterIDs(ids any) *Query[T] {
if !q.checkErr() {
return q
}
kv := reflect.ValueOf(ids)
if kv.Kind() != reflect.Slice {
q.errorf("%w: ids must be slice of %v, not %T", ErrParam, q.pkType, ids)
return q
}
if !q.checkPK(kv.Type().Elem()) {
return q
}
n := kv.Len()
pks := make([][]byte, 0, n)
var prevIDs map[any]struct{}
if q.xfilterIDs != nil {
prevIDs = q.xfilterIDs.ids // We use this to check intersection.
}
// todo: should we fail for a zero PK?
nids := map[any]struct{}{}
for i := 0; i < n; i++ {
rev := kv.Index(i)
ev := rev.Interface()
if _, ok := prevIDs[ev]; !ok && prevIDs != nil {
continue
}
nids[ev] = struct{}{}
pk, err := packPK(rev)
if err != nil {
q.error(err)
return q
}
pks = append(pks, pk)
}
q.xfilterIDs = &filterIDs[T]{nids, pks}
return q
}
// FilterFn calls fn for each record selected so far. If fn returns true, the
// record is kept for further filters and finally the operation.
func (q *Query[T]) FilterFn(fn func(value T) bool) *Query[T] {
if !q.checkErr() {
return q
}
if fn == nil {
q.errorf("%w: nil fn", ErrParam)
return q
}
q.addFilter(filterFn[T]{fn})
return q
}
// gatherNonzeroFields returns fields and values that are non-zero. Used for
// Update and FilterNonzero.
//
// allowID indicates if the primary key is allowed to be nonzero (not for
// Updates).
//
// At least one field must be nonzero.
func gatherNonzeroFields(tv *typeVersion, rv reflect.Value, allowID bool) ([]field, []reflect.Value, error) {
var fields []field
var values []reflect.Value
for i, f := range tv.Fields {
fv := rv.FieldByIndex(f.structField.Index)
if f.Type.isZero(fv) {
continue
}
if i == 0 && !allowID {
return nil, nil, fmt.Errorf("%w: primary key must be zero", ErrParam)
}
fields = append(fields, f)
values = append(values, fv)
}
if len(fields) == 0 {
return nil, nil, fmt.Errorf("%w: must have at least one nonzero field", ErrParam)
}
return fields, values, nil
}
// FilterNonzero gathers the nonzero fields from value, and selects records that
// have equal values for those fields. At least one value must be nonzero. If a
// value comes from an external source, e.g. user input, make sure it is not
// the zero value.
//
// Keep in mind that filtering on an embed/anonymous field looks at individual
// fields in the embedded field for non-zeroness, not at the embed field as a whole.
func (q *Query[T]) FilterNonzero(value T) *Query[T] {
if !q.checkErr() {
return q
}
fields, values, err := gatherNonzeroFields(q.st.Current, reflect.ValueOf(value), true)
if err != nil {
q.error(err)
return q
}
for i, f := range fields {
if f.Name == q.st.Current.Fields[0].Name {
q.FilterID(values[i].Interface())
} else {
q.addFilter(filterEqual[T]{f, values[i]})
}
}
return q
}
// FilterEqual selects records that have one of values for fieldName.
//
// Note: Value must be a compatible type for comparison with fieldName. Go
// constant numbers become ints, which are not compatible with uint or float
// types.
func (q *Query[T]) FilterEqual(fieldName string, values ...any) *Query[T] {
q.filterEqual(fieldName, values, false)
return q
}
// FilterNotEqual selects records that do not have any of values for fieldName.
func (q *Query[T]) FilterNotEqual(fieldName string, values ...any) *Query[T] {
q.filterEqual(fieldName, values, true)
return q
}
func (q *Query[T]) filterEqual(fieldName string, values []any, not bool) {
if !q.checkErr() {
return
}
ff, ok := q.lookupField(fieldName)
if !ok {
return
}
if len(values) == 0 {
q.errorf("%w: need at least one value for (not) equal", ErrParam)
return
}
if ff.Type.Ptr {
q.errorf("%w: cannot compare pointer values", ErrParam)
return
}
if len(values) == 1 {
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(values[0]))
if !ok {
return
}
if not {
q.addFilter(filterNotEqual[T]{ff, rv})
} else {
q.addFilter(filterEqual[T]{ff, rv})
}
return
}
rvs := make([]reflect.Value, len(values))
for i, value := range values {
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(value))
if !ok {
return
}
rvs[i] = rv
}
if not {
q.addFilter(filterNotIn[T]{ff, rvs})
} else {
q.addFilter(filterIn[T]{ff, rvs})
}
}
// FilterIn selects records that have one of values of the string slice fieldName.
//
// If fieldName has an index, it is used to select rows.
//
// Note: Value must be a compatible type for comparison with the elements of
// fieldName. Go constant numbers become ints, which are not compatible with uint
// or float types.
func (q *Query[T]) FilterIn(fieldName string, value any) *Query[T] {
if !q.checkErr() {
return q
}
ff, ok := q.lookupField(fieldName)
if !ok {
return q
}
if ff.Type.Ptr {
q.errorf("%w: cannot compare pointer values", ErrParam)
return q
}
if ff.Type.Kind != kindSlice {
q.errorf("%w: field for FilterIn must be a slice", ErrParam)
return q
}
et := ff.Type.ListElem
if et.Ptr {
q.errorf("%w: cannot compare element pointer values", ErrParam)
return q
}
rv, ok := q.prepareValue(ff.Name, *et, ff.structField.Type.Elem(), reflect.ValueOf(value))
if !ok {
return q
}
q.addFilter(filterInSlice[T]{ff, rv})
return q
}
// FilterGreater selects records that have fieldName > value.
//
// Note: Value must be a compatible type for comparison with fieldName. Go
// constant numbers become ints, which are not compatible with uint or float
// types.
func (q *Query[T]) FilterGreater(fieldName string, value any) *Query[T] {
return q.filterCompare(fieldName, opGreater, reflect.ValueOf(value))
}
// FilterGreaterEqual selects records that have fieldName >= value.
func (q *Query[T]) FilterGreaterEqual(fieldName string, value any) *Query[T] {
return q.filterCompare(fieldName, opGreaterEqual, reflect.ValueOf(value))
}
// FilterLess selects records that have fieldName < value.
func (q *Query[T]) FilterLess(fieldName string, value any) *Query[T] {
return q.filterCompare(fieldName, opLess, reflect.ValueOf(value))
}
// FilterLessEqual selects records that have fieldName <= value.
func (q *Query[T]) FilterLessEqual(fieldName string, value any) *Query[T] {
return q.filterCompare(fieldName, opLessEqual, reflect.ValueOf(value))
}
func (q *Query[T]) filterCompare(fieldName string, op compareOp, value reflect.Value) *Query[T] {
if !q.checkErr() {
return q
}
ff, ok := q.lookupField(fieldName)
if !ok {
return q
}
if !comparable(ff.Type) {
q.errorf("%w: cannot compare %s", ErrParam, ff.Type.Kind)
return q
}
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, value)
if !ok {
return q
}
q.addFilter(filterCompare[T]{ff, op, rv})
return q
}
// Limit stops selecting records after the first n records.
// Can only be called once. n must be > 1.
func (q *Query[T]) Limit(n int) *Query[T] {
if !q.checkErr() {
return q
}
if n <= 0 {
q.errorf("%w: limit must be >= 1", ErrParam)
return q
}
if q.xlimit > 0 {
q.errorf("%w: already have a limit", ErrParam)
return q
}
q.xlimit = n
return q
}
// SortAsc sorts the selected records by fieldNames in ascending order.
// Additional orderings can be added by more calls to SortAsc or SortDesc.
func (q *Query[T]) SortAsc(fieldNames ...string) *Query[T] {
return q.order(fieldNames, true)
}
// SortDesc sorts the selected records by fieldNames in descending order.
// Additional orderings can be added by more calls to SortAsc or SortDesc.
func (q *Query[T]) SortDesc(fieldNames ...string) *Query[T] {
return q.order(fieldNames, false)
}
func (q *Query[T]) order(fieldNames []string, asc bool) *Query[T] {
if !q.checkErr() {
return q
}
if len(fieldNames) == 0 {
q.errorf("%w: sort fieldNames must be non-empty", ErrParam)
return q
}
for _, name := range fieldNames {
ff, ok := q.lookupField(name)
if !ok {
return q
}
if !comparable(ff.Type) {
q.errorf("%w: cannot sort by unorderable %q", ErrParam, name)
return q
}
q.xorders = append(q.xorders, order{ff, asc})
}
return q
}
// Gather causes an Update or Delete operation to return the values of the
// affect records into l. For Update, the updated records are returned.
func (q *Query[T]) Gather(l *[]T) *Query[T] {
if !q.checkErr() {
return q
}
if l == nil {
q.errorf("%w: l must be non-nil", ErrParam)
return q
}
if q.gathers != nil {
q.errorf("%w: can only have one Gather", ErrParam)
return q
}
q.gathers = l
return q
}
// GatherIDs causes an Update or Delete operation to return the primary keys of
// affected records into ids, which must be a pointer to a slice of T's
// primary key.
func (q *Query[T]) GatherIDs(ids any) *Query[T] {
if !q.checkErr() {
return q
}
if ids == nil {
q.errorf("%w: ids must be non-nil", ErrParam)
return q
}
rv := reflect.ValueOf(ids)
t := rv.Type()
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice || t.Elem().Elem() != q.pkType {
q.errorf("%w: ids must be pointer to slice of %v, not %T", ErrParam, q.pkType, ids)
return q
}
if q.gatherIDs.IsValid() {
q.errorf("%w: can only have one GatherIDs", ErrParam)
return q
}
q.gatherIDs = rv
return q
}
func (q *Query[T]) gather(v T, rv reflect.Value) {
if q.gathers != nil {
*q.gathers = append(*q.gathers, v)
}
if q.gatherIDs.IsValid() {
ridv := rv.FieldByIndex(q.st.Current.Fields[0].structField.Index)
l := q.gatherIDs.Elem()
nl := reflect.Append(l, ridv)
l.Set(nl)
}
}
// Err returns if an error is set on the query. Can happen for invalid filters or
// canceled contexts.
// Finished queries return ErrFinished.
func (q *Query[T]) Err() error {
q.checkErr()
return q.err
}
// Delete removes the selected records, returning how many were deleted.
//
// See Gather and GatherIDs for collecting the deleted records or IDs.
func (q *Query[T]) Delete() (deleted int, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return 0, q.err
}
// We collect the records to delete first, then delete them.
type work struct {
bk []byte
rov reflect.Value
}
var deletes []work
err := q.foreachKey(true, true, func(bk []byte, ov T) error {
rov := reflect.ValueOf(ov)
q.gather(ov, rov)
deletes = append(deletes, work{bk, rov})
return nil
})
if err != nil {
return 0, err
}
for _, w := range deletes {
q.stats.Delete++
if err := q.xtx.delete(q.exec.rb, q.st, w.bk, w.rov); err != nil {
return 0, err
}
}
return len(deletes), nil
}
// Get returns the single selected record.
//
// ErrMultiple is returned if multiple records were selected.
// ErrAbsent is returned if no record was selected.
func (q *Query[T]) Get() (value T, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
var zero T
return zero, q.err
}
if _, v, err := q.nextKey(false, true); err != nil {
return v, err
} else if _, _, err := q.nextKey(false, false); err == nil {
return v, ErrMultiple
} else {
return v, nil
}
}
// Count returns the number of selected records.
func (q *Query[T]) Count() (n int, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return 0, q.err
}
err := q.foreachKey(false, false, func(kb []byte, unused T) error {
n++
return nil
})
return n, err
}
// List returns all selected records.
// On success with zero selected records, List returns the empty list.
func (q *Query[T]) List() (list []T, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return nil, q.err
}
l := []T{}
err := q.foreachKey(false, true, func(unused []byte, v T) error {
l = append(l, v)
return nil
})
if err != nil {
return nil, err
}
return l, nil
}
// UpdateNonzero updates all selected records with the non-zero fields from
// value, returning the number of records updated.
//
// Recall that false, 0, "" are all zero values. Use UpdateField or
// UpdateFields to set fields to zero their value. This is especially relevant
// if the field value comes from an external source, e.g. user input.
//
// See Gather and GatherIDs for collecting the updated records or IDs.
//
// Keep in mind that updating on an embed/anonymous field looks at individual
// fields in the embedded field for non-zeroness, not at the embed field as a whole.
func (q *Query[T]) UpdateNonzero(value T) (updated int, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return 0, q.err
}
fields, values, err := gatherNonzeroFields(q.st.Current, reflect.ValueOf(value), false)
if err != nil {
return 0, err
}
sfl := make([]reflect.StructField, len(fields))
for i, f := range fields {
sfl[i] = f.structField
}
return q.update(sfl, values)
}
// UpdateField calls UpdateFields for fieldName and value.
func (q *Query[T]) UpdateField(fieldName string, value any) (updated int, rerr error) {
return q.UpdateFields(map[string]any{fieldName: value})
}
// UpdateFields updates all selected records, setting fields named by the map
// keys of fieldValues to the corresponding map value and returning the number
// of records updated.
//
// See Gather and GatherIDs for collecting the updated records or IDs.
//
// Entire embed fields can be updated, as well as their individual embedded
// fields.
func (q *Query[T]) UpdateFields(fieldValues map[string]any) (updated int, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return 0, q.err
}
if len(fieldValues) == 0 {
return 0, fmt.Errorf("%w: must update at least one field", ErrParam)
}
fields := make([]reflect.StructField, 0, len(fieldValues))
values := make([]reflect.Value, 0, len(fieldValues))
next:
for name, value := range fieldValues {
for i, f := range q.st.Current.Fields {
if f.Name != name {
continue
}
if i == 0 {
return 0, fmt.Errorf("%w: cannot update primary key", ErrParam)
}
rv, ok := q.prepareValue(f.Name, f.Type, f.structField.Type, reflect.ValueOf(value))
if !ok {
return 0, q.err
}
fields = append(fields, f.structField)
values = append(values, rv)
continue next
}
for _, ef := range q.st.Current.embedFields {
if ef.Name != name {
continue
}
rv, ok := q.prepareValue(ef.Name, ef.Type, ef.structField.Type, reflect.ValueOf(value))
if !ok {
return 0, q.err
}
fields = append(fields, ef.structField)
values = append(values, rv)
continue next
}
return 0, fmt.Errorf("%w: unknown field %q", ErrParam, name)
}
return q.update(fields, values)
}
func (q *Query[T]) update(fields []reflect.StructField, values []reflect.Value) (int, error) {
// todo: we could check if the updated fields are not relevant for the cursor (not in filter/sort query) and update inside foreach.
// We first gather all records to be updated (using the query), then update the
// records.
type work struct {
bk []byte
rv reflect.Value
ov reflect.Value
}
var updates []work
err := q.foreachKey(true, true, func(bk []byte, v T) error {
rv := reflect.ValueOf(&v).Elem()
ov := reflect.New(q.st.Type).Elem()
ov.Set(rv)
for i, sf := range fields {
frv := rv.FieldByIndex(sf.Index)
frv.Set(values[i])
}
q.gather(v, rv)
updates = append(updates, work{bk, rv, ov})
return nil
})
if err != nil {
return 0, err
}
for _, w := range updates {
q.stats.Update++
if err := q.xtx.update(q.exec.rb, q.st, w.rv, w.ov, w.bk); err != nil {
return 0, err
}
}
return len(updates), nil
}
// IDs sets idsptr to the primary keys of selected records. Idptrs must be a
// slice of T's primary key type.
func (q *Query[T]) IDs(idsptr any) (rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return q.err
}
if idsptr == nil {
return fmt.Errorf("%w: idsptr must not be nil", ErrParam)
}
rv := reflect.ValueOf(idsptr)
if rv.Type().Kind() != reflect.Ptr || rv.Type().Elem().Kind() != reflect.Slice || rv.Type().Elem().Elem() != q.pkType {
return fmt.Errorf("%w: idsptr must be a ptr to slice of %v, not %T", ErrParam, q.pkType, idsptr)
}
s := reflect.MakeSlice(rv.Type().Elem(), 0, 0)
err := q.foreachID(false, func(pkv any) error {
s = reflect.Append(s, reflect.ValueOf(pkv))
return nil
})
if err != nil {
return err
}
rv.Elem().Set(s)
return nil
}
// todo: should we have an iteration object that we can call Next and NextID on?
// Next fetches the next record, moving the cursor forward.
//
// ErrAbsent is returned if no more records match.
//
// Automatically created transactions are read-only.
//
// Close must be called on a Query on which Next or NextID was called and that
// is not yet finished, i.e. has not yet returned an error (including
// ErrAbsent).
func (q *Query[T]) Next() (value T, rerr error) {
// note: no q.finish preamble because caller iterates over result themselves.
if !q.checkErr() {
var zero T
return zero, q.err
}
q.nextOnly = true
_, v, err := q.nextKey(false, true)
if err == ErrAbsent {
q.error(err)
}
return v, err
}
// NextID is like Next, but only fetches the primary key of the next matching
// record, storing it in idptr.
func (q *Query[T]) NextID(idptr any) (rerr error) {
// note: no q.finish preamble because caller iterates over result themselves.
if !q.checkErr() {
return q.err
}
q.nextOnly = true
rpkv := reflect.ValueOf(idptr)
if idptr == nil {
q.errorf("%w: idptr must be non-nil", ErrParam)
return q.err
}
t := rpkv.Type()
if t.Kind() != reflect.Ptr || t.Elem() != q.pkType {
return fmt.Errorf("%w: value must be ptr to %v, not %v", ErrParam, q.pkType, t)
}
err := q.nextID(false, rpkv.Elem())
if err == ErrAbsent {
q.error(err)
}
return err
}
// Exists returns whether any record was selected.
func (q *Query[T]) Exists() (exists bool, rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return false, q.err
}
q.xlimit = 1
_, _, err := q.nextKey(false, false)
if err == ErrAbsent {
return false, nil
}
return err == nil, err
}
// StopForEach is an error value that, if returned by the function passed to
// Query.ForEach, stops further iterations.
var StopForEach error = errors.New("stop foreach")
// ForEach calls fn on each selected record.
// If fn returns StopForEach, ForEach stops iterating, so no longer calls fn,
// and returns nil.
// Fn must not update values, the internal cursor is not repositioned between
// invocations of fn, which would cause undefined behaviour (in practice,
// matching values could be skipped).
func (q *Query[T]) ForEach(fn func(value T) error) (rerr error) {
defer q.finish(&rerr)
q.checkNotNext()
if !q.checkErr() {
return q.err
}
return q.foreachKey(false, true, func(bk []byte, v T) error {
return fn(v)
})
}