mirror of
https://github.com/mjl-/mox.git
synced 2024-12-26 16:33:47 +03:00
update to latest bstore (with support for an index on a []string: Message.DKIMDomains), and cyclic data types (to be used for Message.Part soon); also adds a context.Context to database operations.
This commit is contained in:
parent
f6ed860ccb
commit
e81930ba20
58 changed files with 1970 additions and 1035 deletions
16
ctl.go
16
ctl.go
|
@ -386,7 +386,7 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
< "ok"
|
< "ok"
|
||||||
< stream
|
< stream
|
||||||
*/
|
*/
|
||||||
qmsgs, err := queue.List()
|
qmsgs, err := queue.List(ctx)
|
||||||
ctl.xcheck(err, "listing queue")
|
ctl.xcheck(err, "listing queue")
|
||||||
ctl.xwriteok()
|
ctl.xwriteok()
|
||||||
|
|
||||||
|
@ -425,10 +425,10 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
if cmd == "queuekick" {
|
if cmd == "queuekick" {
|
||||||
count, err = queue.Kick(id, todomain, recipient)
|
count, err = queue.Kick(ctx, id, todomain, recipient)
|
||||||
ctl.xcheck(err, "kicking queue")
|
ctl.xcheck(err, "kicking queue")
|
||||||
} else {
|
} else {
|
||||||
count, err = queue.Drop(id, todomain, recipient)
|
count, err = queue.Drop(ctx, id, todomain, recipient)
|
||||||
ctl.xcheck(err, "dropping messages from queue")
|
ctl.xcheck(err, "dropping messages from queue")
|
||||||
}
|
}
|
||||||
ctl.xwrite(fmt.Sprintf("%d", count))
|
ctl.xwrite(fmt.Sprintf("%d", count))
|
||||||
|
@ -447,7 +447,7 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctl.xcheck(err, "parsing id")
|
ctl.xcheck(err, "parsing id")
|
||||||
}
|
}
|
||||||
mr, err := queue.OpenMessage(id)
|
mr, err := queue.OpenMessage(ctx, id)
|
||||||
ctl.xcheck(err, "opening message")
|
ctl.xcheck(err, "opening message")
|
||||||
defer func() {
|
defer func() {
|
||||||
err := mr.Close()
|
err := mr.Close()
|
||||||
|
@ -458,7 +458,7 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
|
|
||||||
case "importmaildir", "importmbox":
|
case "importmaildir", "importmbox":
|
||||||
mbox := cmd == "importmbox"
|
mbox := cmd == "importmbox"
|
||||||
importctl(ctl, mbox)
|
importctl(ctx, ctl, mbox)
|
||||||
|
|
||||||
case "domainadd":
|
case "domainadd":
|
||||||
/* protocol:
|
/* protocol:
|
||||||
|
@ -609,7 +609,7 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
log.Check(err, "removing old junkfilter bloom filter file", mlog.Field("path", bloomPath))
|
log.Check(err, "removing old junkfilter bloom filter file", mlog.Field("path", bloomPath))
|
||||||
|
|
||||||
// Open junk filter, this creates new files.
|
// Open junk filter, this creates new files.
|
||||||
jf, _, err := acc.OpenJunkFilter(ctl.log)
|
jf, _, err := acc.OpenJunkFilter(ctx, ctl.log)
|
||||||
ctl.xcheck(err, "open new junk filter")
|
ctl.xcheck(err, "open new junk filter")
|
||||||
defer func() {
|
defer func() {
|
||||||
if jf == nil {
|
if jf == nil {
|
||||||
|
@ -621,10 +621,10 @@ func servectlcmd(ctx context.Context, log *mlog.Log, ctl *ctl, xcmd *string, shu
|
||||||
|
|
||||||
// Read through messages with junk or nonjunk flag set, and train them.
|
// Read through messages with junk or nonjunk flag set, and train them.
|
||||||
var total, trained int
|
var total, trained int
|
||||||
q := bstore.QueryDB[store.Message](acc.DB)
|
q := bstore.QueryDB[store.Message](ctx, acc.DB)
|
||||||
err = q.ForEach(func(m store.Message) error {
|
err = q.ForEach(func(m store.Message) error {
|
||||||
total++
|
total++
|
||||||
ok, err := acc.TrainMessage(ctl.log, jf, m)
|
ok, err := acc.TrainMessage(ctx, ctl.log, jf, m)
|
||||||
if ok {
|
if ok {
|
||||||
trained++
|
trained++
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,13 +67,13 @@ type DomainFeedback struct {
|
||||||
dmarcrpt.Feedback
|
dmarcrpt.Feedback
|
||||||
}
|
}
|
||||||
|
|
||||||
func database() (rdb *bstore.DB, rerr error) {
|
func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
if dmarcDB == nil {
|
if dmarcDB == nil {
|
||||||
p := mox.DataDirPath("dmarcrpt.db")
|
p := mox.DataDirPath("dmarcrpt.db")
|
||||||
os.MkdirAll(filepath.Dir(p), 0770)
|
os.MkdirAll(filepath.Dir(p), 0770)
|
||||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DomainFeedback{})
|
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DomainFeedback{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func database() (rdb *bstore.DB, rerr error) {
|
||||||
|
|
||||||
// Init opens the database.
|
// Init opens the database.
|
||||||
func Init() error {
|
func Init() error {
|
||||||
_, err := database()
|
_, err := database(mox.Shutdown)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ func Init() error {
|
||||||
//
|
//
|
||||||
// fromDomain is the domain in the report message From header.
|
// fromDomain is the domain in the report message From header.
|
||||||
func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) error {
|
func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) error {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
df := DomainFeedback{0, d.Name(), fromDomain.Name(), *f}
|
df := DomainFeedback{0, d.Name(), fromDomain.Name(), *f}
|
||||||
if err := db.Insert(&df); err != nil {
|
if err := db.Insert(ctx, &df); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,30 +143,30 @@ func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain)
|
||||||
|
|
||||||
// Records returns all reports in the database.
|
// Records returns all reports in the database.
|
||||||
func Records(ctx context.Context) ([]DomainFeedback, error) {
|
func Records(ctx context.Context) ([]DomainFeedback, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return bstore.QueryDB[DomainFeedback](db).List()
|
return bstore.QueryDB[DomainFeedback](ctx, db).List()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordID returns the report for the ID.
|
// RecordID returns the report for the ID.
|
||||||
func RecordID(ctx context.Context, id int64) (DomainFeedback, error) {
|
func RecordID(ctx context.Context, id int64) (DomainFeedback, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DomainFeedback{}, err
|
return DomainFeedback{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
e := DomainFeedback{ID: id}
|
e := DomainFeedback{ID: id}
|
||||||
err = db.Get(&e)
|
err = db.Get(ctx, &e)
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordsPeriodDomain returns the reports overlapping start and end, for the given
|
// RecordsPeriodDomain returns the reports overlapping start and end, for the given
|
||||||
// domain. If domain is empty, all records match for domain.
|
// domain. If domain is empty, all records match for domain.
|
||||||
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]DomainFeedback, error) {
|
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]DomainFeedback, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -174,7 +174,7 @@ func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain strin
|
||||||
s := start.Unix()
|
s := start.Unix()
|
||||||
e := end.Unix()
|
e := end.Unix()
|
||||||
|
|
||||||
q := bstore.QueryDB[DomainFeedback](db)
|
q := bstore.QueryDB[DomainFeedback](ctx, db)
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
q.FilterNonzero(DomainFeedback{Domain: domain})
|
q.FilterNonzero(DomainFeedback{Domain: domain})
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,10 @@ import (
|
||||||
"github.com/mjl-/mox/mox-"
|
"github.com/mjl-/mox/mox-"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func TestDMARCDB(t *testing.T) {
|
func TestDMARCDB(t *testing.T) {
|
||||||
|
mox.Shutdown = ctxbg
|
||||||
mox.ConfigStaticPath = "../testdata/dmarcdb/fake.conf"
|
mox.ConfigStaticPath = "../testdata/dmarcdb/fake.conf"
|
||||||
mox.Conf.Static.DataDir = "."
|
mox.Conf.Static.DataDir = "."
|
||||||
|
|
||||||
|
@ -76,32 +79,32 @@ func TestDMARCDB(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err := AddReport(context.Background(), feedback, dns.Domain{ASCII: "google.com"}); err != nil {
|
if err := AddReport(ctxbg, feedback, dns.Domain{ASCII: "google.com"}); err != nil {
|
||||||
t.Fatalf("adding report: %s", err)
|
t.Fatalf("adding report: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
records, err := Records(context.Background())
|
records, err := Records(ctxbg)
|
||||||
if err != nil || len(records) != 1 || !reflect.DeepEqual(&records[0].Feedback, feedback) {
|
if err != nil || len(records) != 1 || !reflect.DeepEqual(&records[0].Feedback, feedback) {
|
||||||
t.Fatalf("records: got err %v, records %#v, expected no error, single record with feedback %#v", err, records, feedback)
|
t.Fatalf("records: got err %v, records %#v, expected no error, single record with feedback %#v", err, records, feedback)
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := RecordID(context.Background(), records[0].ID)
|
record, err := RecordID(ctxbg, records[0].ID)
|
||||||
if err != nil || !reflect.DeepEqual(&record.Feedback, feedback) {
|
if err != nil || !reflect.DeepEqual(&record.Feedback, feedback) {
|
||||||
t.Fatalf("record id: got err %v, record %#v, expected feedback %#v", err, record, feedback)
|
t.Fatalf("record id: got err %v, record %#v, expected feedback %#v", err, record, feedback)
|
||||||
}
|
}
|
||||||
|
|
||||||
start := time.Unix(1596412800, 0)
|
start := time.Unix(1596412800, 0)
|
||||||
end := time.Unix(1596499199, 0)
|
end := time.Unix(1596499199, 0)
|
||||||
records, err = RecordsPeriodDomain(context.Background(), start, end, "example.org")
|
records, err = RecordsPeriodDomain(ctxbg, start, end, "example.org")
|
||||||
if err != nil || len(records) != 1 || !reflect.DeepEqual(&records[0].Feedback, feedback) {
|
if err != nil || len(records) != 1 || !reflect.DeepEqual(&records[0].Feedback, feedback) {
|
||||||
t.Fatalf("records: got err %v, records %#v, expected no error, single record with feedback %#v", err, records, feedback)
|
t.Fatalf("records: got err %v, records %#v, expected no error, single record with feedback %#v", err, records, feedback)
|
||||||
}
|
}
|
||||||
|
|
||||||
records, err = RecordsPeriodDomain(context.Background(), end, end, "example.org")
|
records, err = RecordsPeriodDomain(ctxbg, end, end, "example.org")
|
||||||
if err != nil || len(records) != 0 {
|
if err != nil || len(records) != 0 {
|
||||||
t.Fatalf("records: got err %v, records %#v, expected no error and no records", err, records)
|
t.Fatalf("records: got err %v, records %#v, expected no error and no records", err, records)
|
||||||
}
|
}
|
||||||
records, err = RecordsPeriodDomain(context.Background(), start, end, "other.example")
|
records, err = RecordsPeriodDomain(ctxbg, start, end, "other.example")
|
||||||
if err != nil || len(records) != 0 {
|
if err != nil || len(records) != 0 {
|
||||||
t.Fatalf("records: got err %v, records %#v, expected no error and no records", err, records)
|
t.Fatalf("records: got err %v, records %#v, expected no error and no records", err, records)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
@ -56,7 +57,7 @@ func xcmdExport(mbox bool, args []string, c *cmd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
dbpath := filepath.Join(accountDir, "index.db")
|
dbpath := filepath.Join(accountDir, "index.db")
|
||||||
db, err := bstore.Open(dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, store.Message{}, store.Recipient{}, store.Mailbox{})
|
db, err := bstore.Open(context.Background(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, store.Message{}, store.Recipient{}, store.Mailbox{})
|
||||||
xcheckf(err, "open database %q", dbpath)
|
xcheckf(err, "open database %q", dbpath)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := db.Close(); err != nil {
|
if err := db.Close(); err != nil {
|
||||||
|
@ -65,7 +66,7 @@ func xcmdExport(mbox bool, args []string, c *cmd) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
a := store.DirArchiver{Dir: dst}
|
a := store.DirArchiver{Dir: dst}
|
||||||
err = store.ExportMessages(mlog.New("export"), db, accountDir, a, !mbox, mailbox)
|
err = store.ExportMessages(context.Background(), mlog.New("export"), db, accountDir, a, !mbox, mailbox)
|
||||||
xcheckf(err, "exporting messages")
|
xcheckf(err, "exporting messages")
|
||||||
err = a.Close()
|
err = a.Close()
|
||||||
xcheckf(err, "closing archiver")
|
xcheckf(err, "closing archiver")
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/mjl-/mox
|
||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/mjl-/bstore v0.0.0-20230211204415-a9899ef6e782
|
github.com/mjl-/bstore v0.0.1
|
||||||
github.com/mjl-/sconf v0.0.4
|
github.com/mjl-/sconf v0.0.4
|
||||||
github.com/mjl-/sherpa v0.6.5
|
github.com/mjl-/sherpa v0.6.5
|
||||||
github.com/mjl-/sherpadoc v0.0.10
|
github.com/mjl-/sherpadoc v0.0.10
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -145,8 +145,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/mjl-/bstore v0.0.0-20230211204415-a9899ef6e782 h1:dVwJA/wXzXXUROM9oM3Stg3cmqixiFh4Zi1Xumvtj74=
|
github.com/mjl-/bstore v0.0.1 h1:OzQfYgpMCvNjNIj9FFJ3HidYzG6eSlLSYzCTzw9sptY=
|
||||||
github.com/mjl-/bstore v0.0.0-20230211204415-a9899ef6e782/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0=
|
github.com/mjl-/bstore v0.0.1/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0=
|
||||||
github.com/mjl-/sconf v0.0.4 h1:uyfn4vv5qOULSgiwQsPbbgkiONKnMFMsSOhsHfAiYwI=
|
github.com/mjl-/sconf v0.0.4 h1:uyfn4vv5qOULSgiwQsPbbgkiONKnMFMsSOhsHfAiYwI=
|
||||||
github.com/mjl-/sconf v0.0.4/go.mod h1:ezf7YOn7gtClo8y71SqgZKaEkyMQ5Te7vkv4PmTTfwM=
|
github.com/mjl-/sconf v0.0.4/go.mod h1:ezf7YOn7gtClo8y71SqgZKaEkyMQ5Te7vkv4PmTTfwM=
|
||||||
github.com/mjl-/sherpa v0.6.5 h1:d90uG/j8fw+2M+ohCTAcVwTSUURGm8ktYDScJO1nKog=
|
github.com/mjl-/sherpa v0.6.5 h1:d90uG/j8fw+2M+ohCTAcVwTSUURGm8ktYDScJO1nKog=
|
||||||
|
|
|
@ -222,7 +222,7 @@ func accountHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
err := archiver.Close()
|
err := archiver.Close()
|
||||||
log.Check(err, "exporting mail close")
|
log.Check(err, "exporting mail close")
|
||||||
}()
|
}()
|
||||||
if err := store.ExportMessages(log, acc.DB, acc.Dir, archiver, maildir, ""); err != nil {
|
if err := store.ExportMessages(r.Context(), log, acc.DB, acc.Dir, archiver, maildir, ""); err != nil {
|
||||||
log.Errorx("exporting mail", err)
|
log.Errorx("exporting mail", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1500,21 +1500,21 @@ func (Admin) ClientConfigDomain(ctx context.Context, domain string) mox.ClientCo
|
||||||
|
|
||||||
// QueueList returns the messages currently in the outgoing queue.
|
// QueueList returns the messages currently in the outgoing queue.
|
||||||
func (Admin) QueueList(ctx context.Context) []queue.Msg {
|
func (Admin) QueueList(ctx context.Context) []queue.Msg {
|
||||||
l, err := queue.List()
|
l, err := queue.List(ctx)
|
||||||
xcheckf(ctx, err, "listing messages in queue")
|
xcheckf(ctx, err, "listing messages in queue")
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueueSize returns the number of messages currently in the outgoing queue.
|
// QueueSize returns the number of messages currently in the outgoing queue.
|
||||||
func (Admin) QueueSize(ctx context.Context) int {
|
func (Admin) QueueSize(ctx context.Context) int {
|
||||||
n, err := queue.Count()
|
n, err := queue.Count(ctx)
|
||||||
xcheckf(ctx, err, "listing messages in queue")
|
xcheckf(ctx, err, "listing messages in queue")
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueueKick initiates delivery of a message from the queue.
|
// QueueKick initiates delivery of a message from the queue.
|
||||||
func (Admin) QueueKick(ctx context.Context, id int64) {
|
func (Admin) QueueKick(ctx context.Context, id int64) {
|
||||||
n, err := queue.Kick(id, "", "")
|
n, err := queue.Kick(ctx, id, "", "")
|
||||||
if err == nil && n == 0 {
|
if err == nil && n == 0 {
|
||||||
err = errors.New("message not found")
|
err = errors.New("message not found")
|
||||||
}
|
}
|
||||||
|
@ -1523,7 +1523,7 @@ func (Admin) QueueKick(ctx context.Context, id int64) {
|
||||||
|
|
||||||
// QueueDrop removes a message from the queue.
|
// QueueDrop removes a message from the queue.
|
||||||
func (Admin) QueueDrop(ctx context.Context, id int64) {
|
func (Admin) QueueDrop(ctx context.Context, id int64) {
|
||||||
n, err := queue.Drop(id, "", "")
|
n, err := queue.Drop(ctx, id, "", "")
|
||||||
if err == nil && n == 0 {
|
if err == nil && n == 0 {
|
||||||
err = errors.New("message not found")
|
err = errors.New("message not found")
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,7 +249,7 @@ func importStart(log *mlog.Log, accName string, f *os.File, skipMailboxPrefix st
|
||||||
}
|
}
|
||||||
acc.Lock() // Not using WithWLock because importMessage is responsible for unlocking.
|
acc.Lock() // Not using WithWLock because importMessage is responsible for unlocking.
|
||||||
|
|
||||||
tx, err := acc.DB.Begin(true)
|
tx, err := acc.DB.Begin(context.Background(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
acc.Unlock()
|
acc.Unlock()
|
||||||
xerr := acc.Close()
|
xerr := acc.Close()
|
||||||
|
@ -346,7 +346,7 @@ func importMessages(ctx context.Context, log *mlog.Log, token string, acc *store
|
||||||
|
|
||||||
conf, _ := acc.Conf()
|
conf, _ := acc.Conf()
|
||||||
|
|
||||||
jf, _, err := acc.OpenJunkFilter(log)
|
jf, _, err := acc.OpenJunkFilter(ctx, log)
|
||||||
if err != nil && !errors.Is(err, store.ErrNoJunkFilter) {
|
if err != nil && !errors.Is(err, store.ErrNoJunkFilter) {
|
||||||
ximportcheckf(err, "open junk filter")
|
ximportcheckf(err, "open junk filter")
|
||||||
}
|
}
|
||||||
|
@ -376,7 +376,7 @@ func importMessages(ctx context.Context, log *mlog.Log, token string, acc *store
|
||||||
problemf("parsing message %s for updating junk filter: %v (continuing)", pos, err)
|
problemf("parsing message %s for updating junk filter: %v (continuing)", pos, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = jf.Train(!m.Junk, words)
|
err = jf.Train(ctx, !m.Junk, words)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
problemf("training junk filter for message %s: %v (continuing)", pos, err)
|
problemf("training junk filter for message %s: %v (continuing)", pos, err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -371,7 +371,7 @@ func (c *conn) utf8strings() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) xdbwrite(fn func(tx *bstore.Tx)) {
|
func (c *conn) xdbwrite(fn func(tx *bstore.Tx)) {
|
||||||
err := c.account.DB.Write(func(tx *bstore.Tx) error {
|
err := c.account.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
fn(tx)
|
fn(tx)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -379,7 +379,7 @@ func (c *conn) xdbwrite(fn func(tx *bstore.Tx)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) xdbread(fn func(tx *bstore.Tx)) {
|
func (c *conn) xdbread(fn func(tx *bstore.Tx)) {
|
||||||
err := c.account.DB.Read(func(tx *bstore.Tx) error {
|
err := c.account.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
fn(tx)
|
fn(tx)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -1574,7 +1574,7 @@ func (c *conn) cmdAuthenticate(tag, cmd string, p *parser) {
|
||||||
}()
|
}()
|
||||||
var ipadhash, opadhash hash.Hash
|
var ipadhash, opadhash hash.Hash
|
||||||
acc.WithRLock(func() {
|
acc.WithRLock(func() {
|
||||||
err := acc.DB.Read(func(tx *bstore.Tx) error {
|
err := acc.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
password, err := bstore.QueryTx[store.Password](tx).Get()
|
password, err := bstore.QueryTx[store.Password](tx).Get()
|
||||||
if err == bstore.ErrAbsent {
|
if err == bstore.ErrAbsent {
|
||||||
xusercodeErrorf("AUTHENTICATIONFAILED", "bad credentials")
|
xusercodeErrorf("AUTHENTICATIONFAILED", "bad credentials")
|
||||||
|
@ -1644,7 +1644,7 @@ func (c *conn) cmdAuthenticate(tag, cmd string, p *parser) {
|
||||||
}
|
}
|
||||||
var xscram store.SCRAM
|
var xscram store.SCRAM
|
||||||
acc.WithRLock(func() {
|
acc.WithRLock(func() {
|
||||||
err := acc.DB.Read(func(tx *bstore.Tx) error {
|
err := acc.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
password, err := bstore.QueryTx[store.Password](tx).Get()
|
password, err := bstore.QueryTx[store.Password](tx).Get()
|
||||||
if authVariant == "scram-sha-1" {
|
if authVariant == "scram-sha-1" {
|
||||||
xscram = password.SCRAMSHA1
|
xscram = password.SCRAMSHA1
|
||||||
|
@ -1998,7 +1998,7 @@ func (c *conn) cmdDelete(tag, cmd string, p *parser) {
|
||||||
remove[i].Junk = false
|
remove[i].Junk = false
|
||||||
remove[i].Notjunk = false
|
remove[i].Notjunk = false
|
||||||
}
|
}
|
||||||
err = c.account.RetrainMessages(c.log, tx, remove, true)
|
err = c.account.RetrainMessages(context.TODO(), c.log, tx, remove, true)
|
||||||
xcheckf(err, "untraining deleted messages")
|
xcheckf(err, "untraining deleted messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2743,7 +2743,7 @@ func (c *conn) xexpunge(uidSet *numSet, missingMailboxOK bool) []store.Message {
|
||||||
remove[i].Junk = false
|
remove[i].Junk = false
|
||||||
remove[i].Notjunk = false
|
remove[i].Notjunk = false
|
||||||
}
|
}
|
||||||
err = c.account.RetrainMessages(c.log, tx, remove, true)
|
err = c.account.RetrainMessages(context.TODO(), c.log, tx, remove, true)
|
||||||
xcheckf(err, "untraining deleted messages")
|
xcheckf(err, "untraining deleted messages")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -3030,7 +3030,7 @@ func (c *conn) cmdxCopy(isUID bool, tag, cmd string, p *parser) {
|
||||||
createdIDs = append(createdIDs, newMsgIDs[i])
|
createdIDs = append(createdIDs, newMsgIDs[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.account.RetrainMessages(c.log, tx, nmsgs, false)
|
err = c.account.RetrainMessages(context.TODO(), c.log, tx, nmsgs, false)
|
||||||
xcheckf(err, "train copied messages")
|
xcheckf(err, "train copied messages")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -3169,7 +3169,7 @@ func (c *conn) cmdxMove(isUID bool, tag, cmd string, p *parser) {
|
||||||
xcheckf(err, "updating moved message in database")
|
xcheckf(err, "updating moved message in database")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.account.RetrainMessages(c.log, tx, msgs, false)
|
err = c.account.RetrainMessages(context.TODO(), c.log, tx, msgs, false)
|
||||||
xcheckf(err, "retraining messages after move")
|
xcheckf(err, "retraining messages after move")
|
||||||
|
|
||||||
// Prepare broadcast changes to other connections.
|
// Prepare broadcast changes to other connections.
|
||||||
|
@ -3269,7 +3269,7 @@ func (c *conn) cmdxStore(isUID bool, tag, cmd string, p *parser) {
|
||||||
xcheckf(err, "updating flags")
|
xcheckf(err, "updating flags")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := c.account.RetrainMessages(c.log, tx, updated, false)
|
err := c.account.RetrainMessages(context.TODO(), c.log, tx, updated, false)
|
||||||
xcheckf(err, "training messages")
|
xcheckf(err, "training messages")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -107,7 +108,7 @@ func xcmdImport(mbox bool, args []string, c *cmd) {
|
||||||
fmt.Fprintf(os.Stderr, "%s imported\n", count)
|
fmt.Fprintf(os.Stderr, "%s imported\n", count)
|
||||||
}
|
}
|
||||||
|
|
||||||
func importctl(ctl *ctl, mbox bool) {
|
func importctl(ctx context.Context, ctl *ctl, mbox bool) {
|
||||||
/* protocol:
|
/* protocol:
|
||||||
> "importmaildir" or "importmbox"
|
> "importmaildir" or "importmbox"
|
||||||
> account
|
> account
|
||||||
|
@ -177,7 +178,7 @@ func importctl(ctl *ctl, mbox bool) {
|
||||||
msgreader = store.NewMaildirReader(store.CreateMessageTemp, mdnewf, mdcurf, ctl.log)
|
msgreader = store.NewMaildirReader(store.CreateMessageTemp, mdnewf, mdcurf, ctl.log)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := a.DB.Begin(true)
|
tx, err := a.DB.Begin(ctx, true)
|
||||||
ctl.xcheck(err, "begin transaction")
|
ctl.xcheck(err, "begin transaction")
|
||||||
defer func() {
|
defer func() {
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
|
@ -239,7 +240,7 @@ func importctl(ctl *ctl, mbox bool) {
|
||||||
mb, changes, err = a.MailboxEnsure(tx, mailbox, true)
|
mb, changes, err = a.MailboxEnsure(tx, mailbox, true)
|
||||||
ctl.xcheck(err, "ensuring mailbox exists")
|
ctl.xcheck(err, "ensuring mailbox exists")
|
||||||
|
|
||||||
jf, _, err := a.OpenJunkFilter(ctl.log)
|
jf, _, err := a.OpenJunkFilter(ctx, ctl.log)
|
||||||
if err != nil && !errors.Is(err, store.ErrNoJunkFilter) {
|
if err != nil && !errors.Is(err, store.ErrNoJunkFilter) {
|
||||||
ctl.xcheck(err, "open junk filter")
|
ctl.xcheck(err, "open junk filter")
|
||||||
}
|
}
|
||||||
|
@ -287,7 +288,7 @@ func importctl(ctl *ctl, mbox bool) {
|
||||||
if words, err := jf.ParseMessage(p); err != nil {
|
if words, err := jf.ParseMessage(p); err != nil {
|
||||||
ctl.log.Infox("parsing message for updating junk filter", err, mlog.Field("parse", ""), mlog.Field("path", origPath))
|
ctl.log.Infox("parsing message for updating junk filter", err, mlog.Field("parse", ""), mlog.Field("path", origPath))
|
||||||
} else {
|
} else {
|
||||||
err = jf.Train(!m.Junk, words)
|
err = jf.Train(ctx, !m.Junk, words)
|
||||||
ctl.xcheck(err, "training junk filter")
|
ctl.xcheck(err, "training junk filter")
|
||||||
m.TrainedJunk = &m.Junk
|
m.TrainedJunk = &m.Junk
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,8 @@ import (
|
||||||
"github.com/mjl-/mox/store"
|
"github.com/mjl-/mox/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func tcheck(t *testing.T, err error, msg string) {
|
func tcheck(t *testing.T, err error, msg string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -51,7 +53,7 @@ func TestDeliver(t *testing.T) {
|
||||||
// Load mox config.
|
// Load mox config.
|
||||||
mox.ConfigStaticPath = "testdata/integration/config/mox.conf"
|
mox.ConfigStaticPath = "testdata/integration/config/mox.conf"
|
||||||
filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
|
filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
|
||||||
if errs := mox.LoadConfig(context.Background(), false); len(errs) > 0 {
|
if errs := mox.LoadConfig(ctxbg, false); len(errs) > 0 {
|
||||||
t.Fatalf("loading mox config: %v", errs)
|
t.Fatalf("loading mox config: %v", errs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +82,7 @@ func TestDeliver(t *testing.T) {
|
||||||
latestMsgID := func(username string) int64 {
|
latestMsgID := func(username string) int64 {
|
||||||
// We open the account index database created by mox for the test user. And we keep looking for the email we sent.
|
// We open the account index database created by mox for the test user. And we keep looking for the email we sent.
|
||||||
dbpath := fmt.Sprintf("testdata/integration/data/accounts/%s/index.db", username)
|
dbpath := fmt.Sprintf("testdata/integration/data/accounts/%s/index.db", username)
|
||||||
db, err := bstore.Open(dbpath, &bstore.Options{Timeout: 3 * time.Second}, store.Message{}, store.Recipient{}, store.Mailbox{}, store.Password{})
|
db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{Timeout: 3 * time.Second}, store.Message{}, store.Recipient{}, store.Mailbox{}, store.Password{})
|
||||||
if err != nil && errors.Is(err, bolt.ErrTimeout) {
|
if err != nil && errors.Is(err, bolt.ErrTimeout) {
|
||||||
log.Printf("db open timeout (normal delay for new sender with account and db file kept open)")
|
log.Printf("db open timeout (normal delay for new sender with account and db file kept open)")
|
||||||
return 0
|
return 0
|
||||||
|
@ -88,7 +90,7 @@ func TestDeliver(t *testing.T) {
|
||||||
tcheck(t, err, "open test account database")
|
tcheck(t, err, "open test account database")
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
q := bstore.QueryDB[store.Mailbox](db)
|
q := bstore.QueryDB[store.Mailbox](ctxbg, db)
|
||||||
q.FilterNonzero(store.Mailbox{Name: "Inbox"})
|
q.FilterNonzero(store.Mailbox{Name: "Inbox"})
|
||||||
inbox, err := q.Get()
|
inbox, err := q.Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -96,7 +98,7 @@ func TestDeliver(t *testing.T) {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
qm := bstore.QueryDB[store.Message](db)
|
qm := bstore.QueryDB[store.Message](ctxbg, db)
|
||||||
qm.FilterNonzero(store.Message{MailboxID: inbox.ID})
|
qm.FilterNonzero(store.Message{MailboxID: inbox.ID})
|
||||||
qm.SortDesc("ID")
|
qm.SortDesc("ID")
|
||||||
qm.Limit(1)
|
qm.Limit(1)
|
||||||
|
|
21
junk.go
21
junk.go
|
@ -14,6 +14,7 @@ own ham/spam emails.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -134,7 +135,7 @@ func cmdJunkTrain(c *cmd) {
|
||||||
defer a.Profile()()
|
defer a.Profile()()
|
||||||
a.SetLogLevel()
|
a.SetLogLevel()
|
||||||
|
|
||||||
f := must(junk.NewFilter(mlog.New("junktrain"), a.params, a.databasePath, a.bloomfilterPath))
|
f := must(junk.NewFilter(context.Background(), mlog.New("junktrain"), a.params, a.databasePath, a.bloomfilterPath))
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
log.Printf("closing junk filter: %v", err)
|
log.Printf("closing junk filter: %v", err)
|
||||||
|
@ -164,14 +165,14 @@ func cmdJunkCheck(c *cmd) {
|
||||||
defer a.Profile()()
|
defer a.Profile()()
|
||||||
a.SetLogLevel()
|
a.SetLogLevel()
|
||||||
|
|
||||||
f := must(junk.OpenFilter(mlog.New("junkcheck"), a.params, a.databasePath, a.bloomfilterPath, false))
|
f := must(junk.OpenFilter(context.Background(), mlog.New("junkcheck"), a.params, a.databasePath, a.bloomfilterPath, false))
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
log.Printf("closing junk filter: %v", err)
|
log.Printf("closing junk filter: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
prob, _, _, _, err := f.ClassifyMessagePath(args[0])
|
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), args[0])
|
||||||
xcheckf(err, "testing mail")
|
xcheckf(err, "testing mail")
|
||||||
|
|
||||||
fmt.Printf("%.6f\n", prob)
|
fmt.Printf("%.6f\n", prob)
|
||||||
|
@ -189,7 +190,7 @@ func cmdJunkTest(c *cmd) {
|
||||||
defer a.Profile()()
|
defer a.Profile()()
|
||||||
a.SetLogLevel()
|
a.SetLogLevel()
|
||||||
|
|
||||||
f := must(junk.OpenFilter(mlog.New("junktest"), a.params, a.databasePath, a.bloomfilterPath, false))
|
f := must(junk.OpenFilter(context.Background(), mlog.New("junktest"), a.params, a.databasePath, a.bloomfilterPath, false))
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
log.Printf("closing junk filter: %v", err)
|
log.Printf("closing junk filter: %v", err)
|
||||||
|
@ -202,7 +203,7 @@ func cmdJunkTest(c *cmd) {
|
||||||
xcheckf(err, "readdir %q", dir)
|
xcheckf(err, "readdir %q", dir)
|
||||||
for _, fi := range files {
|
for _, fi := range files {
|
||||||
path := dir + "/" + fi.Name()
|
path := dir + "/" + fi.Name()
|
||||||
prob, _, _, _, err := f.ClassifyMessagePath(path)
|
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("classify message %q: %s", path, err)
|
log.Printf("classify message %q: %s", path, err)
|
||||||
continue
|
continue
|
||||||
|
@ -246,7 +247,7 @@ messages are shuffled, with optional random seed.`
|
||||||
defer a.Profile()()
|
defer a.Profile()()
|
||||||
a.SetLogLevel()
|
a.SetLogLevel()
|
||||||
|
|
||||||
f := must(junk.NewFilter(mlog.New("junkanalyze"), a.params, a.databasePath, a.bloomfilterPath))
|
f := must(junk.NewFilter(context.Background(), mlog.New("junkanalyze"), a.params, a.databasePath, a.bloomfilterPath))
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
log.Printf("closing junk filter: %v", err)
|
log.Printf("closing junk filter: %v", err)
|
||||||
|
@ -295,7 +296,7 @@ messages are shuffled, with optional random seed.`
|
||||||
testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
|
testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
|
||||||
for _, name := range files {
|
for _, name := range files {
|
||||||
path := dir + "/" + name
|
path := dir + "/" + name
|
||||||
prob, _, _, _, err := f.ClassifyMessagePath(path)
|
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// log.Infof("%s: %s", path, err)
|
// log.Infof("%s: %s", path, err)
|
||||||
malformed++
|
malformed++
|
||||||
|
@ -338,7 +339,7 @@ func cmdJunkPlay(c *cmd) {
|
||||||
defer a.Profile()()
|
defer a.Profile()()
|
||||||
a.SetLogLevel()
|
a.SetLogLevel()
|
||||||
|
|
||||||
f := must(junk.NewFilter(mlog.New("junkplay"), a.params, a.databasePath, a.bloomfilterPath))
|
f := must(junk.NewFilter(context.Background(), mlog.New("junkplay"), a.params, a.databasePath, a.bloomfilterPath))
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
log.Printf("closing junk filter: %v", err)
|
log.Printf("closing junk filter: %v", err)
|
||||||
|
@ -414,7 +415,7 @@ func cmdJunkPlay(c *cmd) {
|
||||||
if !msg.sent {
|
if !msg.sent {
|
||||||
var prob float64
|
var prob float64
|
||||||
var err error
|
var err error
|
||||||
prob, words, _, _, err = f.ClassifyMessagePath(path)
|
prob, words, _, _, err = f.ClassifyMessagePath(context.Background(), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nbad++
|
nbad++
|
||||||
return
|
return
|
||||||
|
@ -455,7 +456,7 @@ func cmdJunkPlay(c *cmd) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.Train(msg.ham, words); err != nil {
|
if err := f.Train(context.Background(), msg.ham, words); err != nil {
|
||||||
log.Printf("train: %s", err)
|
log.Printf("train: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ package junk
|
||||||
// todo: perhaps: whether anchor text in links in html are different from the url
|
// todo: perhaps: whether anchor text in links in html are different from the url
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -108,7 +109,7 @@ func (f *Filter) Close() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
|
func OpenFilter(ctx context.Context, log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
|
||||||
var bloom *Bloom
|
var bloom *Bloom
|
||||||
if loadBloom {
|
if loadBloom {
|
||||||
var err error
|
var err error
|
||||||
|
@ -122,7 +123,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := openDB(dbPath)
|
db, err := openDB(ctx, dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("open database: %s", err)
|
return nil, fmt.Errorf("open database: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -137,7 +138,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
|
||||||
db: db,
|
db: db,
|
||||||
bloom: bloom,
|
bloom: bloom,
|
||||||
}
|
}
|
||||||
err = f.db.Read(func(tx *bstore.Tx) error {
|
err = f.db.Read(ctx, func(tx *bstore.Tx) error {
|
||||||
wc := wordscore{Word: "-"}
|
wc := wordscore{Word: "-"}
|
||||||
err := tx.Get(&wc)
|
err := tx.Get(&wc)
|
||||||
f.hams = wc.Ham
|
f.hams = wc.Ham
|
||||||
|
@ -156,7 +157,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
|
||||||
// filter is marked as new until the first save, will be done automatically if
|
// filter is marked as new until the first save, will be done automatically if
|
||||||
// TrainDirs is called. If the bloom and/or database files exist, an error is
|
// TrainDirs is called. If the bloom and/or database files exist, an error is
|
||||||
// returned.
|
// returned.
|
||||||
func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
|
func NewFilter(ctx context.Context, log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
|
||||||
var err error
|
var err error
|
||||||
if _, err := os.Stat(bloomPath); err == nil {
|
if _, err := os.Stat(bloomPath); err == nil {
|
||||||
return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
|
return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
|
||||||
|
@ -182,7 +183,7 @@ func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter,
|
||||||
err = bf.Close()
|
err = bf.Close()
|
||||||
log.Check(err, "closing bloomfilter file")
|
log.Check(err, "closing bloomfilter file")
|
||||||
|
|
||||||
db, err := newDB(log, dbPath)
|
db, err := newDB(ctx, log, dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xerr := os.Remove(bloomPath)
|
xerr := os.Remove(bloomPath)
|
||||||
log.Check(xerr, "removing bloom filter file after db init error")
|
log.Check(xerr, "removing bloom filter file after db init error")
|
||||||
|
@ -216,7 +217,7 @@ func openBloom(path string) (*Bloom, error) {
|
||||||
return NewBloom(buf, bloomK)
|
return NewBloom(buf, bloomK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDB(log *mlog.Log, path string) (db *bstore.DB, rerr error) {
|
func newDB(ctx context.Context, log *mlog.Log, path string) (db *bstore.DB, rerr error) {
|
||||||
// Remove any existing files.
|
// Remove any existing files.
|
||||||
os.Remove(path)
|
os.Remove(path)
|
||||||
|
|
||||||
|
@ -227,18 +228,18 @@ func newDB(log *mlog.Log, path string) (db *bstore.DB, rerr error) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
db, err := bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("open new database: %w", err)
|
return nil, fmt.Errorf("open new database: %w", err)
|
||||||
}
|
}
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openDB(path string) (*bstore.DB, error) {
|
func openDB(ctx context.Context, path string) (*bstore.DB, error) {
|
||||||
if _, err := os.Stat(path); err != nil {
|
if _, err := os.Stat(path); err != nil {
|
||||||
return nil, fmt.Errorf("stat db file: %w", err)
|
return nil, fmt.Errorf("stat db file: %w", err)
|
||||||
}
|
}
|
||||||
return bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save stores modifications, e.g. from training, to the database and bloom
|
// Save stores modifications, e.g. from training, to the database and bloom
|
||||||
|
@ -280,7 +281,7 @@ func (f *Filter) Save() error {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := f.db.Write(func(tx *bstore.Tx) error {
|
err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
|
||||||
update := func(w string, ham, spam uint32) error {
|
update := func(w string, ham, spam uint32) error {
|
||||||
if f.isNew {
|
if f.isNew {
|
||||||
return tx.Insert(&wordscore{w, ham, spam})
|
return tx.Insert(&wordscore{w, ham, spam})
|
||||||
|
@ -318,12 +319,12 @@ func (f *Filter) Save() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
|
func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
|
||||||
sort.Slice(l, func(i, j int) bool {
|
sort.Slice(l, func(i, j int) bool {
|
||||||
return l[i] < l[j]
|
return l[i] < l[j]
|
||||||
})
|
})
|
||||||
|
|
||||||
err := db.Read(func(tx *bstore.Tx) error {
|
err := db.Read(ctx, func(tx *bstore.Tx) error {
|
||||||
for _, w := range l {
|
for _, w := range l {
|
||||||
wc := wordscore{Word: w}
|
wc := wordscore{Word: w}
|
||||||
if err := tx.Get(&wc); err == nil {
|
if err := tx.Get(&wc); err == nil {
|
||||||
|
@ -339,7 +340,7 @@ func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
|
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
|
||||||
func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
|
func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
|
||||||
if f.closed {
|
if f.closed {
|
||||||
return 0, 0, 0, errClosed
|
return 0, 0, 0, errClosed
|
||||||
}
|
}
|
||||||
|
@ -380,7 +381,7 @@ func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64,
|
||||||
// Fetch words from database.
|
// Fetch words from database.
|
||||||
fetched := map[string]word{}
|
fetched := map[string]word{}
|
||||||
if len(lookupWords) > 0 {
|
if len(lookupWords) > 0 {
|
||||||
if err := loadWords(f.db, lookupWords, fetched); err != nil {
|
if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
|
||||||
return 0, 0, 0, err
|
return 0, 0, 0, err
|
||||||
}
|
}
|
||||||
for w, c := range fetched {
|
for w, c := range fetched {
|
||||||
|
@ -477,7 +478,7 @@ func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
|
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
|
||||||
func (f *Filter) ClassifyMessagePath(path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||||
if f.closed {
|
if f.closed {
|
||||||
return 0, nil, 0, 0, errClosed
|
return 0, nil, 0, 0, errClosed
|
||||||
}
|
}
|
||||||
|
@ -494,35 +495,35 @@ func (f *Filter) ClassifyMessagePath(path string) (probability float64, words ma
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, 0, 0, err
|
return 0, nil, 0, 0, err
|
||||||
}
|
}
|
||||||
return f.ClassifyMessageReader(mf, fi.Size())
|
return f.ClassifyMessageReader(ctx, mf, fi.Size())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) ClassifyMessageReader(mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||||
m, err := message.EnsurePart(mf, size)
|
m, err := message.EnsurePart(mf, size)
|
||||||
if err != nil && errors.Is(err, message.ErrBadContentType) {
|
if err != nil && errors.Is(err, message.ErrBadContentType) {
|
||||||
// Invalid content-type header is a sure sign of spam.
|
// Invalid content-type header is a sure sign of spam.
|
||||||
//f.log.Infox("parsing content", err)
|
//f.log.Infox("parsing content", err)
|
||||||
return 1, nil, 0, 0, nil
|
return 1, nil, 0, 0, nil
|
||||||
}
|
}
|
||||||
return f.ClassifyMessage(m)
|
return f.ClassifyMessage(ctx, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClassifyMessage parses the mail message in r and returns the spam probability
|
// ClassifyMessage parses the mail message in r and returns the spam probability
|
||||||
// (between 0 and 1), along with the tokenized words found in the message, and the
|
// (between 0 and 1), along with the tokenized words found in the message, and the
|
||||||
// number of recognized ham and spam words.
|
// number of recognized ham and spam words.
|
||||||
func (f *Filter) ClassifyMessage(m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||||
var err error
|
var err error
|
||||||
words, err = f.ParseMessage(m)
|
words, err = f.ParseMessage(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, 0, 0, err
|
return 0, nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
probability, nham, nspam, err = f.ClassifyWords(words)
|
probability, nham, nspam, err = f.ClassifyWords(ctx, words)
|
||||||
return probability, words, nham, nspam, err
|
return probability, words, nham, nspam, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Train adds the words of a single message to the filter.
|
// Train adds the words of a single message to the filter.
|
||||||
func (f *Filter) Train(ham bool, words map[string]struct{}) error {
|
func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
|
||||||
if err := f.ensureBloom(); err != nil {
|
if err := f.ensureBloom(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -539,7 +540,7 @@ func (f *Filter) Train(ham bool, words map[string]struct{}) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.loadCache(lwords); err != nil {
|
if err := f.loadCache(ctx, lwords); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,34 +564,34 @@ func (f *Filter) Train(ham bool, words map[string]struct{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) TrainMessage(r io.ReaderAt, size int64, ham bool) error {
|
func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
|
||||||
p, _ := message.EnsurePart(r, size)
|
p, _ := message.EnsurePart(r, size)
|
||||||
words, err := f.ParseMessage(p)
|
words, err := f.ParseMessage(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing mail contents: %v", err)
|
return fmt.Errorf("parsing mail contents: %v", err)
|
||||||
}
|
}
|
||||||
return f.Train(ham, words)
|
return f.Train(ctx, ham, words)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) UntrainMessage(r io.ReaderAt, size int64, ham bool) error {
|
func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
|
||||||
p, _ := message.EnsurePart(r, size)
|
p, _ := message.EnsurePart(r, size)
|
||||||
words, err := f.ParseMessage(p)
|
words, err := f.ParseMessage(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing mail contents: %v", err)
|
return fmt.Errorf("parsing mail contents: %v", err)
|
||||||
}
|
}
|
||||||
return f.Untrain(ham, words)
|
return f.Untrain(ctx, ham, words)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) loadCache(lwords []string) error {
|
func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
|
||||||
if len(lwords) == 0 {
|
if len(lwords) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadWords(f.db, lwords, f.cache)
|
return loadWords(ctx, f.db, lwords, f.cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Untrain adjusts the filter to undo a previous training of the words.
|
// Untrain adjusts the filter to undo a previous training of the words.
|
||||||
func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
|
func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
|
||||||
if err := f.ensureBloom(); err != nil {
|
if err := f.ensureBloom(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -602,7 +603,7 @@ func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
|
||||||
lwords = append(lwords, w)
|
lwords = append(lwords, w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := f.loadCache(lwords); err != nil {
|
if err := f.loadCache(ctx, lwords); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package junk
|
package junk
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
@ -10,6 +11,8 @@ import (
|
||||||
"github.com/mjl-/mox/mlog"
|
"github.com/mjl-/mox/mlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func tcheck(t *testing.T, err error, msg string) {
|
func tcheck(t *testing.T, err error, msg string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -43,12 +46,12 @@ func TestFilter(t *testing.T) {
|
||||||
bloomPath := "../testdata/junk/filter.bloom"
|
bloomPath := "../testdata/junk/filter.bloom"
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
os.Remove(bloomPath)
|
os.Remove(bloomPath)
|
||||||
f, err := NewFilter(log, params, dbPath, bloomPath)
|
f, err := NewFilter(ctxbg, log, params, dbPath, bloomPath)
|
||||||
tcheck(t, err, "new filter")
|
tcheck(t, err, "new filter")
|
||||||
err = f.Close()
|
err = f.Close()
|
||||||
tcheck(t, err, "close filter")
|
tcheck(t, err, "close filter")
|
||||||
|
|
||||||
f, err = OpenFilter(log, params, dbPath, bloomPath, true)
|
f, err = OpenFilter(ctxbg, log, params, dbPath, bloomPath, true)
|
||||||
tcheck(t, err, "open filter")
|
tcheck(t, err, "open filter")
|
||||||
|
|
||||||
// Ensure these dirs exist. Developers should bring their own ham/spam example
|
// Ensure these dirs exist. Developers should bring their own ham/spam example
|
||||||
|
@ -75,13 +78,13 @@ func TestFilter(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prob, _, _, _, err := f.ClassifyMessagePath(filepath.Join(hamdir, hamfiles[0]))
|
prob, _, _, _, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0]))
|
||||||
tcheck(t, err, "classify ham message")
|
tcheck(t, err, "classify ham message")
|
||||||
if prob > 0.1 {
|
if prob > 0.1 {
|
||||||
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
|
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
|
||||||
}
|
}
|
||||||
|
|
||||||
prob, _, _, _, err = f.ClassifyMessagePath(filepath.Join(spamdir, spamfiles[0]))
|
prob, _, _, _, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0]))
|
||||||
tcheck(t, err, "classify spam message")
|
tcheck(t, err, "classify spam message")
|
||||||
if prob < 0.9 {
|
if prob < 0.9 {
|
||||||
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
|
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
|
||||||
|
@ -94,7 +97,7 @@ func TestFilter(t *testing.T) {
|
||||||
// classified as ham/spam. Then we untrain to see they are no longer classified.
|
// classified as ham/spam. Then we untrain to see they are no longer classified.
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
os.Remove(bloomPath)
|
os.Remove(bloomPath)
|
||||||
f, err = NewFilter(log, params, dbPath, bloomPath)
|
f, err = NewFilter(ctxbg, log, params, dbPath, bloomPath)
|
||||||
tcheck(t, err, "open filter")
|
tcheck(t, err, "open filter")
|
||||||
|
|
||||||
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
|
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
|
||||||
|
@ -112,18 +115,18 @@ func TestFilter(t *testing.T) {
|
||||||
spamsize := spamstat.Size()
|
spamsize := spamstat.Size()
|
||||||
|
|
||||||
// Train each message twice, to prevent single occurrences from being ignored.
|
// Train each message twice, to prevent single occurrences from being ignored.
|
||||||
err = f.TrainMessage(hamf, hamsize, true)
|
err = f.TrainMessage(ctxbg, hamf, hamsize, true)
|
||||||
tcheck(t, err, "train ham message")
|
tcheck(t, err, "train ham message")
|
||||||
_, err = hamf.Seek(0, 0)
|
_, err = hamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek ham message")
|
tcheck(t, err, "seek ham message")
|
||||||
err = f.TrainMessage(hamf, hamsize, true)
|
err = f.TrainMessage(ctxbg, hamf, hamsize, true)
|
||||||
tcheck(t, err, "train ham message")
|
tcheck(t, err, "train ham message")
|
||||||
|
|
||||||
err = f.TrainMessage(spamf, spamsize, false)
|
err = f.TrainMessage(ctxbg, spamf, spamsize, false)
|
||||||
tcheck(t, err, "train spam message")
|
tcheck(t, err, "train spam message")
|
||||||
_, err = spamf.Seek(0, 0)
|
_, err = spamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek spam message")
|
tcheck(t, err, "seek spam message")
|
||||||
err = f.TrainMessage(spamf, spamsize, true)
|
err = f.TrainMessage(ctxbg, spamf, spamsize, true)
|
||||||
tcheck(t, err, "train spam message")
|
tcheck(t, err, "train spam message")
|
||||||
|
|
||||||
if !f.modified {
|
if !f.modified {
|
||||||
|
@ -142,7 +145,7 @@ func TestFilter(t *testing.T) {
|
||||||
// Classify and verify.
|
// Classify and verify.
|
||||||
_, err = hamf.Seek(0, 0)
|
_, err = hamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek ham message")
|
tcheck(t, err, "seek ham message")
|
||||||
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
|
||||||
tcheck(t, err, "classify ham")
|
tcheck(t, err, "classify ham")
|
||||||
if prob > 0.1 {
|
if prob > 0.1 {
|
||||||
t.Fatalf("got prob %v, expected <= 0.1", prob)
|
t.Fatalf("got prob %v, expected <= 0.1", prob)
|
||||||
|
@ -150,7 +153,7 @@ func TestFilter(t *testing.T) {
|
||||||
|
|
||||||
_, err = spamf.Seek(0, 0)
|
_, err = spamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek spam message")
|
tcheck(t, err, "seek spam message")
|
||||||
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
|
||||||
tcheck(t, err, "classify spam")
|
tcheck(t, err, "classify spam")
|
||||||
if prob < 0.9 {
|
if prob < 0.9 {
|
||||||
t.Fatalf("got prob %v, expected >= 0.9", prob)
|
t.Fatalf("got prob %v, expected >= 0.9", prob)
|
||||||
|
@ -159,20 +162,20 @@ func TestFilter(t *testing.T) {
|
||||||
// Untrain ham & spam.
|
// Untrain ham & spam.
|
||||||
_, err = hamf.Seek(0, 0)
|
_, err = hamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek ham message")
|
tcheck(t, err, "seek ham message")
|
||||||
err = f.UntrainMessage(hamf, hamsize, true)
|
err = f.UntrainMessage(ctxbg, hamf, hamsize, true)
|
||||||
tcheck(t, err, "untrain ham message")
|
tcheck(t, err, "untrain ham message")
|
||||||
_, err = hamf.Seek(0, 0)
|
_, err = hamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek ham message")
|
tcheck(t, err, "seek ham message")
|
||||||
err = f.UntrainMessage(hamf, spamsize, true)
|
err = f.UntrainMessage(ctxbg, hamf, spamsize, true)
|
||||||
tcheck(t, err, "untrain ham message")
|
tcheck(t, err, "untrain ham message")
|
||||||
|
|
||||||
_, err = spamf.Seek(0, 0)
|
_, err = spamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek spam message")
|
tcheck(t, err, "seek spam message")
|
||||||
err = f.UntrainMessage(spamf, spamsize, true)
|
err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
|
||||||
tcheck(t, err, "untrain spam message")
|
tcheck(t, err, "untrain spam message")
|
||||||
_, err = spamf.Seek(0, 0)
|
_, err = spamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek spam message")
|
tcheck(t, err, "seek spam message")
|
||||||
err = f.UntrainMessage(spamf, spamsize, true)
|
err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
|
||||||
tcheck(t, err, "untrain spam message")
|
tcheck(t, err, "untrain spam message")
|
||||||
|
|
||||||
if !f.modified {
|
if !f.modified {
|
||||||
|
@ -182,7 +185,7 @@ func TestFilter(t *testing.T) {
|
||||||
// Classify again, should be unknown.
|
// Classify again, should be unknown.
|
||||||
_, err = hamf.Seek(0, 0)
|
_, err = hamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek ham message")
|
tcheck(t, err, "seek ham message")
|
||||||
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
|
||||||
tcheck(t, err, "classify ham")
|
tcheck(t, err, "classify ham")
|
||||||
if math.Abs(prob-0.5) > 0.1 {
|
if math.Abs(prob-0.5) > 0.1 {
|
||||||
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||||||
|
@ -190,7 +193,7 @@ func TestFilter(t *testing.T) {
|
||||||
|
|
||||||
_, err = spamf.Seek(0, 0)
|
_, err = spamf.Seek(0, 0)
|
||||||
tcheck(t, err, "seek spam message")
|
tcheck(t, err, "seek spam message")
|
||||||
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
|
||||||
tcheck(t, err, "classify spam")
|
tcheck(t, err, "classify spam")
|
||||||
if math.Abs(prob-0.5) > 0.1 {
|
if math.Abs(prob-0.5) > 0.1 {
|
||||||
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||||||
|
|
|
@ -23,7 +23,7 @@ func FuzzParseMessage(f *testing.F) {
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
os.Remove(bloomPath)
|
os.Remove(bloomPath)
|
||||||
params := Params{Twograms: true}
|
params := Params{Twograms: true}
|
||||||
jf, err := NewFilter(xlog, params, dbPath, bloomPath)
|
jf, err := NewFilter(ctxbg, xlog, params, dbPath, bloomPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.Fatalf("new filter: %v", err)
|
f.Fatalf("new filter: %v", err)
|
||||||
}
|
}
|
||||||
|
|
4
main.go
4
main.go
|
@ -1879,7 +1879,7 @@ func cmdEnsureParsed(c *cmd) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
n := 0
|
n := 0
|
||||||
err = a.DB.Write(func(tx *bstore.Tx) error {
|
err = a.DB.Write(context.Background(), func(tx *bstore.Tx) error {
|
||||||
q := bstore.QueryTx[store.Message](tx)
|
q := bstore.QueryTx[store.Message](tx)
|
||||||
q.FilterFn(func(m store.Message) bool {
|
q.FilterFn(func(m store.Message) bool {
|
||||||
return all || m.ParsedBuf == nil
|
return all || m.ParsedBuf == nil
|
||||||
|
@ -1952,7 +1952,7 @@ func cmdBumpUIDValidity(c *cmd) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var uidvalidity uint32
|
var uidvalidity uint32
|
||||||
err = a.DB.Write(func(tx *bstore.Tx) error {
|
err = a.DB.Write(context.Background(), func(tx *bstore.Tx) error {
|
||||||
mb, err := bstore.QueryTx[store.Mailbox](tx).FilterEqual("Name", args[1]).Get()
|
mb, err := bstore.QueryTx[store.Mailbox](tx).FilterEqual("Name", args[1]).Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("looking up mailbox: %v", err)
|
return fmt.Errorf("looking up mailbox: %v", err)
|
||||||
|
|
|
@ -152,14 +152,15 @@ func Listen(network, addr string) (net.Listener, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is canceled when a graceful shutdown is initiated. SMTP, IMAP, periodic
|
// Shutdown is canceled when a graceful shutdown is initiated. SMTP, IMAP, periodic
|
||||||
// processes should check this before starting a new operation. If true, the
|
// processes should check this before starting a new operation. If this context is
|
||||||
// operation should be aborted, and new connections should receive a message that
|
// canaceled, the operation should not be started, and new connections/commands should
|
||||||
// the service is currently not available.
|
// receive a message that the service is currently not available.
|
||||||
var Shutdown context.Context
|
var Shutdown context.Context
|
||||||
var ShutdownCancel func()
|
var ShutdownCancel func()
|
||||||
|
|
||||||
// Context should be used as parent by all operations. It is canceled when mox is
|
// This context should be used as parent by most operations. It is canceled 1
|
||||||
// shutdown, aborting all pending operations.
|
// second after graceful shutdown was initiated with the cancelation of the
|
||||||
|
// Shutdown context. This should abort active operations.
|
||||||
//
|
//
|
||||||
// Operations typically have context timeouts, 30s for single i/o like DNS queries,
|
// Operations typically have context timeouts, 30s for single i/o like DNS queries,
|
||||||
// and 1 minute for operations with more back and forth. These are set through a
|
// and 1 minute for operations with more back and forth. These are set through a
|
||||||
|
@ -167,6 +168,7 @@ var ShutdownCancel func()
|
||||||
// when shutting down.
|
// when shutting down.
|
||||||
//
|
//
|
||||||
// HTTP servers don't get graceful shutdown, their connections are just aborted.
|
// HTTP servers don't get graceful shutdown, their connections are just aborted.
|
||||||
|
// todo: should shut down http connections as well, and shut down the listener and/or return 503 for new requests.
|
||||||
var Context context.Context
|
var Context context.Context
|
||||||
var ContextCancel func()
|
var ContextCancel func()
|
||||||
|
|
||||||
|
|
|
@ -63,13 +63,13 @@ var (
|
||||||
var mtastsDB *bstore.DB
|
var mtastsDB *bstore.DB
|
||||||
var mutex sync.Mutex
|
var mutex sync.Mutex
|
||||||
|
|
||||||
func database() (rdb *bstore.DB, rerr error) {
|
func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
if mtastsDB == nil {
|
if mtastsDB == nil {
|
||||||
p := mox.DataDirPath("mtasts.db")
|
p := mox.DataDirPath("mtasts.db")
|
||||||
os.MkdirAll(filepath.Dir(p), 0770)
|
os.MkdirAll(filepath.Dir(p), 0770)
|
||||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
|
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ func database() (rdb *bstore.DB, rerr error) {
|
||||||
// Init opens the database and starts a goroutine that refreshes policies in
|
// Init opens the database and starts a goroutine that refreshes policies in
|
||||||
// the database, and keeps doing so periodically.
|
// the database, and keeps doing so periodically.
|
||||||
func Init(refresher bool) error {
|
func Init(refresher bool) error {
|
||||||
_, err := database()
|
_, err := database(mox.Shutdown)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ func Close() {
|
||||||
// Only non-expired records are returned.
|
// Only non-expired records are returned.
|
||||||
func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||||
log := xlog.WithContext(ctx)
|
log := xlog.WithContext(ctx)
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||||
return nil, fmt.Errorf("empty domain")
|
return nil, fmt.Errorf("empty domain")
|
||||||
}
|
}
|
||||||
now := timeNow()
|
now := timeNow()
|
||||||
q := bstore.QueryDB[PolicyRecord](db)
|
q := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||||
q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
|
q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
|
||||||
q.FilterGreater("ValidEnd", now)
|
q.FilterGreater("ValidEnd", now)
|
||||||
pr, err := q.Get()
|
pr, err := q.Get()
|
||||||
|
@ -130,7 +130,7 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pr.LastUse = now
|
pr.LastUse = now
|
||||||
if err := db.Update(&pr); err != nil {
|
if err := db.Update(ctx, &pr); err != nil {
|
||||||
log.Errorx("marking cached mta-sts policy as used in database", err)
|
log.Errorx("marking cached mta-sts policy as used in database", err)
|
||||||
}
|
}
|
||||||
if pr.Backoff {
|
if pr.Backoff {
|
||||||
|
@ -141,13 +141,13 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||||
|
|
||||||
// Upsert adds the policy to the database, overwriting an existing policy for the domain.
|
// Upsert adds the policy to the database, overwriting an existing policy for the domain.
|
||||||
// Policy can be nil, indicating a failure to fetch the policy.
|
// Policy can be nil, indicating a failure to fetch the policy.
|
||||||
func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Write(func(tx *bstore.Tx) error {
|
return db.Write(ctx, func(tx *bstore.Tx) error {
|
||||||
pr := PolicyRecord{Domain: domain.Name()}
|
pr := PolicyRecord{Domain: domain.Name()}
|
||||||
err := tx.Get(&pr)
|
err := tx.Get(&pr)
|
||||||
if err != nil && err != bstore.ErrAbsent {
|
if err != nil && err != bstore.ErrAbsent {
|
||||||
|
@ -185,11 +185,11 @@ func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||||
// PolicyRecords returns all policies in the database, sorted descending by last
|
// PolicyRecords returns all policies in the database, sorted descending by last
|
||||||
// use, domain.
|
// use, domain.
|
||||||
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
|
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return bstore.QueryDB[PolicyRecord](db).SortDesc("LastUse", "Domain").List()
|
return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves an MTA-STS policy for domain and whether it is fresh.
|
// Get retrieves an MTA-STS policy for domain and whether it is fresh.
|
||||||
|
@ -244,7 +244,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||||
if record != nil {
|
if record != nil {
|
||||||
recordID = record.ID
|
recordID = record.ID
|
||||||
}
|
}
|
||||||
if err := Upsert(domain, recordID, p); err != nil {
|
if err := Upsert(ctx, domain, recordID, p); err != nil {
|
||||||
log.Errorx("inserting policy into cache, continuing", err)
|
log.Errorx("inserting policy into cache, continuing", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -259,9 +259,9 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||||
|
|
||||||
// Policy was found in database. Check in DNS it is still fresh.
|
// Policy was found in database. Check in DNS it is still fresh.
|
||||||
policy = &cachedPolicy.Policy
|
policy = &cachedPolicy.Policy
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, domain)
|
record, _, _, err := mtasts.LookupRecord(nctx, resolver, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, mtasts.ErrNoRecord) {
|
if !errors.Is(err, mtasts.ErrNoRecord) {
|
||||||
// Could be a temporary DNS or configuration error.
|
// Could be a temporary DNS or configuration error.
|
||||||
|
@ -271,15 +271,16 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||||
} else if record.ID == cachedPolicy.RecordID {
|
} else if record.ID == cachedPolicy.RecordID {
|
||||||
return policy, true, nil
|
return policy, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// New policy should be available.
|
// New policy should be available.
|
||||||
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
nctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
p, _, err := mtasts.FetchPolicy(ctx, domain)
|
p, _, err := mtasts.FetchPolicy(nctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
|
log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
|
||||||
return policy, false, nil
|
return policy, false, nil
|
||||||
}
|
}
|
||||||
if err := Upsert(domain, record.ID, p); err != nil {
|
if err := Upsert(ctx, domain, record.ID, p); err != nil {
|
||||||
log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
|
log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
|
||||||
}
|
}
|
||||||
return p, true, nil
|
return p, true, nil
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package mtastsdb
|
package mtastsdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -24,6 +23,7 @@ func tcheckf(t *testing.T, err error, format string, args ...any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDB(t *testing.T) {
|
func TestDB(t *testing.T) {
|
||||||
|
mox.Shutdown = ctxbg
|
||||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||||
mox.Conf.Static.DataDir = "."
|
mox.Conf.Static.DataDir = "."
|
||||||
|
|
||||||
|
@ -37,14 +37,12 @@ func TestDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer Close()
|
defer Close()
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Mock time.
|
// Mock time.
|
||||||
now := time.Now().Round(0)
|
now := time.Now().Round(0)
|
||||||
timeNow = func() time.Time { return now }
|
timeNow = func() time.Time { return now }
|
||||||
defer func() { timeNow = time.Now }()
|
defer func() { timeNow = time.Now }()
|
||||||
|
|
||||||
if p, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
|
if p, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
|
||||||
t.Fatalf("expected not found, got %v, %#v", err, p)
|
t.Fatalf("expected not found, got %v, %#v", err, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,10 +56,10 @@ func TestDB(t *testing.T) {
|
||||||
},
|
},
|
||||||
MaxAgeSeconds: 1296000,
|
MaxAgeSeconds: 1296000,
|
||||||
}
|
}
|
||||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
|
if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
|
||||||
t.Fatalf("upsert record: %s", err)
|
t.Fatalf("upsert record: %s", err)
|
||||||
}
|
}
|
||||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
if got, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||||
t.Fatalf("lookup after insert: %s", err)
|
t.Fatalf("lookup after insert: %s", err)
|
||||||
} else if !reflect.DeepEqual(got.Policy, policy1) {
|
} else if !reflect.DeepEqual(got.Policy, policy1) {
|
||||||
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
|
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
|
||||||
|
@ -75,17 +73,17 @@ func TestDB(t *testing.T) {
|
||||||
},
|
},
|
||||||
MaxAgeSeconds: 360000,
|
MaxAgeSeconds: 360000,
|
||||||
}
|
}
|
||||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
|
if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
|
||||||
t.Fatalf("upsert record: %s", err)
|
t.Fatalf("upsert record: %s", err)
|
||||||
}
|
}
|
||||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
if got, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||||
t.Fatalf("lookup after insert: %s", err)
|
t.Fatalf("lookup after insert: %s", err)
|
||||||
} else if !reflect.DeepEqual(got.Policy, policy2) {
|
} else if !reflect.DeepEqual(got.Policy, policy2) {
|
||||||
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
|
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if database holds expected record.
|
// Check if database holds expected record.
|
||||||
records, err := PolicyRecords(context.Background())
|
records, err := PolicyRecords(ctxbg)
|
||||||
tcheckf(t, err, "policyrecords")
|
tcheckf(t, err, "policyrecords")
|
||||||
expRecords := []PolicyRecord{
|
expRecords := []PolicyRecord{
|
||||||
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
|
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
|
||||||
|
@ -96,10 +94,10 @@ func TestDB(t *testing.T) {
|
||||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := Upsert(dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
|
if err := Upsert(ctxbg, dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
|
||||||
t.Fatalf("upsert record: %s", err)
|
t.Fatalf("upsert record: %s", err)
|
||||||
}
|
}
|
||||||
records, err = PolicyRecords(context.Background())
|
records, err = PolicyRecords(ctxbg)
|
||||||
tcheckf(t, err, "policyrecords")
|
tcheckf(t, err, "policyrecords")
|
||||||
expRecords = []PolicyRecord{
|
expRecords = []PolicyRecord{
|
||||||
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
|
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
|
||||||
|
@ -109,7 +107,7 @@ func TestDB(t *testing.T) {
|
||||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := lookup(context.Background(), dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
|
if _, err := lookup(ctxbg, dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
|
||||||
t.Fatalf("got %#v, expected ErrBackoff", err)
|
t.Fatalf("got %#v, expected ErrBackoff", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +124,7 @@ func TestDB(t *testing.T) {
|
||||||
|
|
||||||
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
|
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
p, fresh, err := Get(context.Background(), resolver, dns.Domain{ASCII: domain})
|
p, fresh, err := Get(ctxbg, resolver, dns.Domain{ASCII: domain})
|
||||||
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||||||
t.Fatalf("got err %v, expected %v", err, expErr)
|
t.Fatalf("got err %v, expected %v", err, expErr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,19 +52,19 @@ func refresh() int {
|
||||||
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
|
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
|
||||||
// refresh doesn't mess up the timing.
|
// refresh doesn't mess up the timing.
|
||||||
func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
|
func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
now := timeNow()
|
now := timeNow()
|
||||||
qdel := bstore.QueryDB[PolicyRecord](db)
|
qdel := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||||
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
|
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
|
||||||
if _, err := qdel.Delete(); err != nil {
|
if _, err := qdel.Delete(); err != nil {
|
||||||
return 0, fmt.Errorf("deleting old unused policies: %s", err)
|
return 0, fmt.Errorf("deleting old unused policies: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
qup := bstore.QueryDB[PolicyRecord](db)
|
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||||
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
|
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
|
||||||
prs, err := qup.List()
|
prs, err := qup.List()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -127,7 +127,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
|
||||||
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
|
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
|
||||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
|
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
|
||||||
if err == nil && record.ID == pr.RecordID {
|
if err == nil && record.ID == pr.RecordID {
|
||||||
qup := bstore.QueryDB[PolicyRecord](db)
|
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||||
now := timeNow()
|
now := timeNow()
|
||||||
update := PolicyRecord{
|
update := PolicyRecord{
|
||||||
|
@ -166,7 +166,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
|
||||||
if record != nil {
|
if record != nil {
|
||||||
update["RecordID"] = record.ID
|
update["RecordID"] = record.ID
|
||||||
}
|
}
|
||||||
qup := bstore.QueryDB[PolicyRecord](db)
|
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||||
if n, err := qup.UpdateFields(update); err != nil {
|
if n, err := qup.UpdateFields(update); err != nil {
|
||||||
log.Errorx("updating refreshed, modified policy in database", err)
|
log.Errorx("updating refreshed, modified policy in database", err)
|
||||||
|
|
|
@ -26,7 +26,10 @@ import (
|
||||||
"github.com/mjl-/mox/mtasts"
|
"github.com/mjl-/mox/mtasts"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func TestRefresh(t *testing.T) {
|
func TestRefresh(t *testing.T) {
|
||||||
|
mox.Shutdown = ctxbg
|
||||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||||
mox.Conf.Static.DataDir = "."
|
mox.Conf.Static.DataDir = "."
|
||||||
|
|
||||||
|
@ -40,7 +43,7 @@ func TestRefresh(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer Close()
|
defer Close()
|
||||||
|
|
||||||
db, err := database()
|
db, err := database(ctxbg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("database: %s", err)
|
t.Fatalf("database: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +69,7 @@ func TestRefresh(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
|
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
|
||||||
if err := db.Insert(&pr); err != nil {
|
if err := db.Insert(ctxbg, &pr); err != nil {
|
||||||
t.Fatalf("insert policy: %s", err)
|
t.Fatalf("insert policy: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,7 +135,7 @@ func TestRefresh(t *testing.T) {
|
||||||
t.Fatalf("bad sleep duration %v", d)
|
t.Fatalf("bad sleep duration %v", d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if n, err := refresh1(context.Background(), resolver, sleep); err != nil || n != 3 {
|
if n, err := refresh1(ctxbg, resolver, sleep); err != nil || n != 3 {
|
||||||
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
|
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
|
||||||
}
|
}
|
||||||
if slept != 2 {
|
if slept != 2 {
|
||||||
|
@ -141,19 +144,19 @@ func TestRefresh(t *testing.T) {
|
||||||
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
|
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
|
||||||
|
|
||||||
// Should not do any more refreshes and return immediately.
|
// Should not do any more refreshes and return immediately.
|
||||||
q := bstore.QueryDB[PolicyRecord](db)
|
q := bstore.QueryDB[PolicyRecord](ctxbg, db)
|
||||||
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
|
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
|
||||||
if _, err := q.Delete(); err != nil {
|
if _, err := q.Delete(); err != nil {
|
||||||
t.Fatalf("delete record that would be refreshed: %v", err)
|
t.Fatalf("delete record that would be refreshed: %v", err)
|
||||||
}
|
}
|
||||||
mox.Context = context.Background()
|
mox.Context = ctxbg
|
||||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
mox.ShutdownCancel()
|
mox.ShutdownCancel()
|
||||||
n := refresh()
|
n := refresh()
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
t.Fatalf("refresh found unexpected work, n %d", n)
|
t.Fatalf("refresh found unexpected work, n %d", n)
|
||||||
}
|
}
|
||||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type pipeListener struct {
|
type pipeListener struct {
|
||||||
|
|
|
@ -122,7 +122,7 @@ func Init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
queueDB, err = bstore.Open(qpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, Msg{})
|
queueDB, err = bstore.Open(mox.Shutdown, qpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, Msg{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isNew {
|
if isNew {
|
||||||
os.Remove(qpath)
|
os.Remove(qpath)
|
||||||
|
@ -141,8 +141,8 @@ func Shutdown() {
|
||||||
|
|
||||||
// List returns all messages in the delivery queue.
|
// List returns all messages in the delivery queue.
|
||||||
// Ordered by earliest delivery attempt first.
|
// Ordered by earliest delivery attempt first.
|
||||||
func List() ([]Msg, error) {
|
func List(ctx context.Context) ([]Msg, error) {
|
||||||
qmsgs, err := bstore.QueryDB[Msg](queueDB).List()
|
qmsgs, err := bstore.QueryDB[Msg](ctx, queueDB).List()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -165,8 +165,8 @@ func List() ([]Msg, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count returns the number of messages in the delivery queue.
|
// Count returns the number of messages in the delivery queue.
|
||||||
func Count() (int, error) {
|
func Count(ctx context.Context) (int, error) {
|
||||||
return bstore.QueryDB[Msg](queueDB).Count()
|
return bstore.QueryDB[Msg](ctx, queueDB).Count()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a new message to the queue. The queue is kicked immediately to start a
|
// Add a new message to the queue. The queue is kicked immediately to start a
|
||||||
|
@ -179,7 +179,7 @@ func Count() (int, error) {
|
||||||
// this data is used as the message when delivering the DSN and the remote SMTP
|
// this data is used as the message when delivering the DSN and the remote SMTP
|
||||||
// server supports SMTPUTF8. If the remote SMTP server does not support SMTPUTF8,
|
// server supports SMTPUTF8. If the remote SMTP server does not support SMTPUTF8,
|
||||||
// the regular non-utf8 message is delivered.
|
// the regular non-utf8 message is delivered.
|
||||||
func Add(log *mlog.Log, senderAccount string, mailFrom, rcptTo smtp.Path, has8bit, smtputf8 bool, size int64, msgPrefix []byte, msgFile *os.File, dsnutf8Opt []byte, consumeFile bool) error {
|
func Add(ctx context.Context, log *mlog.Log, senderAccount string, mailFrom, rcptTo smtp.Path, has8bit, smtputf8 bool, size int64, msgPrefix []byte, msgFile *os.File, dsnutf8Opt []byte, consumeFile bool) error {
|
||||||
// todo: Add should accept multiple rcptTo if they are for the same domain. so we can queue them for delivery in one (or just a few) session(s), transferring the data only once. ../rfc/5321:3759
|
// todo: Add should accept multiple rcptTo if they are for the same domain. so we can queue them for delivery in one (or just a few) session(s), transferring the data only once. ../rfc/5321:3759
|
||||||
|
|
||||||
if Localserve {
|
if Localserve {
|
||||||
|
@ -187,7 +187,7 @@ func Add(log *mlog.Log, senderAccount string, mailFrom, rcptTo smtp.Path, has8bi
|
||||||
return fmt.Errorf("no queuing with localserve")
|
return fmt.Errorf("no queuing with localserve")
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := queueDB.Begin(true)
|
tx, err := queueDB.Begin(ctx, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("begin transaction: %w", err)
|
return fmt.Errorf("begin transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -287,8 +287,8 @@ func queuekick() {
|
||||||
// and kicks the queue, attempting delivery of those messages. If all parameters
|
// and kicks the queue, attempting delivery of those messages. If all parameters
|
||||||
// are zero, all messages are kicked.
|
// are zero, all messages are kicked.
|
||||||
// Returns number of messages queued for immediate delivery.
|
// Returns number of messages queued for immediate delivery.
|
||||||
func Kick(ID int64, toDomain string, recipient string) (int, error) {
|
func Kick(ctx context.Context, ID int64, toDomain string, recipient string) (int, error) {
|
||||||
q := bstore.QueryDB[Msg](queueDB)
|
q := bstore.QueryDB[Msg](ctx, queueDB)
|
||||||
if ID > 0 {
|
if ID > 0 {
|
||||||
q.FilterID(ID)
|
q.FilterID(ID)
|
||||||
}
|
}
|
||||||
|
@ -311,8 +311,8 @@ func Kick(ID int64, toDomain string, recipient string) (int, error) {
|
||||||
// Drop removes messages from the queue that match all nonzero parameters.
|
// Drop removes messages from the queue that match all nonzero parameters.
|
||||||
// If all parameters are zero, all messages are removed.
|
// If all parameters are zero, all messages are removed.
|
||||||
// Returns number of messages removed.
|
// Returns number of messages removed.
|
||||||
func Drop(ID int64, toDomain string, recipient string) (int, error) {
|
func Drop(ctx context.Context, ID int64, toDomain string, recipient string) (int, error) {
|
||||||
q := bstore.QueryDB[Msg](queueDB)
|
q := bstore.QueryDB[Msg](ctx, queueDB)
|
||||||
if ID > 0 {
|
if ID > 0 {
|
||||||
q.FilterID(ID)
|
q.FilterID(ID)
|
||||||
}
|
}
|
||||||
|
@ -337,9 +337,9 @@ type ReadReaderAtCloser interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenMessage opens a message present in the queue.
|
// OpenMessage opens a message present in the queue.
|
||||||
func OpenMessage(id int64) (ReadReaderAtCloser, error) {
|
func OpenMessage(ctx context.Context, id int64) (ReadReaderAtCloser, error) {
|
||||||
qm := Msg{ID: id}
|
qm := Msg{ID: id}
|
||||||
err := queueDB.Get(&qm)
|
err := queueDB.Get(ctx, &qm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -382,14 +382,14 @@ func Start(resolver dns.Resolver, done chan struct{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
launchWork(resolver, busyDomains)
|
launchWork(resolver, busyDomains)
|
||||||
timer.Reset(nextWork(busyDomains))
|
timer.Reset(nextWork(mox.Shutdown, busyDomains))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nextWork(busyDomains map[string]struct{}) time.Duration {
|
func nextWork(ctx context.Context, busyDomains map[string]struct{}) time.Duration {
|
||||||
q := bstore.QueryDB[Msg](queueDB)
|
q := bstore.QueryDB[Msg](ctx, queueDB)
|
||||||
if len(busyDomains) > 0 {
|
if len(busyDomains) > 0 {
|
||||||
var doms []any
|
var doms []any
|
||||||
for d := range busyDomains {
|
for d := range busyDomains {
|
||||||
|
@ -410,7 +410,7 @@ func nextWork(busyDomains map[string]struct{}) time.Duration {
|
||||||
}
|
}
|
||||||
|
|
||||||
func launchWork(resolver dns.Resolver, busyDomains map[string]struct{}) int {
|
func launchWork(resolver dns.Resolver, busyDomains map[string]struct{}) int {
|
||||||
q := bstore.QueryDB[Msg](queueDB)
|
q := bstore.QueryDB[Msg](mox.Shutdown, queueDB)
|
||||||
q.FilterLessEqual("NextAttempt", time.Now())
|
q.FilterLessEqual("NextAttempt", time.Now())
|
||||||
q.SortAsc("NextAttempt")
|
q.SortAsc("NextAttempt")
|
||||||
q.Limit(maxConcurrentDeliveries)
|
q.Limit(maxConcurrentDeliveries)
|
||||||
|
@ -424,7 +424,7 @@ func launchWork(resolver dns.Resolver, busyDomains map[string]struct{}) int {
|
||||||
msgs, err := q.List()
|
msgs, err := q.List()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Errorx("querying for work in queue", err)
|
xlog.Errorx("querying for work in queue", err)
|
||||||
mox.Sleep(mox.Context, 1*time.Second)
|
mox.Sleep(mox.Shutdown, 1*time.Second)
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -436,8 +436,8 @@ func launchWork(resolver dns.Resolver, busyDomains map[string]struct{}) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove message from queue in database and file system.
|
// Remove message from queue in database and file system.
|
||||||
func queueDelete(msgID int64) error {
|
func queueDelete(ctx context.Context, msgID int64) error {
|
||||||
if err := queueDB.Delete(&Msg{ID: msgID}); err != nil {
|
if err := queueDB.Delete(ctx, &Msg{ID: msgID}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// If removing from database fails, we'll also leave the file in the file system.
|
// If removing from database fails, we'll also leave the file in the file system.
|
||||||
|
@ -483,7 +483,7 @@ func deliver(resolver dns.Resolver, m Msg) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
m.LastAttempt = &now
|
m.LastAttempt = &now
|
||||||
m.NextAttempt = now.Add(backoff)
|
m.NextAttempt = now.Add(backoff)
|
||||||
qup := bstore.QueryDB[Msg](queueDB)
|
qup := bstore.QueryDB[Msg](mox.Shutdown, queueDB)
|
||||||
qup.FilterID(m.ID)
|
qup.FilterID(m.ID)
|
||||||
update := Msg{Attempts: m.Attempts, NextAttempt: m.NextAttempt, LastAttempt: m.LastAttempt}
|
update := Msg{Attempts: m.Attempts, NextAttempt: m.NextAttempt, LastAttempt: m.LastAttempt}
|
||||||
if _, err := qup.UpdateNonzero(update); err != nil {
|
if _, err := qup.UpdateNonzero(update); err != nil {
|
||||||
|
@ -496,13 +496,13 @@ func deliver(resolver dns.Resolver, m Msg) {
|
||||||
qlog.Errorx("permanent failure delivering from queue", errors.New(errmsg))
|
qlog.Errorx("permanent failure delivering from queue", errors.New(errmsg))
|
||||||
queueDSNFailure(qlog, m, remoteMTA, secodeOpt, errmsg)
|
queueDSNFailure(qlog, m, remoteMTA, secodeOpt, errmsg)
|
||||||
|
|
||||||
if err := queueDelete(m.ID); err != nil {
|
if err := queueDelete(context.Background(), m.ID); err != nil {
|
||||||
qlog.Errorx("deleting message from queue after permanent failure", err)
|
qlog.Errorx("deleting message from queue after permanent failure", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qup := bstore.QueryDB[Msg](queueDB)
|
qup := bstore.QueryDB[Msg](context.Background(), queueDB)
|
||||||
qup.FilterID(m.ID)
|
qup.FilterID(m.ID)
|
||||||
if _, err := qup.UpdateNonzero(Msg{LastError: errmsg, DialedIPs: m.DialedIPs}); err != nil {
|
if _, err := qup.UpdateNonzero(Msg{LastError: errmsg, DialedIPs: m.DialedIPs}); err != nil {
|
||||||
qlog.Errorx("storing delivery error", err, mlog.Field("deliveryerror", errmsg))
|
qlog.Errorx("storing delivery error", err, mlog.Field("deliveryerror", errmsg))
|
||||||
|
@ -534,7 +534,7 @@ func deliver(resolver dns.Resolver, m Msg) {
|
||||||
var policy *mtasts.Policy
|
var policy *mtasts.Policy
|
||||||
tlsModeDefault := smtpclient.TLSOpportunistic
|
tlsModeDefault := smtpclient.TLSOpportunistic
|
||||||
if !effectiveDomain.IsZero() {
|
if !effectiveDomain.IsZero() {
|
||||||
cidctx := context.WithValue(mox.Context, mlog.CidKey, cid)
|
cidctx := context.WithValue(mox.Shutdown, mlog.CidKey, cid)
|
||||||
policy, policyFresh, err = mtastsdb.Get(cidctx, resolver, effectiveDomain)
|
policy, policyFresh, err = mtastsdb.Get(cidctx, resolver, effectiveDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// No need to refuse to deliver if we have some mtasts error.
|
// No need to refuse to deliver if we have some mtasts error.
|
||||||
|
@ -586,7 +586,7 @@ func deliver(resolver dns.Resolver, m Msg) {
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
nqlog.Info("delivered from queue")
|
nqlog.Info("delivered from queue")
|
||||||
if err := queueDelete(m.ID); err != nil {
|
if err := queueDelete(context.Background(), m.ID); err != nil {
|
||||||
nqlog.Errorx("deleting message from queue after delivery", err)
|
nqlog.Errorx("deleting message from queue after delivery", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -21,6 +21,8 @@ import (
|
||||||
"github.com/mjl-/mox/store"
|
"github.com/mjl-/mox/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func tcheck(t *testing.T, err error, msg string) {
|
func tcheck(t *testing.T, err error, msg string) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
@ -31,7 +33,7 @@ func tcheck(t *testing.T, err error, msg string) {
|
||||||
func setup(t *testing.T) (*store.Account, func()) {
|
func setup(t *testing.T) (*store.Account, func()) {
|
||||||
// Prepare config so email can be delivered to mjl@mox.example.
|
// Prepare config so email can be delivered to mjl@mox.example.
|
||||||
os.RemoveAll("../testdata/queue/data")
|
os.RemoveAll("../testdata/queue/data")
|
||||||
mox.Context = context.Background()
|
mox.Context = ctxbg
|
||||||
mox.ConfigStaticPath = "../testdata/queue/mox.conf"
|
mox.ConfigStaticPath = "../testdata/queue/mox.conf"
|
||||||
mox.MustLoadConfig(false)
|
mox.MustLoadConfig(false)
|
||||||
acc, err := store.OpenAccount("mjl")
|
acc, err := store.OpenAccount("mjl")
|
||||||
|
@ -39,11 +41,11 @@ func setup(t *testing.T) (*store.Account, func()) {
|
||||||
err = acc.SetPassword("testtest")
|
err = acc.SetPassword("testtest")
|
||||||
tcheck(t, err, "set password")
|
tcheck(t, err, "set password")
|
||||||
switchDone := store.Switchboard()
|
switchDone := store.Switchboard()
|
||||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
return acc, func() {
|
return acc, func() {
|
||||||
acc.Close()
|
acc.Close()
|
||||||
mox.ShutdownCancel()
|
mox.ShutdownCancel()
|
||||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
Shutdown()
|
Shutdown()
|
||||||
close(switchDone)
|
close(switchDone)
|
||||||
}
|
}
|
||||||
|
@ -71,22 +73,22 @@ func TestQueue(t *testing.T) {
|
||||||
err := Init()
|
err := Init()
|
||||||
tcheck(t, err, "queue init")
|
tcheck(t, err, "queue init")
|
||||||
|
|
||||||
msgs, err := List()
|
msgs, err := List(ctxbg)
|
||||||
tcheck(t, err, "listing messages in queue")
|
tcheck(t, err, "listing messages in queue")
|
||||||
if len(msgs) != 0 {
|
if len(msgs) != 0 {
|
||||||
t.Fatalf("got %d messages in queue, expected 0", len(msgs))
|
t.Fatalf("got %d messages in queue, expected 0", len(msgs))
|
||||||
}
|
}
|
||||||
|
|
||||||
path := smtp.Path{Localpart: "mjl", IPDomain: dns.IPDomain{Domain: dns.Domain{ASCII: "mox.example"}}}
|
path := smtp.Path{Localpart: "mjl", IPDomain: dns.IPDomain{Domain: dns.Domain{ASCII: "mox.example"}}}
|
||||||
err = Add(xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
err = Add(ctxbg, xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
||||||
tcheck(t, err, "add message to queue for delivery")
|
tcheck(t, err, "add message to queue for delivery")
|
||||||
|
|
||||||
mf2 := prepareFile(t)
|
mf2 := prepareFile(t)
|
||||||
err = Add(xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, mf2, nil, false)
|
err = Add(ctxbg, xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, mf2, nil, false)
|
||||||
tcheck(t, err, "add message to queue for delivery")
|
tcheck(t, err, "add message to queue for delivery")
|
||||||
os.Remove(mf2.Name())
|
os.Remove(mf2.Name())
|
||||||
|
|
||||||
msgs, err = List()
|
msgs, err = List(ctxbg)
|
||||||
tcheck(t, err, "listing queue")
|
tcheck(t, err, "listing queue")
|
||||||
if len(msgs) != 2 {
|
if len(msgs) != 2 {
|
||||||
t.Fatalf("got msgs %v, expected 1", msgs)
|
t.Fatalf("got msgs %v, expected 1", msgs)
|
||||||
|
@ -95,18 +97,18 @@ func TestQueue(t *testing.T) {
|
||||||
if msg.Attempts != 0 {
|
if msg.Attempts != 0 {
|
||||||
t.Fatalf("msg attempts %d, expected 0", msg.Attempts)
|
t.Fatalf("msg attempts %d, expected 0", msg.Attempts)
|
||||||
}
|
}
|
||||||
n, err := Drop(msgs[1].ID, "", "")
|
n, err := Drop(ctxbg, msgs[1].ID, "", "")
|
||||||
tcheck(t, err, "drop")
|
tcheck(t, err, "drop")
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
t.Fatalf("dropped %d, expected 1", n)
|
t.Fatalf("dropped %d, expected 1", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
next := nextWork(nil)
|
next := nextWork(ctxbg, nil)
|
||||||
if next > 0 {
|
if next > 0 {
|
||||||
t.Fatalf("nextWork in %s, should be now", next)
|
t.Fatalf("nextWork in %s, should be now", next)
|
||||||
}
|
}
|
||||||
busy := map[string]struct{}{"mox.example": {}}
|
busy := map[string]struct{}{"mox.example": {}}
|
||||||
if x := nextWork(busy); x != 24*time.Hour {
|
if x := nextWork(ctxbg, busy); x != 24*time.Hour {
|
||||||
t.Fatalf("nextWork in %s for busy domain, should be in 24 hours", x)
|
t.Fatalf("nextWork in %s for busy domain, should be in 24 hours", x)
|
||||||
}
|
}
|
||||||
if nn := launchWork(nil, busy); nn != 0 {
|
if nn := launchWork(nil, busy); nn != 0 {
|
||||||
|
@ -133,7 +135,7 @@ func TestQueue(t *testing.T) {
|
||||||
case <-dialed:
|
case <-dialed:
|
||||||
i := 0
|
i := 0
|
||||||
for {
|
for {
|
||||||
m, err := bstore.QueryDB[Msg](queueDB).Get()
|
m, err := bstore.QueryDB[Msg](ctxbg, queueDB).Get()
|
||||||
tcheck(t, err, "get")
|
tcheck(t, err, "get")
|
||||||
if m.Attempts == 1 {
|
if m.Attempts == 1 {
|
||||||
break
|
break
|
||||||
|
@ -149,11 +151,11 @@ func TestQueue(t *testing.T) {
|
||||||
}
|
}
|
||||||
<-deliveryResult // Deliver sends here.
|
<-deliveryResult // Deliver sends here.
|
||||||
|
|
||||||
_, err = OpenMessage(msg.ID + 1)
|
_, err = OpenMessage(ctxbg, msg.ID+1)
|
||||||
if err != bstore.ErrAbsent {
|
if err != bstore.ErrAbsent {
|
||||||
t.Fatalf("OpenMessage, got %v, expected ErrAbsent", err)
|
t.Fatalf("OpenMessage, got %v, expected ErrAbsent", err)
|
||||||
}
|
}
|
||||||
reader, err := OpenMessage(msg.ID)
|
reader, err := OpenMessage(ctxbg, msg.ID)
|
||||||
tcheck(t, err, "open message")
|
tcheck(t, err, "open message")
|
||||||
defer reader.Close()
|
defer reader.Close()
|
||||||
msgbuf, err := io.ReadAll(reader)
|
msgbuf, err := io.ReadAll(reader)
|
||||||
|
@ -162,12 +164,12 @@ func TestQueue(t *testing.T) {
|
||||||
t.Fatalf("message mismatch, got %q, expected %q", string(msgbuf), testmsg)
|
t.Fatalf("message mismatch, got %q, expected %q", string(msgbuf), testmsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = Kick(msg.ID+1, "", "")
|
n, err = Kick(ctxbg, msg.ID+1, "", "")
|
||||||
tcheck(t, err, "kick")
|
tcheck(t, err, "kick")
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
t.Fatalf("kick %d, expected 0", n)
|
t.Fatalf("kick %d, expected 0", n)
|
||||||
}
|
}
|
||||||
n, err = Kick(msg.ID, "", "")
|
n, err = Kick(ctxbg, msg.ID, "", "")
|
||||||
tcheck(t, err, "kick")
|
tcheck(t, err, "kick")
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
t.Fatalf("kicked %d, expected 1", n)
|
t.Fatalf("kicked %d, expected 1", n)
|
||||||
|
@ -215,7 +217,7 @@ func TestQueue(t *testing.T) {
|
||||||
case <-smtpdone:
|
case <-smtpdone:
|
||||||
i := 0
|
i := 0
|
||||||
for {
|
for {
|
||||||
xmsgs, err := List()
|
xmsgs, err := List(ctxbg)
|
||||||
tcheck(t, err, "list queue")
|
tcheck(t, err, "list queue")
|
||||||
if len(xmsgs) == 0 {
|
if len(xmsgs) == 0 {
|
||||||
break
|
break
|
||||||
|
@ -235,10 +237,10 @@ func TestQueue(t *testing.T) {
|
||||||
<-deliveryResult // Deliver sends here.
|
<-deliveryResult // Deliver sends here.
|
||||||
|
|
||||||
// Add another message that we'll fail to deliver entirely.
|
// Add another message that we'll fail to deliver entirely.
|
||||||
err = Add(xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
err = Add(ctxbg, xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
||||||
tcheck(t, err, "add message to queue for delivery")
|
tcheck(t, err, "add message to queue for delivery")
|
||||||
|
|
||||||
msgs, err = List()
|
msgs, err = List(ctxbg)
|
||||||
tcheck(t, err, "list queue")
|
tcheck(t, err, "list queue")
|
||||||
if len(msgs) != 1 {
|
if len(msgs) != 1 {
|
||||||
t.Fatalf("queue has %d messages, expected 1", len(msgs))
|
t.Fatalf("queue has %d messages, expected 1", len(msgs))
|
||||||
|
@ -283,7 +285,7 @@ func TestQueue(t *testing.T) {
|
||||||
for i := 1; i < 8; i++ {
|
for i := 1; i < 8; i++ {
|
||||||
go func() { <-deliveryResult }() // Deliver sends here.
|
go func() { <-deliveryResult }() // Deliver sends here.
|
||||||
deliver(resolver, msg)
|
deliver(resolver, msg)
|
||||||
err = queueDB.Get(&msg)
|
err = queueDB.Get(ctxbg, &msg)
|
||||||
tcheck(t, err, "get msg")
|
tcheck(t, err, "get msg")
|
||||||
if msg.Attempts != i {
|
if msg.Attempts != i {
|
||||||
t.Fatalf("got attempt %d, expected %d", msg.Attempts, i)
|
t.Fatalf("got attempt %d, expected %d", msg.Attempts, i)
|
||||||
|
@ -306,7 +308,7 @@ func TestQueue(t *testing.T) {
|
||||||
// Trigger final failure.
|
// Trigger final failure.
|
||||||
go func() { <-deliveryResult }() // Deliver sends here.
|
go func() { <-deliveryResult }() // Deliver sends here.
|
||||||
deliver(resolver, msg)
|
deliver(resolver, msg)
|
||||||
err = queueDB.Get(&msg)
|
err = queueDB.Get(ctxbg, &msg)
|
||||||
if err != bstore.ErrAbsent {
|
if err != bstore.ErrAbsent {
|
||||||
t.Fatalf("attempt to fetch delivered and removed message from queue, got err %v, expected ErrAbsent", err)
|
t.Fatalf("attempt to fetch delivered and removed message from queue, got err %v, expected ErrAbsent", err)
|
||||||
}
|
}
|
||||||
|
@ -343,7 +345,7 @@ func TestQueueStart(t *testing.T) {
|
||||||
defer func() {
|
defer func() {
|
||||||
mox.ShutdownCancel()
|
mox.ShutdownCancel()
|
||||||
<-done
|
<-done
|
||||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
}()
|
}()
|
||||||
err := Start(resolver, done)
|
err := Start(resolver, done)
|
||||||
tcheck(t, err, "queue start")
|
tcheck(t, err, "queue start")
|
||||||
|
@ -369,7 +371,7 @@ func TestQueueStart(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
path := smtp.Path{Localpart: "mjl", IPDomain: dns.IPDomain{Domain: dns.Domain{ASCII: "mox.example"}}}
|
path := smtp.Path{Localpart: "mjl", IPDomain: dns.IPDomain{Domain: dns.Domain{ASCII: "mox.example"}}}
|
||||||
err = Add(xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
err = Add(ctxbg, xlog, "mjl", path, path, false, false, int64(len(testmsg)), nil, prepareFile(t), nil, true)
|
||||||
tcheck(t, err, "add message to queue for delivery")
|
tcheck(t, err, "add message to queue for delivery")
|
||||||
checkDialed(true)
|
checkDialed(true)
|
||||||
|
|
||||||
|
@ -378,7 +380,7 @@ func TestQueueStart(t *testing.T) {
|
||||||
checkDialed(false)
|
checkDialed(false)
|
||||||
|
|
||||||
// Kick for real, should see another attempt.
|
// Kick for real, should see another attempt.
|
||||||
n, err := Kick(0, "mox.example", "")
|
n, err := Kick(ctxbg, 0, "mox.example", "")
|
||||||
tcheck(t, err, "kick queue")
|
tcheck(t, err, "kick queue")
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
t.Fatalf("kick changed %d messages, expected 1", n)
|
t.Fatalf("kick changed %d messages, expected 1", n)
|
||||||
|
@ -402,7 +404,7 @@ func TestWriteFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatherHosts(t *testing.T) {
|
func TestGatherHosts(t *testing.T) {
|
||||||
mox.Context = context.Background()
|
mox.Context = ctxbg
|
||||||
|
|
||||||
// Test basic MX lookup case, but also following CNAME, detecting CNAME loops and
|
// Test basic MX lookup case, but also following CNAME, detecting CNAME loops and
|
||||||
// having a CNAME limit, connecting directly to a host, and domain that does not
|
// having a CNAME limit, connecting directly to a host, and domain that does not
|
||||||
|
@ -524,11 +526,11 @@ func TestDialHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
m := Msg{DialedIPs: map[string][]net.IP{}}
|
m := Msg{DialedIPs: map[string][]net.IP{}}
|
||||||
_, ip, dualstack, err := dialHost(context.Background(), xlog, resolver, ipdomain("dualstack.example"), &m)
|
_, ip, dualstack, err := dialHost(ctxbg, xlog, resolver, ipdomain("dualstack.example"), &m)
|
||||||
if err != nil || ip.String() != "10.0.0.1" || !dualstack {
|
if err != nil || ip.String() != "10.0.0.1" || !dualstack {
|
||||||
t.Fatalf("expected err nil, address 10.0.0.1, dualstack true, got %v %v %v", err, ip, dualstack)
|
t.Fatalf("expected err nil, address 10.0.0.1, dualstack true, got %v %v %v", err, ip, dualstack)
|
||||||
}
|
}
|
||||||
_, ip, dualstack, err = dialHost(context.Background(), xlog, resolver, ipdomain("dualstack.example"), &m)
|
_, ip, dualstack, err = dialHost(ctxbg, xlog, resolver, ipdomain("dualstack.example"), &m)
|
||||||
if err != nil || ip.String() != "2001:db8::1" || !dualstack {
|
if err != nil || ip.String() != "2001:db8::1" || !dualstack {
|
||||||
t.Fatalf("expected err nil, address 2001:db8::1, dualstack true, got %v %v %v", err, ip, dualstack)
|
t.Fatalf("expected err nil, address 2001:db8::1, dualstack true, got %v %v %v", err, ip, dualstack)
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,7 +161,7 @@ func analyze(ctx context.Context, log *mlog.Log, resolver dns.Resolver, d delive
|
||||||
var reason string
|
var reason string
|
||||||
var err error
|
var err error
|
||||||
d.acc.WithRLock(func() {
|
d.acc.WithRLock(func() {
|
||||||
err = d.acc.DB.Read(func(tx *bstore.Tx) error {
|
err = d.acc.DB.Read(ctx, func(tx *bstore.Tx) error {
|
||||||
// Set message MailboxID to which mail will be delivered. Reputation is
|
// Set message MailboxID to which mail will be delivered. Reputation is
|
||||||
// per-mailbox. If referenced mailbox is not found (e.g. does not yet exist), we
|
// per-mailbox. If referenced mailbox is not found (e.g. does not yet exist), we
|
||||||
// can still determine a reputation because we also base it on outgoing
|
// can still determine a reputation because we also base it on outgoing
|
||||||
|
@ -251,13 +251,13 @@ func analyze(ctx context.Context, log *mlog.Log, resolver dns.Resolver, d delive
|
||||||
reason = reasonNoBadSignals
|
reason = reasonNoBadSignals
|
||||||
accept := true
|
accept := true
|
||||||
var junkSubjectpass bool
|
var junkSubjectpass bool
|
||||||
f, jf, err := d.acc.OpenJunkFilter(log)
|
f, jf, err := d.acc.OpenJunkFilter(ctx, log)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer func() {
|
defer func() {
|
||||||
err := f.Close()
|
err := f.Close()
|
||||||
log.Check(err, "closing junkfilter")
|
log.Check(err, "closing junkfilter")
|
||||||
}()
|
}()
|
||||||
contentProb, _, _, _, err := f.ClassifyMessageReader(store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size)
|
contentProb, _, _, _, err := f.ClassifyMessageReader(ctx, store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorx("testing for spam", err)
|
log.Errorx("testing for spam", err)
|
||||||
return reject(smtp.C451LocalErr, smtp.SeSys3Other0, "error processing", err, reasonJunkClassifyError)
|
return reject(smtp.C451LocalErr, smtp.SeSys3Other0, "error processing", err, reasonJunkClassifyError)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package smtpserver
|
package smtpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// compose dsn message and add it to the queue for delivery to rcptTo.
|
// compose dsn message and add it to the queue for delivery to rcptTo.
|
||||||
func queueDSN(c *conn, rcptTo smtp.Path, m dsn.Message) error {
|
func queueDSN(ctx context.Context, c *conn, rcptTo smtp.Path, m dsn.Message) error {
|
||||||
buf, err := m.Compose(c.log, false)
|
buf, err := m.Compose(c.log, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -45,7 +46,7 @@ func queueDSN(c *conn, rcptTo smtp.Path, m dsn.Message) error {
|
||||||
// ../rfc/3464:433
|
// ../rfc/3464:433
|
||||||
const has8bit = false
|
const has8bit = false
|
||||||
const smtputf8 = false
|
const smtputf8 = false
|
||||||
if err := queue.Add(c.log, "", smtp.Path{}, rcptTo, has8bit, smtputf8, int64(len(buf)), nil, f, bufUTF8, true); err != nil {
|
if err := queue.Add(ctx, c.log, "", smtp.Path{}, rcptTo, has8bit, smtputf8, int64(len(buf)), nil, f, bufUTF8, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = f.Close()
|
err = f.Close()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package smtpserver
|
package smtpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
@ -30,7 +29,7 @@ func FuzzServer(f *testing.F) {
|
||||||
f.Add("NOOP")
|
f.Add("NOOP")
|
||||||
f.Add("QUIT")
|
f.Add("QUIT")
|
||||||
|
|
||||||
mox.Context = context.Background()
|
mox.Context = ctxbg
|
||||||
mox.ConfigStaticPath = "../testdata/smtp/mox.conf"
|
mox.ConfigStaticPath = "../testdata/smtp/mox.conf"
|
||||||
mox.MustLoadConfig(false)
|
mox.MustLoadConfig(false)
|
||||||
dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir)
|
dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package smtpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -40,7 +41,7 @@ func rejectPresent(log *mlog.Log, acc *store.Account, rejectsMailbox string, m *
|
||||||
var exists bool
|
var exists bool
|
||||||
var err error
|
var err error
|
||||||
acc.WithRLock(func() {
|
acc.WithRLock(func() {
|
||||||
err = acc.DB.Read(func(tx *bstore.Tx) error {
|
err = acc.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
mbq := bstore.QueryTx[store.Mailbox](tx)
|
mbq := bstore.QueryTx[store.Mailbox](tx)
|
||||||
mbq.FilterNonzero(store.Mailbox{Name: rejectsMailbox})
|
mbq.FilterNonzero(store.Mailbox{Name: rejectsMailbox})
|
||||||
mb, err := mbq.Get()
|
mb, err := mbq.Get()
|
||||||
|
|
|
@ -272,16 +272,8 @@ func reputation(tx *bstore.Tx, log *mlog.Log, m *store.Message) (rjunk *bool, rc
|
||||||
dkimspfsignals := []float64{}
|
dkimspfsignals := []float64{}
|
||||||
dkimspfmsgs := 0
|
dkimspfmsgs := 0
|
||||||
for _, dom := range m.DKIMDomains {
|
for _, dom := range m.DKIMDomains {
|
||||||
// todo: should get dkimdomains in an index for faster lookup. bstore does not yet support "in" indexes.
|
|
||||||
q := messageQuery(nil, year/2, 50)
|
q := messageQuery(nil, year/2, 50)
|
||||||
q.FilterFn(func(m store.Message) bool {
|
q.FilterIn("DKIMDomains", dom)
|
||||||
for _, d := range m.DKIMDomains {
|
|
||||||
if d == dom {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
msgs := xmessageList(q, "dkimdomain")
|
msgs := xmessageList(q, "dkimdomain")
|
||||||
if len(msgs) > 0 {
|
if len(msgs) > 0 {
|
||||||
nspam := 0
|
nspam := 0
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package smtpserver
|
package smtpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
@ -77,7 +76,7 @@ func TestReputation(t *testing.T) {
|
||||||
|
|
||||||
MsgFromLocalpart: msgFrom.Localpart,
|
MsgFromLocalpart: msgFrom.Localpart,
|
||||||
MsgFromDomain: msgFrom.Domain.Name(),
|
MsgFromDomain: msgFrom.Domain.Name(),
|
||||||
MsgFromOrgDomain: publicsuffix.Lookup(context.Background(), msgFrom.Domain).Name(),
|
MsgFromOrgDomain: publicsuffix.Lookup(ctxbg, msgFrom.Domain).Name(),
|
||||||
|
|
||||||
MailFromValidated: mailfromValid,
|
MailFromValidated: mailfromValid,
|
||||||
EHLOValidated: ehloValid,
|
EHLOValidated: ehloValid,
|
||||||
|
@ -103,11 +102,11 @@ func TestReputation(t *testing.T) {
|
||||||
p := "../testdata/smtpserver-reputation.db"
|
p := "../testdata/smtpserver-reputation.db"
|
||||||
defer os.Remove(p)
|
defer os.Remove(p)
|
||||||
|
|
||||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second}, store.Message{}, store.Recipient{}, store.Mailbox{})
|
db, err := bstore.Open(ctxbg, p, &bstore.Options{Timeout: 5 * time.Second}, store.Message{}, store.Recipient{}, store.Mailbox{})
|
||||||
tcheck(t, err, "open db")
|
tcheck(t, err, "open db")
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
err = db.Write(func(tx *bstore.Tx) error {
|
err = db.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
err = tx.Insert(&store.Mailbox{ID: 1, Name: "Inbox"})
|
err = tx.Insert(&store.Mailbox{ID: 1, Name: "Inbox"})
|
||||||
tcheck(t, err, "insert into db")
|
tcheck(t, err, "insert into db")
|
||||||
|
|
||||||
|
@ -117,7 +116,7 @@ func TestReputation(t *testing.T) {
|
||||||
|
|
||||||
rcptToDomain, err := dns.ParseDomain(hm.RcptToDomain)
|
rcptToDomain, err := dns.ParseDomain(hm.RcptToDomain)
|
||||||
tcheck(t, err, "parse rcptToDomain")
|
tcheck(t, err, "parse rcptToDomain")
|
||||||
rcptToOrgDomain := publicsuffix.Lookup(context.Background(), rcptToDomain)
|
rcptToOrgDomain := publicsuffix.Lookup(ctxbg, rcptToDomain)
|
||||||
r := store.Recipient{MessageID: hm.ID, Localpart: hm.RcptToLocalpart, Domain: hm.RcptToDomain, OrgDomain: rcptToOrgDomain.Name(), Sent: hm.Received}
|
r := store.Recipient{MessageID: hm.ID, Localpart: hm.RcptToLocalpart, Domain: hm.RcptToDomain, OrgDomain: rcptToOrgDomain.Name(), Sent: hm.Received}
|
||||||
err = tx.Insert(&r)
|
err = tx.Insert(&r)
|
||||||
tcheck(t, err, "insert recipient")
|
tcheck(t, err, "insert recipient")
|
||||||
|
@ -130,7 +129,7 @@ func TestReputation(t *testing.T) {
|
||||||
var isjunk *bool
|
var isjunk *bool
|
||||||
var conclusive bool
|
var conclusive bool
|
||||||
var method reputationMethod
|
var method reputationMethod
|
||||||
err = db.Read(func(tx *bstore.Tx) error {
|
err = db.Read(ctxbg, func(tx *bstore.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
isjunk, conclusive, method, err = reputation(tx, xlog, &m)
|
isjunk, conclusive, method, err = reputation(tx, xlog, &m)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -1028,7 +1028,7 @@ func (c *conn) cmdAuth(p *parser) {
|
||||||
}()
|
}()
|
||||||
var ipadhash, opadhash hash.Hash
|
var ipadhash, opadhash hash.Hash
|
||||||
acc.WithRLock(func() {
|
acc.WithRLock(func() {
|
||||||
err := acc.DB.Read(func(tx *bstore.Tx) error {
|
err := acc.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
password, err := bstore.QueryTx[store.Password](tx).Get()
|
password, err := bstore.QueryTx[store.Password](tx).Get()
|
||||||
if err == bstore.ErrAbsent {
|
if err == bstore.ErrAbsent {
|
||||||
xsmtpUserErrorf(smtp.C535AuthBadCreds, smtp.SePol7AuthBadCreds8, "bad user/pass")
|
xsmtpUserErrorf(smtp.C535AuthBadCreds, smtp.SePol7AuthBadCreds8, "bad user/pass")
|
||||||
|
@ -1101,7 +1101,7 @@ func (c *conn) cmdAuth(p *parser) {
|
||||||
}
|
}
|
||||||
var xscram store.SCRAM
|
var xscram store.SCRAM
|
||||||
acc.WithRLock(func() {
|
acc.WithRLock(func() {
|
||||||
err := acc.DB.Read(func(tx *bstore.Tx) error {
|
err := acc.DB.Read(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
password, err := bstore.QueryTx[store.Password](tx).Get()
|
password, err := bstore.QueryTx[store.Password](tx).Get()
|
||||||
if authVariant == "scram-sha-1" {
|
if authVariant == "scram-sha-1" {
|
||||||
xscram = password.SCRAMSHA1
|
xscram = password.SCRAMSHA1
|
||||||
|
@ -1672,7 +1672,7 @@ func (c *conn) submit(ctx context.Context, recvHdrFor func(string) string, msgWr
|
||||||
// Limit damage to the internet and our reputation in case of account compromise by
|
// Limit damage to the internet and our reputation in case of account compromise by
|
||||||
// limiting the max number of messages sent in a 24 hour window, both total number
|
// limiting the max number of messages sent in a 24 hour window, both total number
|
||||||
// of messages and number of first-time recipients.
|
// of messages and number of first-time recipients.
|
||||||
err = c.account.DB.Read(func(tx *bstore.Tx) error {
|
err = c.account.DB.Read(ctx, func(tx *bstore.Tx) error {
|
||||||
conf, _ := c.account.Conf()
|
conf, _ := c.account.Conf()
|
||||||
msgmax := conf.MaxOutgoingMessagesPerDay
|
msgmax := conf.MaxOutgoingMessagesPerDay
|
||||||
if msgmax == 0 {
|
if msgmax == 0 {
|
||||||
|
@ -1828,7 +1828,7 @@ func (c *conn) submit(ctx context.Context, recvHdrFor func(string) string, msgWr
|
||||||
metricSubmission.WithLabelValues("ok").Inc()
|
metricSubmission.WithLabelValues("ok").Inc()
|
||||||
c.log.Info("submitted message delivered", mlog.Field("mailfrom", *c.mailFrom), mlog.Field("rcptto", rcptAcc.rcptTo), mlog.Field("smtputf8", c.smtputf8), mlog.Field("msgsize", msgSize))
|
c.log.Info("submitted message delivered", mlog.Field("mailfrom", *c.mailFrom), mlog.Field("rcptto", rcptAcc.rcptTo), mlog.Field("smtputf8", c.smtputf8), mlog.Field("msgsize", msgSize))
|
||||||
|
|
||||||
err := c.account.DB.Insert(&store.Outgoing{Recipient: rcptAcc.rcptTo.XString(true)})
|
err := c.account.DB.Insert(ctx, &store.Outgoing{Recipient: rcptAcc.rcptTo.XString(true)})
|
||||||
xcheckf(err, "adding outgoing message")
|
xcheckf(err, "adding outgoing message")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1852,7 +1852,7 @@ func (c *conn) submit(ctx context.Context, recvHdrFor func(string) string, msgWr
|
||||||
}
|
}
|
||||||
|
|
||||||
msgSize := int64(len(xmsgPrefix)) + msgWriter.Size
|
msgSize := int64(len(xmsgPrefix)) + msgWriter.Size
|
||||||
if err := queue.Add(c.log, c.account.Name, *c.mailFrom, rcptAcc.rcptTo, msgWriter.Has8bit, c.smtputf8, msgSize, xmsgPrefix, dataFile, nil, i == len(c.recipients)-1); err != nil {
|
if err := queue.Add(ctx, c.log, c.account.Name, *c.mailFrom, rcptAcc.rcptTo, msgWriter.Has8bit, c.smtputf8, msgSize, xmsgPrefix, dataFile, nil, i == len(c.recipients)-1); err != nil {
|
||||||
// Aborting the transaction is not great. But continuing and generating DSNs will
|
// Aborting the transaction is not great. But continuing and generating DSNs will
|
||||||
// probably result in errors as well...
|
// probably result in errors as well...
|
||||||
metricSubmission.WithLabelValues("queueerror").Inc()
|
metricSubmission.WithLabelValues("queueerror").Inc()
|
||||||
|
@ -1862,7 +1862,7 @@ func (c *conn) submit(ctx context.Context, recvHdrFor func(string) string, msgWr
|
||||||
metricSubmission.WithLabelValues("ok").Inc()
|
metricSubmission.WithLabelValues("ok").Inc()
|
||||||
c.log.Info("message queued for delivery", mlog.Field("mailfrom", *c.mailFrom), mlog.Field("rcptto", rcptAcc.rcptTo), mlog.Field("smtputf8", c.smtputf8), mlog.Field("msgsize", msgSize))
|
c.log.Info("message queued for delivery", mlog.Field("mailfrom", *c.mailFrom), mlog.Field("rcptto", rcptAcc.rcptTo), mlog.Field("smtputf8", c.smtputf8), mlog.Field("msgsize", msgSize))
|
||||||
|
|
||||||
err := c.account.DB.Insert(&store.Outgoing{Recipient: rcptAcc.rcptTo.XString(true)})
|
err := c.account.DB.Insert(ctx, &store.Outgoing{Recipient: rcptAcc.rcptTo.XString(true)})
|
||||||
xcheckf(err, "adding outgoing message")
|
xcheckf(err, "adding outgoing message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2267,7 +2267,7 @@ func (c *conn) deliver(ctx context.Context, recvHdrFor func(string) string, msgW
|
||||||
// account. They may fill up the mailbox, either with messages that have to be
|
// account. They may fill up the mailbox, either with messages that have to be
|
||||||
// purged, or by filling the disk. We check both cases for IP's and networks.
|
// purged, or by filling the disk. We check both cases for IP's and networks.
|
||||||
var rateError bool // Whether returned error represents a rate error.
|
var rateError bool // Whether returned error represents a rate error.
|
||||||
err = acc.DB.Read(func(tx *bstore.Tx) (retErr error) {
|
err = acc.DB.Read(ctx, func(tx *bstore.Tx) (retErr error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
log.Debugx("checking message and size delivery rates", retErr, mlog.Field("duration", time.Since(now)))
|
log.Debugx("checking message and size delivery rates", retErr, mlog.Field("duration", time.Since(now)))
|
||||||
|
@ -2575,7 +2575,7 @@ func (c *conn) deliver(ctx context.Context, recvHdrFor func(string) string, msgW
|
||||||
|
|
||||||
if Localserve {
|
if Localserve {
|
||||||
c.log.Error("not queueing dsn for incoming delivery due to localserve")
|
c.log.Error("not queueing dsn for incoming delivery due to localserve")
|
||||||
} else if err := queueDSN(c, *c.mailFrom, dsnMsg); err != nil {
|
} else if err := queueDSN(context.TODO(), c, *c.mailFrom, dsnMsg); err != nil {
|
||||||
metricServerErrors.WithLabelValues("queuedsn").Inc()
|
metricServerErrors.WithLabelValues("queuedsn").Inc()
|
||||||
c.log.Errorx("queuing DSN for incoming delivery, no DSN sent", err)
|
c.log.Errorx("queuing DSN for incoming delivery, no DSN sent", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,8 @@ import (
|
||||||
"github.com/mjl-/mox/tlsrptdb"
|
"github.com/mjl-/mox/tlsrptdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Don't make tests slow.
|
// Don't make tests slow.
|
||||||
badClientDelay = 0
|
badClientDelay = 0
|
||||||
|
@ -88,7 +90,7 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test
|
||||||
|
|
||||||
ts := testserver{t: t, cid: 1, resolver: resolver, tlsmode: smtpclient.TLSOpportunistic}
|
ts := testserver{t: t, cid: 1, resolver: resolver, tlsmode: smtpclient.TLSOpportunistic}
|
||||||
|
|
||||||
mox.Context = context.Background()
|
mox.Context = ctxbg
|
||||||
mox.ConfigStaticPath = configPath
|
mox.ConfigStaticPath = configPath
|
||||||
mox.MustLoadConfig(false)
|
mox.MustLoadConfig(false)
|
||||||
dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir)
|
dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir)
|
||||||
|
@ -138,7 +140,7 @@ func (ts *testserver) run(fn func(helloErr error, client *smtpclient.Client)) {
|
||||||
authLine = fmt.Sprintf("AUTH PLAIN %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\u0000%s\u0000%s", ts.user, ts.pass))))
|
authLine = fmt.Sprintf("AUTH PLAIN %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\u0000%s\u0000%s", ts.user, ts.pass))))
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := smtpclient.New(context.Background(), xlog.WithCid(ts.cid-1), clientConn, ts.tlsmode, "mox.example", authLine)
|
client, err := smtpclient.New(ctxbg, xlog.WithCid(ts.cid-1), clientConn, ts.tlsmode, "mox.example", authLine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clientConn.Close()
|
clientConn.Close()
|
||||||
} else {
|
} else {
|
||||||
|
@ -199,7 +201,7 @@ func TestSubmission(t *testing.T) {
|
||||||
mailFrom := "mjl@mox.example"
|
mailFrom := "mjl@mox.example"
|
||||||
rcptTo := "remote@example.org"
|
rcptTo := "remote@example.org"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
||||||
|
@ -230,7 +232,7 @@ func TestDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@127.0.0.10"
|
rcptTo := "mjl@127.0.0.10"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
||||||
|
@ -242,7 +244,7 @@ func TestDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@test.example" // Not configured as destination.
|
rcptTo := "mjl@test.example" // Not configured as destination.
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
||||||
|
@ -254,7 +256,7 @@ func TestDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "unknown@mox.example" // User unknown.
|
rcptTo := "unknown@mox.example" // User unknown.
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
||||||
|
@ -266,7 +268,7 @@ func TestDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
||||||
|
@ -280,7 +282,7 @@ func TestDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver to remote")
|
tcheck(t, err, "deliver to remote")
|
||||||
|
|
||||||
|
@ -319,12 +321,12 @@ func tretrain(t *testing.T, acc *store.Account) {
|
||||||
bloomPath := filepath.Join(basePath, acc.Name, "junkfilter.bloom")
|
bloomPath := filepath.Join(basePath, acc.Name, "junkfilter.bloom")
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
os.Remove(bloomPath)
|
os.Remove(bloomPath)
|
||||||
jf, _, err := acc.OpenJunkFilter(xlog)
|
jf, _, err := acc.OpenJunkFilter(ctxbg, xlog)
|
||||||
tcheck(t, err, "open junk filter")
|
tcheck(t, err, "open junk filter")
|
||||||
defer jf.Close()
|
defer jf.Close()
|
||||||
|
|
||||||
// Fetch messags to retrain on.
|
// Fetch messags to retrain on.
|
||||||
q := bstore.QueryDB[store.Message](acc.DB)
|
q := bstore.QueryDB[store.Message](ctxbg, acc.DB)
|
||||||
q.FilterFn(func(m store.Message) bool {
|
q.FilterFn(func(m store.Message) bool {
|
||||||
return m.Flags.Junk || m.Flags.Notjunk
|
return m.Flags.Junk || m.Flags.Notjunk
|
||||||
})
|
})
|
||||||
|
@ -339,7 +341,7 @@ func tretrain(t *testing.T, acc *store.Account) {
|
||||||
tcheck(t, err, "open message")
|
tcheck(t, err, "open message")
|
||||||
r := store.FileMsgReader(m.MsgPrefix, f)
|
r := store.FileMsgReader(m.MsgPrefix, f)
|
||||||
|
|
||||||
jf.TrainMessage(r, m.Size, ham)
|
jf.TrainMessage(ctxbg, r, m.Size, ham)
|
||||||
|
|
||||||
err = r.Close()
|
err = r.Close()
|
||||||
tcheck(t, err, "close message")
|
tcheck(t, err, "close message")
|
||||||
|
@ -388,11 +390,11 @@ func TestSpam(t *testing.T) {
|
||||||
|
|
||||||
checkRejectsCount := func(expect int) {
|
checkRejectsCount := func(expect int) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
q := bstore.QueryDB[store.Mailbox](ts.acc.DB)
|
q := bstore.QueryDB[store.Mailbox](ctxbg, ts.acc.DB)
|
||||||
q.FilterNonzero(store.Mailbox{Name: "Rejects"})
|
q.FilterNonzero(store.Mailbox{Name: "Rejects"})
|
||||||
mb, err := q.Get()
|
mb, err := q.Get()
|
||||||
tcheck(t, err, "get rejects mailbox")
|
tcheck(t, err, "get rejects mailbox")
|
||||||
qm := bstore.QueryDB[store.Message](ts.acc.DB)
|
qm := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB)
|
||||||
qm.FilterNonzero(store.Message{MailboxID: mb.ID})
|
qm.FilterNonzero(store.Message{MailboxID: mb.ID})
|
||||||
n, err := qm.Count()
|
n, err := qm.Count()
|
||||||
tcheck(t, err, "count messages in rejects mailbox")
|
tcheck(t, err, "count messages in rejects mailbox")
|
||||||
|
@ -406,7 +408,7 @@ func TestSpam(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
||||||
|
@ -418,7 +420,7 @@ func TestSpam(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Mark the messages as having good reputation.
|
// Mark the messages as having good reputation.
|
||||||
q := bstore.QueryDB[store.Message](ts.acc.DB)
|
q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB)
|
||||||
_, err := q.UpdateFields(map[string]any{"Junk": false, "Notjunk": true})
|
_, err := q.UpdateFields(map[string]any{"Junk": false, "Notjunk": true})
|
||||||
tcheck(t, err, "update junkiness")
|
tcheck(t, err, "update junkiness")
|
||||||
|
|
||||||
|
@ -427,7 +429,7 @@ func TestSpam(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver")
|
tcheck(t, err, "deliver")
|
||||||
|
|
||||||
|
@ -437,7 +439,7 @@ func TestSpam(t *testing.T) {
|
||||||
|
|
||||||
// Undo dmarc pass, mark messages as junk, and train the filter.
|
// Undo dmarc pass, mark messages as junk, and train the filter.
|
||||||
resolver.TXT = nil
|
resolver.TXT = nil
|
||||||
q = bstore.QueryDB[store.Message](ts.acc.DB)
|
q = bstore.QueryDB[store.Message](ctxbg, ts.acc.DB)
|
||||||
_, err = q.UpdateFields(map[string]any{"Junk": true, "Notjunk": false})
|
_, err = q.UpdateFields(map[string]any{"Junk": true, "Notjunk": false})
|
||||||
tcheck(t, err, "update junkiness")
|
tcheck(t, err, "update junkiness")
|
||||||
tretrain(t, ts.acc)
|
tretrain(t, ts.acc)
|
||||||
|
@ -447,7 +449,7 @@ func TestSpam(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
||||||
|
@ -488,7 +490,7 @@ func TestDMARCSent(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
||||||
|
@ -499,7 +501,7 @@ func TestDMARCSent(t *testing.T) {
|
||||||
// Insert a message that we sent to the address that is about to send to us.
|
// Insert a message that we sent to the address that is about to send to us.
|
||||||
var sentMsg store.Message
|
var sentMsg store.Message
|
||||||
tinsertmsg(t, ts.acc, "Sent", &sentMsg, deliverMessage)
|
tinsertmsg(t, ts.acc, "Sent", &sentMsg, deliverMessage)
|
||||||
err := ts.acc.DB.Insert(&store.Recipient{MessageID: sentMsg.ID, Localpart: "remote", Domain: "example.org", OrgDomain: "example.org", Sent: time.Now()})
|
err := ts.acc.DB.Insert(ctxbg, &store.Recipient{MessageID: sentMsg.ID, Localpart: "remote", Domain: "example.org", OrgDomain: "example.org", Sent: time.Now()})
|
||||||
tcheck(t, err, "inserting message recipient")
|
tcheck(t, err, "inserting message recipient")
|
||||||
|
|
||||||
// We should now be accepting the message because we recently sent a message.
|
// We should now be accepting the message because we recently sent a message.
|
||||||
|
@ -507,7 +509,7 @@ func TestDMARCSent(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver")
|
tcheck(t, err, "deliver")
|
||||||
})
|
})
|
||||||
|
@ -540,7 +542,7 @@ func TestBlocklistedSubjectpass(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C451LocalErr {
|
||||||
|
@ -559,7 +561,7 @@ func TestBlocklistedSubjectpass(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C550MailboxUnavail {
|
||||||
|
@ -577,7 +579,7 @@ func TestBlocklistedSubjectpass(t *testing.T) {
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
passMessage := strings.Replace(deliverMessage, "Subject: test", "Subject: test "+pass, 1)
|
passMessage := strings.Replace(deliverMessage, "Subject: test", "Subject: test "+pass, 1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(passMessage)), strings.NewReader(passMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(passMessage)), strings.NewReader(passMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver with subjectpass")
|
tcheck(t, err, "deliver with subjectpass")
|
||||||
})
|
})
|
||||||
|
@ -619,11 +621,11 @@ func TestDMARCReport(t *testing.T) {
|
||||||
msg := msgb.String()
|
msg := msgb.String()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver")
|
tcheck(t, err, "deliver")
|
||||||
|
|
||||||
records, err := dmarcdb.Records(context.Background())
|
records, err := dmarcdb.Records(ctxbg)
|
||||||
tcheck(t, err, "dmarcdb records")
|
tcheck(t, err, "dmarcdb records")
|
||||||
if len(records) != n {
|
if len(records) != n {
|
||||||
t.Fatalf("got %d dmarcdb records, expected %d or more", len(records), n)
|
t.Fatalf("got %d dmarcdb records, expected %d or more", len(records), n)
|
||||||
|
@ -736,16 +738,16 @@ func TestTLSReport(t *testing.T) {
|
||||||
tcheck(t, xerr, "write msg")
|
tcheck(t, xerr, "write msg")
|
||||||
msg := msgb.String()
|
msg := msgb.String()
|
||||||
|
|
||||||
headers, xerr := dkim.Sign(context.Background(), "remote", dns.Domain{ASCII: "example.org"}, dkimConf, false, strings.NewReader(msg))
|
headers, xerr := dkim.Sign(ctxbg, "remote", dns.Domain{ASCII: "example.org"}, dkimConf, false, strings.NewReader(msg))
|
||||||
tcheck(t, xerr, "dkim sign")
|
tcheck(t, xerr, "dkim sign")
|
||||||
msg = headers + msg
|
msg = headers + msg
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver")
|
tcheck(t, err, "deliver")
|
||||||
|
|
||||||
records, err := tlsrptdb.Records(context.Background())
|
records, err := tlsrptdb.Records(ctxbg)
|
||||||
tcheck(t, err, "tlsrptdb records")
|
tcheck(t, err, "tlsrptdb records")
|
||||||
if len(records) != n {
|
if len(records) != n {
|
||||||
t.Fatalf("got %d tlsrptdb records, expected %d", len(records), n)
|
t.Fatalf("got %d tlsrptdb records, expected %d", len(records), n)
|
||||||
|
@ -840,11 +842,11 @@ func TestRatelimitDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver to remote")
|
tcheck(t, err, "deliver to remote")
|
||||||
|
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C452StorageFull {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C452StorageFull {
|
||||||
t.Fatalf("got err %v, expected smtpclient error with code 452 for storage full", err)
|
t.Fatalf("got err %v, expected smtpclient error with code 452 for storage full", err)
|
||||||
|
@ -857,7 +859,7 @@ func TestRatelimitDelivery(t *testing.T) {
|
||||||
// Message was already delivered once. We'll do another one. But the 3rd will fail.
|
// Message was already delivered once. We'll do another one. But the 3rd will fail.
|
||||||
// We need the actual size with prepended headers, since that is used in the
|
// We need the actual size with prepended headers, since that is used in the
|
||||||
// calculations.
|
// calculations.
|
||||||
msg, err := bstore.QueryDB[store.Message](ts.acc.DB).Get()
|
msg, err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getting delivered message for its size: %v", err)
|
t.Fatalf("getting delivered message for its size: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -869,11 +871,11 @@ func TestRatelimitDelivery(t *testing.T) {
|
||||||
mailFrom := "remote@example.org"
|
mailFrom := "remote@example.org"
|
||||||
rcptTo := "mjl@mox.example"
|
rcptTo := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver to remote")
|
tcheck(t, err, "deliver to remote")
|
||||||
|
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C452StorageFull {
|
if err == nil || !errors.As(err, &cerr) || cerr.Code != smtp.C452StorageFull {
|
||||||
t.Fatalf("got err %v, expected smtpclient error with code 452 for storage full", err)
|
t.Fatalf("got err %v, expected smtpclient error with code 452 for storage full", err)
|
||||||
|
@ -933,7 +935,7 @@ func TestLimitOutgoing(t *testing.T) {
|
||||||
ts.pass = "testtest"
|
ts.pass = "testtest"
|
||||||
ts.submission = true
|
ts.submission = true
|
||||||
|
|
||||||
err := ts.acc.DB.Insert(&store.Outgoing{Recipient: "a@other.example", Submitted: time.Now().Add(-24*time.Hour - time.Minute)})
|
err := ts.acc.DB.Insert(ctxbg, &store.Outgoing{Recipient: "a@other.example", Submitted: time.Now().Add(-24*time.Hour - time.Minute)})
|
||||||
tcheck(t, err, "inserting outgoing/recipient past 24h window")
|
tcheck(t, err, "inserting outgoing/recipient past 24h window")
|
||||||
|
|
||||||
testSubmit := func(rcptTo string, expErr *smtpclient.Error) {
|
testSubmit := func(rcptTo string, expErr *smtpclient.Error) {
|
||||||
|
@ -942,7 +944,7 @@ func TestLimitOutgoing(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
mailFrom := "mjl@mox.example"
|
mailFrom := "mjl@mox.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
||||||
|
@ -979,7 +981,7 @@ func TestCatchall(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
mailFrom := "mjl@other.example"
|
mailFrom := "mjl@other.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(submitMessage)), strings.NewReader(submitMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Secode != expErr.Secode) {
|
||||||
|
@ -993,14 +995,14 @@ func TestCatchall(t *testing.T) {
|
||||||
testDeliver("MJL+TEST@mox.example", nil) // Again, and case insensitive.
|
testDeliver("MJL+TEST@mox.example", nil) // Again, and case insensitive.
|
||||||
testDeliver("unknown@mox.example", nil) // Catchall address, to account catchall.
|
testDeliver("unknown@mox.example", nil) // Catchall address, to account catchall.
|
||||||
|
|
||||||
n, err := bstore.QueryDB[store.Message](ts.acc.DB).Count()
|
n, err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).Count()
|
||||||
tcheck(t, err, "checking delivered messages")
|
tcheck(t, err, "checking delivered messages")
|
||||||
tcompare(t, n, 3)
|
tcompare(t, n, 3)
|
||||||
|
|
||||||
acc, err := store.OpenAccount("catchall")
|
acc, err := store.OpenAccount("catchall")
|
||||||
tcheck(t, err, "open account")
|
tcheck(t, err, "open account")
|
||||||
defer acc.Close()
|
defer acc.Close()
|
||||||
n, err = bstore.QueryDB[store.Message](acc.DB).Count()
|
n, err = bstore.QueryDB[store.Message](ctxbg, acc.DB).Count()
|
||||||
tcheck(t, err, "checking delivered messages to catchall account")
|
tcheck(t, err, "checking delivered messages to catchall account")
|
||||||
tcompare(t, n, 1)
|
tcompare(t, n, 1)
|
||||||
}
|
}
|
||||||
|
@ -1072,21 +1074,21 @@ test email
|
||||||
|
|
||||||
rcptTo := "remote@example.org"
|
rcptTo := "remote@example.org"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(msg)), strings.NewReader(msg), false, false)
|
||||||
}
|
}
|
||||||
tcheck(t, err, "deliver")
|
tcheck(t, err, "deliver")
|
||||||
|
|
||||||
msgs, err := queue.List()
|
msgs, err := queue.List(ctxbg)
|
||||||
tcheck(t, err, "listing queue")
|
tcheck(t, err, "listing queue")
|
||||||
n++
|
n++
|
||||||
tcompare(t, len(msgs), n)
|
tcompare(t, len(msgs), n)
|
||||||
sort.Slice(msgs, func(i, j int) bool {
|
sort.Slice(msgs, func(i, j int) bool {
|
||||||
return msgs[i].ID > msgs[j].ID
|
return msgs[i].ID > msgs[j].ID
|
||||||
})
|
})
|
||||||
f, err := queue.OpenMessage(msgs[0].ID)
|
f, err := queue.OpenMessage(ctxbg, msgs[0].ID)
|
||||||
tcheck(t, err, "open message in queue")
|
tcheck(t, err, "open message in queue")
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
results, err := dkim.Verify(context.Background(), resolver, false, dkim.DefaultPolicy, f, false)
|
results, err := dkim.Verify(ctxbg, resolver, false, dkim.DefaultPolicy, f, false)
|
||||||
tcheck(t, err, "verifying dkim message")
|
tcheck(t, err, "verifying dkim message")
|
||||||
tcompare(t, len(results), 1)
|
tcompare(t, len(results), 1)
|
||||||
tcompare(t, results[0].Status, dkim.StatusPass)
|
tcompare(t, results[0].Status, dkim.StatusPass)
|
||||||
|
@ -1117,7 +1119,7 @@ func TestPostmaster(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
mailFrom := "mjl@other.example"
|
mailFrom := "mjl@other.example"
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = client.Deliver(context.Background(), mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(deliverMessage)), strings.NewReader(deliverMessage), false, false)
|
||||||
}
|
}
|
||||||
var cerr smtpclient.Error
|
var cerr smtpclient.Error
|
||||||
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Code != expErr.Code || cerr.Secode != expErr.Secode) {
|
if expErr == nil && err != nil || expErr != nil && (err == nil || !errors.As(err, &cerr) || cerr.Code != expErr.Code || cerr.Secode != expErr.Secode) {
|
||||||
|
|
|
@ -278,7 +278,7 @@ type Message struct {
|
||||||
MsgFromValidation Validation // Desirable validations: Strict, DMARC, Relaxed. Will not be just Pass.
|
MsgFromValidation Validation // Desirable validations: Strict, DMARC, Relaxed. Will not be just Pass.
|
||||||
|
|
||||||
// todo: needs an "in" index, which bstore does not yet support. for performance while checking reputation.
|
// todo: needs an "in" index, which bstore does not yet support. for performance while checking reputation.
|
||||||
DKIMDomains []string // Domains with verified DKIM signatures. Unicode string.
|
DKIMDomains []string `bstore:"index DKIMDomains+Received"` // Domains with verified DKIM signatures. Unicode string.
|
||||||
|
|
||||||
// Value of Message-Id header. Only set for messages that were
|
// Value of Message-Id header. Only set for messages that were
|
||||||
// delivered to the rejects mailbox. For ensuring such messages are
|
// delivered to the rejects mailbox. For ensuring such messages are
|
||||||
|
@ -455,7 +455,7 @@ func openAccount(name string) (a *Account, rerr error) {
|
||||||
os.MkdirAll(dir, 0770)
|
os.MkdirAll(dir, 0770)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := bstore.Open(dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, NextUIDValidity{}, Message{}, Recipient{}, Mailbox{}, Subscription{}, Outgoing{}, Password{}, Subjectpass{})
|
db, err := bstore.Open(context.TODO(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, NextUIDValidity{}, Message{}, Recipient{}, Mailbox{}, Subscription{}, Outgoing{}, Password{}, Subjectpass{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -484,7 +484,7 @@ func openAccount(name string) (a *Account, rerr error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func initAccount(db *bstore.DB) error {
|
func initAccount(db *bstore.DB) error {
|
||||||
return db.Write(func(tx *bstore.Tx) error {
|
return db.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
uidvalidity := InitialUIDValidity()
|
uidvalidity := InitialUIDValidity()
|
||||||
|
|
||||||
mailboxes := InitialMailboxes
|
mailboxes := InitialMailboxes
|
||||||
|
@ -700,7 +700,7 @@ func (a *Account) DeliverMessage(log *mlog.Log, tx *bstore.Tx, m *Message, msgFi
|
||||||
|
|
||||||
if !notrain && m.NeedsTraining() {
|
if !notrain && m.NeedsTraining() {
|
||||||
l := []Message{*m}
|
l := []Message{*m}
|
||||||
if err := a.RetrainMessages(log, tx, l, false); err != nil {
|
if err := a.RetrainMessages(context.TODO(), log, tx, l, false); err != nil {
|
||||||
return fmt.Errorf("training junkfilter: %w", err)
|
return fmt.Errorf("training junkfilter: %w", err)
|
||||||
}
|
}
|
||||||
*m = l[0]
|
*m = l[0]
|
||||||
|
@ -739,7 +739,7 @@ func (a *Account) SetPassword(password string) error {
|
||||||
return fmt.Errorf("generating password hash: %w", err)
|
return fmt.Errorf("generating password hash: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.DB.Write(func(tx *bstore.Tx) error {
|
err = a.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
if _, err := bstore.QueryTx[Password](tx).Delete(); err != nil {
|
if _, err := bstore.QueryTx[Password](tx).Delete(); err != nil {
|
||||||
return fmt.Errorf("deleting existing password: %v", err)
|
return fmt.Errorf("deleting existing password: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -793,7 +793,7 @@ func (a *Account) SetPassword(password string) error {
|
||||||
// Subjectpass returns the signing key for use with subjectpass for the given
|
// Subjectpass returns the signing key for use with subjectpass for the given
|
||||||
// email address with canonical localpart.
|
// email address with canonical localpart.
|
||||||
func (a *Account) Subjectpass(email string) (key string, err error) {
|
func (a *Account) Subjectpass(email string) (key string, err error) {
|
||||||
return key, a.DB.Write(func(tx *bstore.Tx) error {
|
return key, a.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
v := Subjectpass{Email: email}
|
v := Subjectpass{Email: email}
|
||||||
err := tx.Get(&v)
|
err := tx.Get(&v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -1036,7 +1036,7 @@ func (a *Account) Deliver(log *mlog.Log, dest config.Destination, m *Message, ms
|
||||||
// Message delivery and possible mailbox creation are broadcasted.
|
// Message delivery and possible mailbox creation are broadcasted.
|
||||||
func (a *Account) DeliverMailbox(log *mlog.Log, mailbox string, m *Message, msgFile *os.File, consumeFile bool) error {
|
func (a *Account) DeliverMailbox(log *mlog.Log, mailbox string, m *Message, msgFile *os.File, consumeFile bool) error {
|
||||||
var changes []Change
|
var changes []Change
|
||||||
err := a.DB.Write(func(tx *bstore.Tx) error {
|
err := a.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
mb, chl, err := a.MailboxEnsure(tx, mailbox, true)
|
mb, chl, err := a.MailboxEnsure(tx, mailbox, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ensuring mailbox: %w", err)
|
return fmt.Errorf("ensuring mailbox: %w", err)
|
||||||
|
@ -1075,7 +1075,7 @@ func (a *Account) TidyRejectsMailbox(log *mlog.Log, rejectsMailbox string) (hasS
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := a.DB.Write(func(tx *bstore.Tx) error {
|
err := a.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
mb, err := a.MailboxFind(tx, rejectsMailbox)
|
mb, err := a.MailboxFind(tx, rejectsMailbox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("finding mailbox: %w", err)
|
return fmt.Errorf("finding mailbox: %w", err)
|
||||||
|
@ -1096,7 +1096,7 @@ func (a *Account) TidyRejectsMailbox(log *mlog.Log, rejectsMailbox string) (hasS
|
||||||
return fmt.Errorf("listing old messages: %w", err)
|
return fmt.Errorf("listing old messages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
changes, err = a.removeMessages(log, tx, mb, remove)
|
changes, err = a.removeMessages(context.TODO(), log, tx, mb, remove)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("removing messages: %w", err)
|
return fmt.Errorf("removing messages: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1125,7 +1125,7 @@ func (a *Account) TidyRejectsMailbox(log *mlog.Log, rejectsMailbox string) (hasS
|
||||||
return hasSpace, nil
|
return hasSpace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) removeMessages(log *mlog.Log, tx *bstore.Tx, mb *Mailbox, l []Message) ([]Change, error) {
|
func (a *Account) removeMessages(ctx context.Context, log *mlog.Log, tx *bstore.Tx, mb *Mailbox, l []Message) ([]Change, error) {
|
||||||
if len(l) == 0 {
|
if len(l) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -1158,7 +1158,7 @@ func (a *Account) removeMessages(log *mlog.Log, tx *bstore.Tx, mb *Mailbox, l []
|
||||||
deleted[i].Junk = false
|
deleted[i].Junk = false
|
||||||
deleted[i].Notjunk = false
|
deleted[i].Notjunk = false
|
||||||
}
|
}
|
||||||
if err := a.RetrainMessages(log, tx, deleted, true); err != nil {
|
if err := a.RetrainMessages(ctx, log, tx, deleted, true); err != nil {
|
||||||
return nil, fmt.Errorf("training deleted messages: %w", err)
|
return nil, fmt.Errorf("training deleted messages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1184,7 +1184,7 @@ func (a *Account) RejectsRemove(log *mlog.Log, rejectsMailbox, messageID string)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := a.DB.Write(func(tx *bstore.Tx) error {
|
err := a.DB.Write(context.TODO(), func(tx *bstore.Tx) error {
|
||||||
mb, err := a.MailboxFind(tx, rejectsMailbox)
|
mb, err := a.MailboxFind(tx, rejectsMailbox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("finding mailbox: %w", err)
|
return fmt.Errorf("finding mailbox: %w", err)
|
||||||
|
@ -1200,7 +1200,7 @@ func (a *Account) RejectsRemove(log *mlog.Log, rejectsMailbox, messageID string)
|
||||||
return fmt.Errorf("listing messages to remove: %w", err)
|
return fmt.Errorf("listing messages to remove: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
changes, err = a.removeMessages(log, tx, mb, remove)
|
changes, err = a.removeMessages(context.TODO(), log, tx, mb, remove)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("removing messages: %w", err)
|
return fmt.Errorf("removing messages: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1262,7 +1262,7 @@ func OpenEmailAuth(email string, password string) (acc *Account, rerr error) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
pw, err := bstore.QueryDB[Password](acc.DB).Get()
|
pw, err := bstore.QueryDB[Password](context.TODO(), acc.DB).Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == bstore.ErrAbsent {
|
if err == bstore.ErrAbsent {
|
||||||
return acc, ErrUnknownCredentials
|
return acc, ErrUnknownCredentials
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -16,6 +17,8 @@ import (
|
||||||
"github.com/mjl-/mox/mox-"
|
"github.com/mjl-/mox/mox-"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
func tcheck(t *testing.T, err error, msg string) {
|
func tcheck(t *testing.T, err error, msg string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -66,7 +69,7 @@ func TestMailbox(t *testing.T) {
|
||||||
err := acc.Deliver(xlog, conf.Destinations["mjl"], &m, msgFile, false)
|
err := acc.Deliver(xlog, conf.Destinations["mjl"], &m, msgFile, false)
|
||||||
tcheck(t, err, "deliver without consume")
|
tcheck(t, err, "deliver without consume")
|
||||||
|
|
||||||
err = acc.DB.Write(func(tx *bstore.Tx) error {
|
err = acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
mbsent, err = bstore.QueryTx[Mailbox](tx).FilterNonzero(Mailbox{Name: "Sent"}).Get()
|
mbsent, err = bstore.QueryTx[Mailbox](tx).FilterNonzero(Mailbox{Name: "Sent"}).Get()
|
||||||
tcheck(t, err, "sent mailbox")
|
tcheck(t, err, "sent mailbox")
|
||||||
|
@ -89,10 +92,10 @@ func TestMailbox(t *testing.T) {
|
||||||
err = acc.Deliver(xlog, conf.Destinations["mjl"], &mconsumed, msgFile, true)
|
err = acc.Deliver(xlog, conf.Destinations["mjl"], &mconsumed, msgFile, true)
|
||||||
tcheck(t, err, "deliver with consume")
|
tcheck(t, err, "deliver with consume")
|
||||||
|
|
||||||
err = acc.DB.Write(func(tx *bstore.Tx) error {
|
err = acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
m.Junk = true
|
m.Junk = true
|
||||||
l := []Message{m}
|
l := []Message{m}
|
||||||
err = acc.RetrainMessages(log, tx, l, false)
|
err = acc.RetrainMessages(ctxbg, log, tx, l, false)
|
||||||
tcheck(t, err, "train as junk")
|
tcheck(t, err, "train as junk")
|
||||||
m = l[0]
|
m = l[0]
|
||||||
return nil
|
return nil
|
||||||
|
@ -102,18 +105,18 @@ func TestMailbox(t *testing.T) {
|
||||||
|
|
||||||
m.Junk = false
|
m.Junk = false
|
||||||
m.Notjunk = true
|
m.Notjunk = true
|
||||||
jf, _, err := acc.OpenJunkFilter(log)
|
jf, _, err := acc.OpenJunkFilter(ctxbg, log)
|
||||||
tcheck(t, err, "open junk filter")
|
tcheck(t, err, "open junk filter")
|
||||||
err = acc.DB.Write(func(tx *bstore.Tx) error {
|
err = acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
return acc.RetrainMessage(log, tx, jf, &m, false)
|
return acc.RetrainMessage(ctxbg, log, tx, jf, &m, false)
|
||||||
})
|
})
|
||||||
tcheck(t, err, "retraining as non-junk")
|
tcheck(t, err, "retraining as non-junk")
|
||||||
err = jf.Close()
|
err = jf.Close()
|
||||||
tcheck(t, err, "close junk filter")
|
tcheck(t, err, "close junk filter")
|
||||||
|
|
||||||
m.Notjunk = false
|
m.Notjunk = false
|
||||||
err = acc.DB.Write(func(tx *bstore.Tx) error {
|
err = acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
return acc.RetrainMessages(log, tx, []Message{m}, false)
|
return acc.RetrainMessages(ctxbg, log, tx, []Message{m}, false)
|
||||||
})
|
})
|
||||||
tcheck(t, err, "untraining non-junk")
|
tcheck(t, err, "untraining non-junk")
|
||||||
|
|
||||||
|
@ -134,18 +137,18 @@ func TestMailbox(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
acc.WithWLock(func() {
|
acc.WithWLock(func() {
|
||||||
err := acc.DB.Write(func(tx *bstore.Tx) error {
|
err := acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
_, _, err := acc.MailboxEnsure(tx, "Testbox", true)
|
_, _, err := acc.MailboxEnsure(tx, "Testbox", true)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
tcheck(t, err, "ensure mailbox exists")
|
tcheck(t, err, "ensure mailbox exists")
|
||||||
err = acc.DB.Read(func(tx *bstore.Tx) error {
|
err = acc.DB.Read(ctxbg, func(tx *bstore.Tx) error {
|
||||||
_, _, err := acc.MailboxEnsure(tx, "Testbox", true)
|
_, _, err := acc.MailboxEnsure(tx, "Testbox", true)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
tcheck(t, err, "ensure mailbox exists")
|
tcheck(t, err, "ensure mailbox exists")
|
||||||
|
|
||||||
err = acc.DB.Write(func(tx *bstore.Tx) error {
|
err = acc.DB.Write(ctxbg, func(tx *bstore.Tx) error {
|
||||||
_, _, err := acc.MailboxEnsure(tx, "Testbox2", false)
|
_, _, err := acc.MailboxEnsure(tx, "Testbox2", false)
|
||||||
tcheck(t, err, "create mailbox")
|
tcheck(t, err, "create mailbox")
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
@ -104,13 +105,13 @@ func (a DirArchiver) Close() error {
|
||||||
// Some errors are not fatal and result in skipped messages. In that happens, a
|
// Some errors are not fatal and result in skipped messages. In that happens, a
|
||||||
// file "errors.txt" is added to the archive describing the errors. The goal is to
|
// file "errors.txt" is added to the archive describing the errors. The goal is to
|
||||||
// let users export (hopefully) most messages even in the face of errors.
|
// let users export (hopefully) most messages even in the face of errors.
|
||||||
func ExportMessages(log *mlog.Log, db *bstore.DB, accountDir string, archiver Archiver, maildir bool, mailboxOpt string) error {
|
func ExportMessages(ctx context.Context, log *mlog.Log, db *bstore.DB, accountDir string, archiver Archiver, maildir bool, mailboxOpt string) error {
|
||||||
// todo optimize: should prepare next file to add to archive (can be an mbox with many messages) while writing a file to the archive (which typically compresses, which takes time).
|
// todo optimize: should prepare next file to add to archive (can be an mbox with many messages) while writing a file to the archive (which typically compresses, which takes time).
|
||||||
|
|
||||||
// Start transaction without closure, we are going to close it early, but don't
|
// Start transaction without closure, we are going to close it early, but don't
|
||||||
// want to deal with declaring many variables now to be able to assign them in a
|
// want to deal with declaring many variables now to be able to assign them in a
|
||||||
// closure and use them afterwards.
|
// closure and use them afterwards.
|
||||||
tx, err := db.Begin(false)
|
tx, err := db.Begin(ctx, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("transaction: %v", err)
|
return fmt.Errorf("transaction: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ func TestExport(t *testing.T) {
|
||||||
|
|
||||||
archive := func(archiver Archiver, maildir bool) {
|
archive := func(archiver Archiver, maildir bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
err = ExportMessages(log, acc.DB, acc.Dir, archiver, maildir, "")
|
err = ExportMessages(ctxbg, log, acc.DB, acc.Dir, archiver, maildir, "")
|
||||||
tcheck(t, err, "export messages")
|
tcheck(t, err, "export messages")
|
||||||
err = archiver.Close()
|
err = archiver.Close()
|
||||||
tcheck(t, err, "archiver close")
|
tcheck(t, err, "archiver close")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -21,7 +22,7 @@ var ErrNoJunkFilter = errors.New("junkfilter: not configured")
|
||||||
// If the account does not have a junk filter enabled, ErrNotConfigured is returned.
|
// If the account does not have a junk filter enabled, ErrNotConfigured is returned.
|
||||||
// Do not forget to save the filter after modifying, and to always close the filter when done.
|
// Do not forget to save the filter after modifying, and to always close the filter when done.
|
||||||
// An empty filter is initialized on first access of the filter.
|
// An empty filter is initialized on first access of the filter.
|
||||||
func (a *Account) OpenJunkFilter(log *mlog.Log) (*junk.Filter, *config.JunkFilter, error) {
|
func (a *Account) OpenJunkFilter(ctx context.Context, log *mlog.Log) (*junk.Filter, *config.JunkFilter, error) {
|
||||||
conf, ok := mox.Conf.Account(a.Name)
|
conf, ok := mox.Conf.Account(a.Name)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, ErrAccountUnknown
|
return nil, nil, ErrAccountUnknown
|
||||||
|
@ -36,16 +37,16 @@ func (a *Account) OpenJunkFilter(log *mlog.Log) (*junk.Filter, *config.JunkFilte
|
||||||
bloomPath := filepath.Join(basePath, a.Name, "junkfilter.bloom")
|
bloomPath := filepath.Join(basePath, a.Name, "junkfilter.bloom")
|
||||||
|
|
||||||
if _, xerr := os.Stat(dbPath); xerr != nil && os.IsNotExist(xerr) {
|
if _, xerr := os.Stat(dbPath); xerr != nil && os.IsNotExist(xerr) {
|
||||||
f, err := junk.NewFilter(log, jf.Params, dbPath, bloomPath)
|
f, err := junk.NewFilter(ctx, log, jf.Params, dbPath, bloomPath)
|
||||||
return f, jf, err
|
return f, jf, err
|
||||||
}
|
}
|
||||||
f, err := junk.OpenFilter(log, jf.Params, dbPath, bloomPath, false)
|
f, err := junk.OpenFilter(ctx, log, jf.Params, dbPath, bloomPath, false)
|
||||||
return f, jf, err
|
return f, jf, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetrainMessages (un)trains messages, if relevant given their flags. Updates
|
// RetrainMessages (un)trains messages, if relevant given their flags. Updates
|
||||||
// m.TrainedJunk after retraining.
|
// m.TrainedJunk after retraining.
|
||||||
func (a *Account) RetrainMessages(log *mlog.Log, tx *bstore.Tx, msgs []Message, absentOK bool) (rerr error) {
|
func (a *Account) RetrainMessages(ctx context.Context, log *mlog.Log, tx *bstore.Tx, msgs []Message, absentOK bool) (rerr error) {
|
||||||
if len(msgs) == 0 {
|
if len(msgs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -60,7 +61,7 @@ func (a *Account) RetrainMessages(log *mlog.Log, tx *bstore.Tx, msgs []Message,
|
||||||
// Lazy open the junk filter.
|
// Lazy open the junk filter.
|
||||||
if jf == nil {
|
if jf == nil {
|
||||||
var err error
|
var err error
|
||||||
jf, _, err = a.OpenJunkFilter(log)
|
jf, _, err = a.OpenJunkFilter(ctx, log)
|
||||||
if err != nil && errors.Is(err, ErrNoJunkFilter) {
|
if err != nil && errors.Is(err, ErrNoJunkFilter) {
|
||||||
// No junk filter configured. Nothing more to do.
|
// No junk filter configured. Nothing more to do.
|
||||||
return nil
|
return nil
|
||||||
|
@ -76,7 +77,7 @@ func (a *Account) RetrainMessages(log *mlog.Log, tx *bstore.Tx, msgs []Message,
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if err := a.RetrainMessage(log, tx, jf, &msgs[i], absentOK); err != nil {
|
if err := a.RetrainMessage(ctx, log, tx, jf, &msgs[i], absentOK); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,7 +86,7 @@ func (a *Account) RetrainMessages(log *mlog.Log, tx *bstore.Tx, msgs []Message,
|
||||||
|
|
||||||
// RetrainMessage untrains and/or trains a message, if relevant given m.TrainedJunk
|
// RetrainMessage untrains and/or trains a message, if relevant given m.TrainedJunk
|
||||||
// and m.Junk/m.Notjunk. Updates m.TrainedJunk after retraining.
|
// and m.Junk/m.Notjunk. Updates m.TrainedJunk after retraining.
|
||||||
func (a *Account) RetrainMessage(log *mlog.Log, tx *bstore.Tx, jf *junk.Filter, m *Message, absentOK bool) error {
|
func (a *Account) RetrainMessage(ctx context.Context, log *mlog.Log, tx *bstore.Tx, jf *junk.Filter, m *Message, absentOK bool) error {
|
||||||
untrain := m.TrainedJunk != nil
|
untrain := m.TrainedJunk != nil
|
||||||
untrainJunk := untrain && *m.TrainedJunk
|
untrainJunk := untrain && *m.TrainedJunk
|
||||||
train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk)
|
train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk)
|
||||||
|
@ -116,14 +117,14 @@ func (a *Account) RetrainMessage(log *mlog.Log, tx *bstore.Tx, jf *junk.Filter,
|
||||||
}
|
}
|
||||||
|
|
||||||
if untrain {
|
if untrain {
|
||||||
err := jf.Untrain(!untrainJunk, words)
|
err := jf.Untrain(ctx, !untrainJunk, words)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.TrainedJunk = nil
|
m.TrainedJunk = nil
|
||||||
}
|
}
|
||||||
if train {
|
if train {
|
||||||
err := jf.Train(!trainJunk, words)
|
err := jf.Train(ctx, !trainJunk, words)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -137,7 +138,7 @@ func (a *Account) RetrainMessage(log *mlog.Log, tx *bstore.Tx, jf *junk.Filter,
|
||||||
|
|
||||||
// TrainMessage trains the junk filter based on the current m.Junk/m.Notjunk flags,
|
// TrainMessage trains the junk filter based on the current m.Junk/m.Notjunk flags,
|
||||||
// disregarding m.TrainedJunk and not updating that field.
|
// disregarding m.TrainedJunk and not updating that field.
|
||||||
func (a *Account) TrainMessage(log *mlog.Log, jf *junk.Filter, m Message) (bool, error) {
|
func (a *Account) TrainMessage(ctx context.Context, log *mlog.Log, jf *junk.Filter, m Message) (bool, error) {
|
||||||
if !m.Junk && !m.Notjunk || (m.Junk && m.Notjunk) {
|
if !m.Junk && !m.Notjunk || (m.Junk && m.Notjunk) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -160,5 +161,5 @@ func (a *Account) TrainMessage(log *mlog.Log, jf *junk.Filter, m Message) (bool,
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, jf.Train(m.Notjunk, words)
|
return true, jf.Train(ctx, m.Notjunk, words)
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,13 +61,13 @@ type TLSReportRecord struct {
|
||||||
Report tlsrpt.Report
|
Report tlsrpt.Report
|
||||||
}
|
}
|
||||||
|
|
||||||
func database() (rdb *bstore.DB, rerr error) {
|
func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
if tlsrptDB == nil {
|
if tlsrptDB == nil {
|
||||||
p := mox.DataDirPath("tlsrpt.db")
|
p := mox.DataDirPath("tlsrpt.db")
|
||||||
os.MkdirAll(filepath.Dir(p), 0770)
|
os.MkdirAll(filepath.Dir(p), 0770)
|
||||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, TLSReportRecord{})
|
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, TLSReportRecord{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ func database() (rdb *bstore.DB, rerr error) {
|
||||||
|
|
||||||
// Init opens and possibly initializes the database.
|
// Init opens and possibly initializes the database.
|
||||||
func Init() error {
|
func Init() error {
|
||||||
_, err := database()
|
_, err := database(mox.Shutdown)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func Close() {
|
||||||
func AddReport(ctx context.Context, verifiedFromDomain dns.Domain, mailFrom string, r *tlsrpt.Report) error {
|
func AddReport(ctx context.Context, verifiedFromDomain dns.Domain, mailFrom string, r *tlsrpt.Report) error {
|
||||||
log := xlog.WithContext(ctx)
|
log := xlog.WithContext(ctx)
|
||||||
|
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -149,39 +149,39 @@ func AddReport(ctx context.Context, verifiedFromDomain dns.Domain, mailFrom stri
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
record.Domain = reportdom.Name()
|
record.Domain = reportdom.Name()
|
||||||
return db.Insert(&record)
|
return db.Insert(ctx, &record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Records returns all TLS reports in the database.
|
// Records returns all TLS reports in the database.
|
||||||
func Records(ctx context.Context) ([]TLSReportRecord, error) {
|
func Records(ctx context.Context) ([]TLSReportRecord, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return bstore.QueryDB[TLSReportRecord](db).List()
|
return bstore.QueryDB[TLSReportRecord](ctx, db).List()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordID returns the report for the ID.
|
// RecordID returns the report for the ID.
|
||||||
func RecordID(ctx context.Context, id int64) (TLSReportRecord, error) {
|
func RecordID(ctx context.Context, id int64) (TLSReportRecord, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return TLSReportRecord{}, err
|
return TLSReportRecord{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
e := TLSReportRecord{ID: id}
|
e := TLSReportRecord{ID: id}
|
||||||
err = db.Get(&e)
|
err = db.Get(ctx, &e)
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordsPeriodDomain returns the reports overlapping start and end, for the given
|
// RecordsPeriodDomain returns the reports overlapping start and end, for the given
|
||||||
// domain. If domain is empty, all records match for domain.
|
// domain. If domain is empty, all records match for domain.
|
||||||
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]TLSReportRecord, error) {
|
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]TLSReportRecord, error) {
|
||||||
db, err := database()
|
db, err := database(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
q := bstore.QueryDB[TLSReportRecord](db)
|
q := bstore.QueryDB[TLSReportRecord](ctx, db)
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
q.FilterNonzero(TLSReportRecord{Domain: domain})
|
q.FilterNonzero(TLSReportRecord{Domain: domain})
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ import (
|
||||||
"github.com/mjl-/mox/tlsrpt"
|
"github.com/mjl-/mox/tlsrpt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ctxbg = context.Background()
|
||||||
|
|
||||||
const reportJSON = `{
|
const reportJSON = `{
|
||||||
"organization-name": "Company-X",
|
"organization-name": "Company-X",
|
||||||
"date-range": {
|
"date-range": {
|
||||||
|
@ -59,6 +61,8 @@ const reportJSON = `{
|
||||||
}`
|
}`
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
func TestReport(t *testing.T) {
|
||||||
|
mox.Context = ctxbg
|
||||||
|
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||||
mox.ConfigStaticPath = "../testdata/tlsrpt/fake.conf"
|
mox.ConfigStaticPath = "../testdata/tlsrpt/fake.conf"
|
||||||
mox.Conf.Static.DataDir = "."
|
mox.Conf.Static.DataDir = "."
|
||||||
// Recognize as configured domain.
|
// Recognize as configured domain.
|
||||||
|
@ -89,7 +93,7 @@ func TestReport(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing TLSRPT from message %q: %s", file.Name(), err)
|
t.Fatalf("parsing TLSRPT from message %q: %s", file.Name(), err)
|
||||||
}
|
}
|
||||||
if err := AddReport(context.Background(), dns.Domain{ASCII: "mox.example"}, "tlsrpt@mox.example", report); err != nil {
|
if err := AddReport(ctxbg, dns.Domain{ASCII: "mox.example"}, "tlsrpt@mox.example", report); err != nil {
|
||||||
t.Fatalf("adding report to database: %s", err)
|
t.Fatalf("adding report to database: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -97,11 +101,11 @@ func TestReport(t *testing.T) {
|
||||||
report, err := tlsrpt.Parse(strings.NewReader(reportJSON))
|
report, err := tlsrpt.Parse(strings.NewReader(reportJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing report: %v", err)
|
t.Fatalf("parsing report: %v", err)
|
||||||
} else if err := AddReport(context.Background(), dns.Domain{ASCII: "company-y.example"}, "tlsrpt@company-y.example", report); err != nil {
|
} else if err := AddReport(ctxbg, dns.Domain{ASCII: "company-y.example"}, "tlsrpt@company-y.example", report); err != nil {
|
||||||
t.Fatalf("adding report to database: %s", err)
|
t.Fatalf("adding report to database: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
records, err := Records(context.Background())
|
records, err := Records(ctxbg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetching records: %s", err)
|
t.Fatalf("fetching records: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -112,14 +116,14 @@ func TestReport(t *testing.T) {
|
||||||
if !reflect.DeepEqual(&r.Report, report) {
|
if !reflect.DeepEqual(&r.Report, report) {
|
||||||
t.Fatalf("report, got %#v, expected %#v", r.Report, report)
|
t.Fatalf("report, got %#v, expected %#v", r.Report, report)
|
||||||
}
|
}
|
||||||
if _, err := RecordID(context.Background(), r.ID); err != nil {
|
if _, err := RecordID(ctxbg, r.ID); err != nil {
|
||||||
t.Fatalf("get record by id: %v", err)
|
t.Fatalf("get record by id: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
start, _ := time.Parse(time.RFC3339, "2016-04-01T00:00:00Z")
|
start, _ := time.Parse(time.RFC3339, "2016-04-01T00:00:00Z")
|
||||||
end, _ := time.Parse(time.RFC3339, "2016-04-01T23:59:59Z")
|
end, _ := time.Parse(time.RFC3339, "2016-04-01T23:59:59Z")
|
||||||
records, err = RecordsPeriodDomain(context.Background(), start, end, "test.xmox.nl")
|
records, err = RecordsPeriodDomain(ctxbg, start, end, "test.xmox.nl")
|
||||||
if err != nil || len(records) != 1 {
|
if err != nil || len(records) != 1 {
|
||||||
t.Fatalf("got err %v, records %#v, expected no error with 1 record", err, records)
|
t.Fatalf("got err %v, records %#v, expected no error with 1 record", err, records)
|
||||||
}
|
}
|
||||||
|
|
3
vendor/github.com/mjl-/bstore/.gitignore
generated
vendored
3
vendor/github.com/mjl-/bstore/.gitignore
generated
vendored
|
@ -1,3 +1,4 @@
|
||||||
/cover.out
|
/cover.out
|
||||||
/cover.html
|
/cover.html
|
||||||
/testdata/*.db
|
/testdata/tmp.*.db
|
||||||
|
/testdata/mail.db
|
||||||
|
|
2
vendor/github.com/mjl-/bstore/Makefile
generated
vendored
2
vendor/github.com/mjl-/bstore/Makefile
generated
vendored
|
@ -2,7 +2,7 @@ build:
|
||||||
go build ./...
|
go build ./...
|
||||||
go vet ./...
|
go vet ./...
|
||||||
GOARCH=386 go vet ./...
|
GOARCH=386 go vet ./...
|
||||||
staticcheck ./...
|
staticcheck -checks 'all,-ST1012' ./...
|
||||||
./gendoc.sh
|
./gendoc.sh
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
|
|
56
vendor/github.com/mjl-/bstore/README.md
generated
vendored
56
vendor/github.com/mjl-/bstore/README.md
generated
vendored
|
@ -1,22 +1,42 @@
|
||||||
bstore is a database library for storing and quering Go struct data.
|
Bstore is a database library for storing and quering Go values.
|
||||||
|
|
||||||
See https://pkg.go.dev/github.com/mjl-/bstore
|
Bstore is designed as a small, pure Go library that still provides most of
|
||||||
|
the common data consistency requirements for modest database use cases. Bstore
|
||||||
|
aims to make basic use of cgo-based libraries, such as sqlite, unnecessary.
|
||||||
|
|
||||||
|
See https://pkg.go.dev/github.com/mjl-/bstore for features, examples and full
|
||||||
|
documentation.
|
||||||
|
|
||||||
MIT-licensed
|
MIT-licensed
|
||||||
|
|
||||||
# Comparison
|
|
||||||
|
|
||||||
Bstore is designed as a small, pure Go library that still provides most of the
|
# FAQ - Frequently Asked Questions
|
||||||
common data consistency requirements for modest database use cases. Bstore aims
|
|
||||||
to make basic use of cgo-based libraries, such as sqlite, unnecessary. Sqlite
|
## Is bstore an ORM?
|
||||||
is a great library, but Go applications that require cgo are hard to
|
|
||||||
|
No. The API for bstore may look like an ORM. But instead of mapping bstore
|
||||||
|
"queries" (function calls) to an SQL query string, bstore executes them
|
||||||
|
directly without converting to a query language, storing the data itself.
|
||||||
|
|
||||||
|
## How does bstore store its data?
|
||||||
|
|
||||||
|
A bstore database is a single-file BoltDB database. BoltDB provides ACID
|
||||||
|
properties. Bstore uses a BoltDB "bucket" (key/value store) for each Go type
|
||||||
|
stored, with multiple subbuckets: one for type definitions, one for the actual
|
||||||
|
data, and one bucket per index. BoltDB stores data in a B+tree. See format.md
|
||||||
|
for details.
|
||||||
|
|
||||||
|
## How does bstore compare to sqlite?
|
||||||
|
|
||||||
|
Sqlite is a great library, but Go applications that require cgo are hard to
|
||||||
cross-compile. With bstore, cross-compiling to most Go-supported platforms
|
cross-compile. With bstore, cross-compiling to most Go-supported platforms
|
||||||
stays trivial. Although bstore is much more limited in so many aspects than
|
stays trivial (though not plan9, unfortunately). Although bstore is much more
|
||||||
sqlite, bstore also offers some advantages as well.
|
limited in so many aspects than sqlite, bstore also offers some advantages as
|
||||||
|
well. Some points of comparison:
|
||||||
|
|
||||||
- Cross-compilation and reproducibility: Trivial with bstore due to pure Go,
|
- Cross-compilation and reproducibility: Trivial with bstore due to pure Go,
|
||||||
much harder with sqlite because of cgo.
|
much harder with sqlite because of cgo.
|
||||||
- Code complexity: low with bstore (6k lines including comments/docs), high
|
- Code complexity: low with bstore (7k lines including comments/docs), high
|
||||||
with sqlite.
|
with sqlite.
|
||||||
- Query language: mostly-type-checked function calls in bstore, free-form query
|
- Query language: mostly-type-checked function calls in bstore, free-form query
|
||||||
strings only checked at runtime with sqlite.
|
strings only checked at runtime with sqlite.
|
||||||
|
@ -33,19 +53,3 @@ sqlite, bstore also offers some advantages as well.
|
||||||
WAL or journal files).
|
WAL or journal files).
|
||||||
- Test coverage: decent coverage but limited real-world for bstore, versus
|
- Test coverage: decent coverage but limited real-world for bstore, versus
|
||||||
extremely thoroughly tested and with enormous real-world use.
|
extremely thoroughly tested and with enormous real-world use.
|
||||||
|
|
||||||
# FAQ
|
|
||||||
|
|
||||||
Q: Is bstore an ORM?
|
|
||||||
|
|
||||||
A: No. The API for bstore may look like an ORM. But instead of mapping bstore
|
|
||||||
"queries" (function calls) to an SQL query string, bstore executes them
|
|
||||||
directly without converting to a query language.
|
|
||||||
|
|
||||||
Q: How does bstore store its data?
|
|
||||||
|
|
||||||
A bstore database is a single-file BoltDB database. BoltDB provides ACID
|
|
||||||
properties. Bstore uses a BoltDB "bucket" (key/value store) for each Go type
|
|
||||||
stored, with multiple subbuckets: one for type definitions, one for the actual
|
|
||||||
data, and one bucket per index. BoltDB stores data in a B+tree. See format.md
|
|
||||||
for details.
|
|
||||||
|
|
17
vendor/github.com/mjl-/bstore/default.go
generated
vendored
17
vendor/github.com/mjl-/bstore/default.go
generated
vendored
|
@ -24,7 +24,7 @@ func (f field) applyDefault(rv reflect.Value) error {
|
||||||
case kindBytes, kindBinaryMarshal, kindMap:
|
case kindBytes, kindBinaryMarshal, kindMap:
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case kindSlice, kindStruct:
|
case kindSlice, kindStruct, kindArray:
|
||||||
return f.Type.applyDefault(rv)
|
return f.Type.applyDefault(rv)
|
||||||
|
|
||||||
case kindBool, kindInt, kindInt8, kindInt16, kindInt32, kindInt64, kindUint, kindUint8, kindUint16, kindUint32, kindUint64, kindFloat32, kindFloat64, kindString, kindTime:
|
case kindBool, kindInt, kindInt8, kindInt16, kindInt32, kindInt64, kindUint, kindUint8, kindUint16, kindUint32, kindUint64, kindFloat32, kindFloat64, kindString, kindTime:
|
||||||
|
@ -53,9 +53,9 @@ func (f field) applyDefault(rv reflect.Value) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// only for recursing. we do not support recursing into maps because it would
|
// only for recursing. we do not support recursing into maps because it would
|
||||||
// involve more work making values settable. and how sensible it it anyway?
|
// involve more work making values settable. and how sensible is it anyway?
|
||||||
func (ft fieldType) applyDefault(rv reflect.Value) error {
|
func (ft fieldType) applyDefault(rv reflect.Value) error {
|
||||||
if ft.Ptr && (rv.IsZero() || rv.IsNil()) {
|
if ft.Ptr && rv.IsZero() {
|
||||||
return nil
|
return nil
|
||||||
} else if ft.Ptr {
|
} else if ft.Ptr {
|
||||||
rv = rv.Elem()
|
rv = rv.Elem()
|
||||||
|
@ -64,12 +64,19 @@ func (ft fieldType) applyDefault(rv reflect.Value) error {
|
||||||
case kindSlice:
|
case kindSlice:
|
||||||
n := rv.Len()
|
n := rv.Len()
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if err := ft.List.applyDefault(rv.Index(i)); err != nil {
|
if err := ft.ListElem.applyDefault(rv.Index(i)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if err := ft.ListElem.applyDefault(rv.Index(i)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
for _, nf := range ft.Fields {
|
for _, nf := range ft.structFields {
|
||||||
nfv := rv.FieldByIndex(nf.structField.Index)
|
nfv := rv.FieldByIndex(nf.structField.Index)
|
||||||
if err := nf.applyDefault(nfv); err != nil {
|
if err := nf.applyDefault(nfv); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
172
vendor/github.com/mjl-/bstore/doc.go
generated
vendored
172
vendor/github.com/mjl-/bstore/doc.go
generated
vendored
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
Package bstore is a database library for storing and querying Go struct data.
|
Package bstore is a database library for storing and querying Go values.
|
||||||
|
|
||||||
Bstore is designed as a small, pure Go library that still provides most of
|
Bstore is designed as a small, pure Go library that still provides most of
|
||||||
the common data consistency requirements for modest database use cases. Bstore
|
the common data consistency requirements for modest database use cases. Bstore
|
||||||
|
@ -9,7 +9,7 @@ Bstore implements autoincrementing primary keys, indices, default values,
|
||||||
enforcement of nonzero, unique and referential integrity constraints, automatic
|
enforcement of nonzero, unique and referential integrity constraints, automatic
|
||||||
schema updates and a query API for combining filters/sorting/limits. Queries
|
schema updates and a query API for combining filters/sorting/limits. Queries
|
||||||
are planned and executed using indices for fast execution where possible.
|
are planned and executed using indices for fast execution where possible.
|
||||||
Bstores is designed with the Go type system in mind: you typically don't have to
|
Bstore is designed with the Go type system in mind: you typically don't have to
|
||||||
write any (un)marshal code for your types.
|
write any (un)marshal code for your types.
|
||||||
|
|
||||||
# Field types
|
# Field types
|
||||||
|
@ -21,7 +21,7 @@ types, but not pointers to pointers:
|
||||||
- uint (as uint32), uint8, uint16, uint32, uint64
|
- uint (as uint32), uint8, uint16, uint32, uint64
|
||||||
- bool, float32, float64, string, []byte
|
- bool, float32, float64, string, []byte
|
||||||
- Maps, with keys and values of any supported type, except keys with pointer types.
|
- Maps, with keys and values of any supported type, except keys with pointer types.
|
||||||
- Slices, with elements of any supported type.
|
- Slices and arrays, with elements of any supported type.
|
||||||
- time.Time
|
- time.Time
|
||||||
- Types that implement binary.MarshalBinary and binary.UnmarshalBinary, useful
|
- Types that implement binary.MarshalBinary and binary.UnmarshalBinary, useful
|
||||||
for struct types with state in private fields. Do not change the
|
for struct types with state in private fields. Do not change the
|
||||||
|
@ -32,24 +32,27 @@ Note: int and uint are stored as int32 and uint32, for compatibility of database
|
||||||
files between 32bit and 64bit systems. Where possible, use explicit (u)int32 or
|
files between 32bit and 64bit systems. Where possible, use explicit (u)int32 or
|
||||||
(u)int64 types.
|
(u)int64 types.
|
||||||
|
|
||||||
Embedded structs are handled by storing the individual fields of the embedded
|
Cyclic types are supported, but cyclic data is not. Attempting to store cyclic
|
||||||
struct. The named embedded type is not part of the type schema, and can
|
data will likely result in a stack overflow panic.
|
||||||
currently only be used with UpdateField and UpdateFields, not for filtering.
|
|
||||||
|
Anonymous struct fields are handled by taking in each of the anonymous struct's
|
||||||
|
fields as a type's own fields. The named embedded type is not part of the type
|
||||||
|
schema, and with a Query it can currently only be used with UpdateField and
|
||||||
|
UpdateFields, not for filtering.
|
||||||
|
|
||||||
Bstore embraces the use of Go zero values. Use zero values, possibly pointers,
|
Bstore embraces the use of Go zero values. Use zero values, possibly pointers,
|
||||||
where you would use NULL values in SQL.
|
where you would use NULL values in SQL.
|
||||||
|
|
||||||
Types that have not yet been implemented: interface values, (fixed length) arrays,
|
|
||||||
complex numbers.
|
|
||||||
|
|
||||||
# Struct tags
|
# Struct tags
|
||||||
|
|
||||||
The typical Go struct can be stored in the database. The first field of a
|
The typical Go struct can be stored in the database. The first field of a
|
||||||
struct type is its primary key, and must always be unique. Additional behaviour
|
struct type is its primary key, must always be unique, and in case of an
|
||||||
can be configured through struct tag "bstore". The values are comma-separated.
|
integer type the insertion of a zero value automatically changes it to the next
|
||||||
Typically one word, but some have multiple space-separated words:
|
sequence number by default. Additional behaviour can be configured through
|
||||||
|
struct tag "bstore". The values are comma-separated. Typically one word, but
|
||||||
|
some have multiple space-separated words:
|
||||||
|
|
||||||
- "-" ignores the field entirely.
|
- "-" ignores the field entirely, not stored.
|
||||||
- "name <fieldname>", use "fieldname" instead of the Go type field name.
|
- "name <fieldname>", use "fieldname" instead of the Go type field name.
|
||||||
- "nonzero", enforces that field values are not the zero value.
|
- "nonzero", enforces that field values are not the zero value.
|
||||||
- "noauto", only valid for integer types, and only for the primary key. By
|
- "noauto", only valid for integer types, and only for the primary key. By
|
||||||
|
@ -57,16 +60,19 @@ Typically one word, but some have multiple space-separated words:
|
||||||
assigned on insert when it is 0. With noauto inserting a 0 value results in an
|
assigned on insert when it is 0. With noauto inserting a 0 value results in an
|
||||||
error. For primary keys of other types inserting the zero value always results
|
error. For primary keys of other types inserting the zero value always results
|
||||||
in an error.
|
in an error.
|
||||||
- "index" or "index <field1+field2+...> [<name>]", adds an index. In the first
|
- "index" or "index <field1>+<field2>+<...> [<name>]", adds an index. In the
|
||||||
form, the index is on the field on which the tag is specified, and the index
|
first form, the index is on the field on which the tag is specified, and the
|
||||||
name is the same as the field name. In the second form multiple fields can be
|
index name is the same as the field name. In the second form multiple fields can
|
||||||
specified, and an optional name. The first field must be the field on which
|
be specified, and an optional name. The first field must be the field on which
|
||||||
the tag is specified. The field names are +-separated. The default name for
|
the tag is specified. The field names are +-separated. The default name for the
|
||||||
the second form is the same +-separated string but can be set explicitly to
|
second form is the same +-separated string but can be set explicitly with the
|
||||||
the second parameter. An index can only be set for basic integer types, bools,
|
second parameter. An index can only be set for basic integer types, bools, time
|
||||||
time and strings. Indices are automatically (re)created when registering a
|
and strings. A field of slice type can also have an index (but not a unique
|
||||||
type.
|
index, and only one slice field per index), allowing fast lookup of any single
|
||||||
- "unique" or "unique <field1+field2+...> [<name>]", adds an index as with
|
value in the slice with Query.FilterIn. Indices are automatically (re)created
|
||||||
|
when registering a type. Fields with a pointer type cannot have an index.
|
||||||
|
String values used in an index cannot contain a \0.
|
||||||
|
- "unique" or "unique <field1>+<field2>+<...> [<name>]", adds an index as with
|
||||||
"index" and also enforces a unique constraint. For time.Time the timezone is
|
"index" and also enforces a unique constraint. For time.Time the timezone is
|
||||||
ignored for the uniqueness check.
|
ignored for the uniqueness check.
|
||||||
- "ref <type>", enforces that the value exists as primary key for "type".
|
- "ref <type>", enforces that the value exists as primary key for "type".
|
||||||
|
@ -80,8 +86,8 @@ Typically one word, but some have multiple space-separated words:
|
||||||
Times are parsed as time.RFC3339 otherwise. Supported types: bool
|
Times are parsed as time.RFC3339 otherwise. Supported types: bool
|
||||||
("true"/"false"), integers, floats, strings. Value is not quoted and no escaping
|
("true"/"false"), integers, floats, strings. Value is not quoted and no escaping
|
||||||
of special characters, like the comma that separates struct tag words, is
|
of special characters, like the comma that separates struct tag words, is
|
||||||
possible. Defaults are also replaced on fields in nested structs and
|
possible. Defaults are also replaced on fields in nested structs, slices
|
||||||
slices, but not in maps.
|
and arrays, but not in maps.
|
||||||
- "typename <name>", override name of the type. The name of the Go type is
|
- "typename <name>", override name of the type. The name of the Go type is
|
||||||
used by default. Can only be present on the first field (primary key).
|
used by default. Can only be present on the first field (primary key).
|
||||||
Useful for doing schema updates.
|
Useful for doing schema updates.
|
||||||
|
@ -89,18 +95,14 @@ Typically one word, but some have multiple space-separated words:
|
||||||
# Schema updates
|
# Schema updates
|
||||||
|
|
||||||
Before using a Go type, you must register it for use with the open database by
|
Before using a Go type, you must register it for use with the open database by
|
||||||
passing a (zero) value of that type to the Open or Register functions. For each
|
passing a (possibly zero) value of that type to the Open or Register functions.
|
||||||
type, a type definition is stored in the database. If a type has an updated
|
For each type, a type definition is stored in the database. If a type has an
|
||||||
definition since the previous database open, a new type definition is added to
|
updated definition since the previous database open, a new type definition is
|
||||||
the database automatically and any required modifications are made: Indexes
|
added to the database automatically and any required modifications are made and
|
||||||
(re)created, fields added/removed, new nonzero/unique/reference constraints
|
checked: Indexes (re)created, fields added/removed, new
|
||||||
validated.
|
nonzero/unique/reference constraints validated.
|
||||||
|
|
||||||
If data/types cannot be updated automatically (e.g. converting an int field into
|
As a special case, you can change field types between pointer and non-pointer
|
||||||
a string field), custom data migration code is needed. You may have to keep
|
|
||||||
track of a data/schema version.
|
|
||||||
|
|
||||||
As a special case, you can switch field types between pointer and non-pointer
|
|
||||||
types. With one exception: changing from pointer to non-pointer where the type
|
types. With one exception: changing from pointer to non-pointer where the type
|
||||||
has a field that must be nonzero is not allowed. The on-disk encoding will not be
|
has a field that must be nonzero is not allowed. The on-disk encoding will not be
|
||||||
changed, and nil pointers will turn into zero values, and zero values into nil
|
changed, and nil pointers will turn into zero values, and zero values into nil
|
||||||
|
@ -110,33 +112,95 @@ Because named embed structs are not part of the type definition, you can
|
||||||
wrap/unwrap fields into a embed/anonymous struct field. No new type definition
|
wrap/unwrap fields into a embed/anonymous struct field. No new type definition
|
||||||
is created.
|
is created.
|
||||||
|
|
||||||
# BoltDB
|
Some schema conversions are not allowed. In some cases due to architectural
|
||||||
|
limitations. In some cases because the constraint checks haven't been
|
||||||
|
implemented yet, or the parsing code does not yet know how to parse the old
|
||||||
|
on-disk values into the updated Go types. If you need a conversion that is not
|
||||||
|
supported, you will need to write a manual conversion, and you would have to
|
||||||
|
keep track whether the update has been executed.
|
||||||
|
|
||||||
BoltDB is used as underlying storage. Bolt provides ACID transactions, storing
|
Changes that are allowed:
|
||||||
its data in a B+tree. Only a single write transaction can be active at a time,
|
|
||||||
but otherwise multiple read-only transactions can be active. Do not start a
|
|
||||||
blocking read-only transaction while holding a writable transaction or vice
|
|
||||||
versa, this will cause deadlock.
|
|
||||||
|
|
||||||
Bolt uses Go types that are memory mapped to the database file. This means bolt
|
- From smaller to larger integer types (same signedness).
|
||||||
database files cannot be transferred between machines with different endianness.
|
- Removal of "noauto" on primary keys (always integer types). This updates the
|
||||||
Bolt uses explicit widths for its types, so files can be transferred between
|
"next sequence" counter automatically to continue after the current maximum
|
||||||
32bit and 64bit machines of same endianness.
|
value.
|
||||||
|
- Adding/removing/modifying an index, including a unique index. When a unique
|
||||||
|
index is added, the current records are verified to be unique.
|
||||||
|
- Adding/removing a reference. When a reference is added, the current records
|
||||||
|
are verified to be valid references.
|
||||||
|
- Add/remove a nonzero constraint. Existing records are verified.
|
||||||
|
|
||||||
|
Conversions that are not currently allowed, but may be in the future:
|
||||||
|
|
||||||
|
- Signedness of integer types. With a one-time check that old values fit in the new
|
||||||
|
type, this could be allowed in the future.
|
||||||
|
- Conversions between basic types: strings, []byte, integers, floats, boolean.
|
||||||
|
Checks would have to be added for some of these conversions. For example,
|
||||||
|
from string to integer: the on-disk string values would have to be valid
|
||||||
|
integers.
|
||||||
|
- Types of primary keys cannot be changed, also not from one integer type to a
|
||||||
|
wider integer type of same signedness.
|
||||||
|
|
||||||
|
# BoltDB and storage
|
||||||
|
|
||||||
|
BoltDB is used as underlying storage. BoltDB stores key/values in a single
|
||||||
|
file, in multiple/nested buckets (namespaces) in a B+tree and provides ACID
|
||||||
|
transactions. Either a single write transaction or multiple read-only
|
||||||
|
transactions can be active at a time. Do not start a blocking read-only
|
||||||
|
transaction while holding a writable transaction or vice versa, this will cause
|
||||||
|
deadlock.
|
||||||
|
|
||||||
|
BoltDB returns Go values that are memory mapped to the database file. This
|
||||||
|
means BoltDB/bstore database files cannot be transferred between machines with
|
||||||
|
different endianness. BoltDB uses explicit widths for its types, so files can
|
||||||
|
be transferred between 32bit and 64bit machines of same endianness. While
|
||||||
|
BoltDB returns read-only memory mapped Go values, bstore only ever returns
|
||||||
|
parsed/copied regular writable Go values that require no special programmer
|
||||||
|
attention.
|
||||||
|
|
||||||
|
For each Go type opened for a database file, bstore ensures a BoltDB bucket
|
||||||
|
exists with two subbuckets:
|
||||||
|
|
||||||
|
- "types", with type descriptions of the stored records. Each time the database
|
||||||
|
file is opened with a modified Go type (add/removed/modified
|
||||||
|
field/type/bstore struct tag), a new type description is automatically added,
|
||||||
|
identified by sequence number.
|
||||||
|
- "records", containing all data, with the type's primary key as BoltDB key,
|
||||||
|
and the encoded remaining fields as value. The encoding starts with a
|
||||||
|
reference to a type description.
|
||||||
|
|
||||||
|
For each index, another subbucket is created, its name starting with "index.".
|
||||||
|
The stored keys consist of the index fields followed by the primary key, and an
|
||||||
|
empty value.
|
||||||
|
|
||||||
# Limitations
|
# Limitations
|
||||||
|
|
||||||
|
Bstore has limitations, not all of which are architectural so may be fixed in
|
||||||
|
the future.
|
||||||
|
|
||||||
Bstore does not implement the equivalent of SQL joins, aggregates, and many
|
Bstore does not implement the equivalent of SQL joins, aggregates, and many
|
||||||
other concepts.
|
other concepts.
|
||||||
|
|
||||||
Filtering/comparing/sorting on pointer fields is not currently allowed. Pointer
|
Filtering/comparing/sorting on pointer fields is not allowed. Pointer fields
|
||||||
fields cannot have a (unique) index due to the current index format. Using zero
|
cannot have a (unique) index. Use non-pointer values with the zero value as the
|
||||||
values is recommended instead for now.
|
equivalent of a nil pointer.
|
||||||
|
|
||||||
Integer field types can be expanded to wider types, but not to a different
|
The first field of a stored struct is always the primary key. Autoincrement is
|
||||||
signedness or a smaller integer (fewer bits). The primary key of a type cannot
|
only available for the primary key.
|
||||||
currently be changed.
|
|
||||||
|
|
||||||
The first field of a struct is always the primary key. Types requires an
|
BoltDB opens the database file with a lock. Only one process can have the
|
||||||
explicit primary key. Autoincrement is only available for the primary key.
|
database open at a time.
|
||||||
|
|
||||||
|
An index stored on disk in BoltDB can consume more disk space than other
|
||||||
|
database systems would: For each record, the indexed field(s) and primary key
|
||||||
|
are stored in full. Because bstore uses BoltDB as key/value store, and doesn't
|
||||||
|
manage disk pages itself, it cannot as efficiently pack an index page with many
|
||||||
|
records.
|
||||||
|
|
||||||
|
Interface values cannot be stored. This would require storing the type along
|
||||||
|
with the value. Instead, use a type that is a BinaryMarshaler.
|
||||||
|
|
||||||
|
Values of builtin type "complex" cannot be stored.
|
||||||
*/
|
*/
|
||||||
package bstore
|
package bstore
|
||||||
|
|
12
vendor/github.com/mjl-/bstore/equal.go
generated
vendored
12
vendor/github.com/mjl-/bstore/equal.go
generated
vendored
|
@ -63,7 +63,15 @@ func (ft fieldType) equal(ov, v reflect.Value) (r bool) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if !ft.List.equal(ov.Index(i), v.Index(i)) {
|
if !ft.ListElem.equal(ov.Index(i), v.Index(i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if !ft.ListElem.equal(ov.Index(i), v.Index(i)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,7 +86,7 @@ func (ft fieldType) equal(ov, v reflect.Value) (r bool) {
|
||||||
}
|
}
|
||||||
return bytes.Equal(obuf, buf)
|
return bytes.Equal(obuf, buf)
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
for _, f := range ft.Fields {
|
for _, f := range ft.structFields {
|
||||||
fov := ov.FieldByIndex(f.structField.Index)
|
fov := ov.FieldByIndex(f.structField.Index)
|
||||||
fv := v.FieldByIndex(f.structField.Index)
|
fv := v.FieldByIndex(f.structField.Index)
|
||||||
if !f.Type.equal(fov, fv) {
|
if !f.Type.equal(fov, fv) {
|
||||||
|
|
36
vendor/github.com/mjl-/bstore/exec.go
generated
vendored
36
vendor/github.com/mjl-/bstore/exec.go
generated
vendored
|
@ -10,6 +10,8 @@ import (
|
||||||
bolt "go.etcd.io/bbolt"
|
bolt "go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// todo optimize: do not fetch full record if we can apply the filters with just the values we glean from the index key.
|
||||||
|
|
||||||
// exec represents the execution of a query plan.
|
// exec represents the execution of a query plan.
|
||||||
type exec[T any] struct {
|
type exec[T any] struct {
|
||||||
q *Query[T]
|
q *Query[T]
|
||||||
|
@ -94,6 +96,13 @@ func (e *exec[T]) nextKey(write, value bool) ([]byte, T, error) {
|
||||||
|
|
||||||
q := e.q
|
q := e.q
|
||||||
|
|
||||||
|
if q.err == nil {
|
||||||
|
select {
|
||||||
|
case <-q.ctxDone:
|
||||||
|
q.error(q.ctx.Err())
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
if q.err != nil {
|
if q.err != nil {
|
||||||
return nil, zero, q.err
|
return nil, zero, q.err
|
||||||
}
|
}
|
||||||
|
@ -424,6 +433,25 @@ func (e *exec[T]) checkFilter(p *pair[T]) (rok bool, rerr error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case filterInSlice[T]:
|
||||||
|
v, err := p.Value(e)
|
||||||
|
if err != nil {
|
||||||
|
q.error(err)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
frv := rv.FieldByIndex(f.field.structField.Index)
|
||||||
|
n := frv.Len()
|
||||||
|
var have bool
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if f.field.Type.ListElem.equal(frv.Index(i), f.rvalue) {
|
||||||
|
have = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !have {
|
||||||
|
return
|
||||||
|
}
|
||||||
case filterCompare[T]:
|
case filterCompare[T]:
|
||||||
v, err := p.Value(e)
|
v, err := p.Value(e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -531,10 +559,10 @@ func compare(k kind, a, b reflect.Value) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *exec[T]) sort() {
|
func (e *exec[T]) sort() {
|
||||||
// todo: We should check whether we actually need to load values. We're just
|
// todo: We should check whether we actually need to load values. We're
|
||||||
// always it now for the time being because SortStableFunc isn't going to
|
// always loading it for the time being because SortStableFunc isn't
|
||||||
// give us a *pair (even though it could because of the slice) so we
|
// going to give us a *pair (even though it could because of the slice)
|
||||||
// couldn't set/cache the value T during sorting.
|
// so we couldn't set/cache the value T during sorting.
|
||||||
q := e.q
|
q := e.q
|
||||||
|
|
||||||
for i := range e.data {
|
for i := range e.data {
|
||||||
|
|
314
vendor/github.com/mjl-/bstore/export.go
generated
vendored
314
vendor/github.com/mjl-/bstore/export.go
generated
vendored
|
@ -13,15 +13,17 @@ import (
|
||||||
// Types returns the types present in the database, regardless of whether they
|
// Types returns the types present in the database, regardless of whether they
|
||||||
// are currently registered using Open or Register. Useful for exporting data
|
// are currently registered using Open or Register. Useful for exporting data
|
||||||
// with Keys and Records.
|
// with Keys and Records.
|
||||||
func (db *DB) Types() ([]string, error) {
|
func (tx *Tx) Types() ([]string, error) {
|
||||||
var types []string
|
if err := tx.ctx.Err(); err != nil {
|
||||||
err := db.Read(func(tx *Tx) error {
|
return nil, err
|
||||||
return tx.btx.ForEach(func(bname []byte, b *bolt.Bucket) error {
|
}
|
||||||
// note: we do not track stats for types operations.
|
|
||||||
|
|
||||||
types = append(types, string(bname))
|
var types []string
|
||||||
return nil
|
err := tx.btx.ForEach(func(bname []byte, b *bolt.Bucket) error {
|
||||||
})
|
// note: we do not track stats for types operations.
|
||||||
|
|
||||||
|
types = append(types, string(bname))
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -31,9 +33,12 @@ func (db *DB) Types() ([]string, error) {
|
||||||
|
|
||||||
// prepareType prepares typeName for export/introspection with DB.Keys,
|
// prepareType prepares typeName for export/introspection with DB.Keys,
|
||||||
// DB.Record, DB.Records. It is different in that it does not require a
|
// DB.Record, DB.Records. It is different in that it does not require a
|
||||||
// reflect.Type to parse into. It parses to a map, e.g. for export to JSON. The
|
// reflect.Type to parse into. It parses to a map, e.g. for export to JSON.
|
||||||
// returned typeVersion has no structFields set in its fields.
|
|
||||||
func (db *DB) prepareType(tx *Tx, typeName string) (map[uint32]*typeVersion, *typeVersion, *bolt.Bucket, []string, error) {
|
func (db *DB) prepareType(tx *Tx, typeName string) (map[uint32]*typeVersion, *typeVersion, *bolt.Bucket, []string, error) {
|
||||||
|
if err := tx.ctx.Err(); err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
rb, err := tx.recordsBucket(typeName, 0.5)
|
rb, err := tx.recordsBucket(typeName, 0.5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
|
@ -51,6 +56,7 @@ func (db *DB) prepareType(tx *Tx, typeName string) (map[uint32]*typeVersion, *ty
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
versions[ntv.Version] = ntv
|
versions[ntv.Version] = ntv
|
||||||
if tv == nil || ntv.Version > tv.Version {
|
if tv == nil || ntv.Version > tv.Version {
|
||||||
tv = ntv
|
tv = ntv
|
||||||
|
@ -74,23 +80,28 @@ func (db *DB) prepareType(tx *Tx, typeName string) (map[uint32]*typeVersion, *ty
|
||||||
// Keys returns the parsed primary keys for the type "typeName". The type does
|
// Keys returns the parsed primary keys for the type "typeName". The type does
|
||||||
// not have to be registered with Open or Register. For use with Record(s) to
|
// not have to be registered with Open or Register. For use with Record(s) to
|
||||||
// export data.
|
// export data.
|
||||||
func (db *DB) Keys(typeName string, fn func(pk any) error) error {
|
func (tx *Tx) Keys(typeName string, fn func(pk any) error) error {
|
||||||
return db.Read(func(tx *Tx) error {
|
_, tv, rb, _, err := tx.db.prepareType(tx, typeName)
|
||||||
_, tv, rb, _, err := db.prepareType(tx, typeName)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
ctxDone := tx.ctx.Done()
|
||||||
|
|
||||||
|
v := reflect.New(reflect.TypeOf(tv.Fields[0].Type.zeroKey())).Elem()
|
||||||
|
return rb.ForEach(func(bk, bv []byte) error {
|
||||||
|
tx.stats.Records.Cursor++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctxDone:
|
||||||
|
return tx.ctx.Err()
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: do not pass nil parser?
|
if err := parsePK(v, bk); err != nil {
|
||||||
v := reflect.New(reflect.TypeOf(tv.Fields[0].Type.zero(nil))).Elem()
|
return err
|
||||||
return rb.ForEach(func(bk, bv []byte) error {
|
}
|
||||||
tx.stats.Records.Cursor++
|
return fn(v.Interface())
|
||||||
|
|
||||||
if err := parsePK(v, bk); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fn(v.Interface())
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,108 +109,109 @@ func (db *DB) Keys(typeName string, fn func(pk any) error) error {
|
||||||
// "Fields" is set to the fields of the type. The type does not have to be
|
// "Fields" is set to the fields of the type. The type does not have to be
|
||||||
// registered with Open or Register. Record parses the data without the Go
|
// registered with Open or Register. Record parses the data without the Go
|
||||||
// type present. BinaryMarshal fields are returned as bytes.
|
// type present. BinaryMarshal fields are returned as bytes.
|
||||||
func (db *DB) Record(typeName, key string, fields *[]string) (map[string]any, error) {
|
func (tx *Tx) Record(typeName, key string, fields *[]string) (map[string]any, error) {
|
||||||
var r map[string]any
|
versions, tv, rb, xfields, err := tx.db.prepareType(tx, typeName)
|
||||||
err := db.Read(func(tx *Tx) error {
|
if err != nil {
|
||||||
versions, tv, rb, xfields, err := db.prepareType(tx, typeName)
|
return nil, err
|
||||||
if err != nil {
|
}
|
||||||
return err
|
*fields = xfields
|
||||||
}
|
|
||||||
*fields = xfields
|
|
||||||
|
|
||||||
var kv any
|
var kv any
|
||||||
switch tv.Fields[0].Type.Kind {
|
switch tv.Fields[0].Type.Kind {
|
||||||
case kindBool:
|
case kindBool:
|
||||||
switch key {
|
switch key {
|
||||||
case "true":
|
case "true":
|
||||||
kv = true
|
kv = true
|
||||||
case "false":
|
case "false":
|
||||||
kv = false
|
kv = false
|
||||||
default:
|
|
||||||
err = fmt.Errorf("%w: invalid bool %q", ErrParam, key)
|
|
||||||
}
|
|
||||||
case kindInt8:
|
|
||||||
kv, err = strconv.ParseInt(key, 10, 8)
|
|
||||||
case kindInt16:
|
|
||||||
kv, err = strconv.ParseInt(key, 10, 16)
|
|
||||||
case kindInt32:
|
|
||||||
kv, err = strconv.ParseInt(key, 10, 32)
|
|
||||||
case kindInt:
|
|
||||||
kv, err = strconv.ParseInt(key, 10, 32)
|
|
||||||
case kindInt64:
|
|
||||||
kv, err = strconv.ParseInt(key, 10, 64)
|
|
||||||
case kindUint8:
|
|
||||||
kv, err = strconv.ParseUint(key, 10, 8)
|
|
||||||
case kindUint16:
|
|
||||||
kv, err = strconv.ParseUint(key, 10, 16)
|
|
||||||
case kindUint32:
|
|
||||||
kv, err = strconv.ParseUint(key, 10, 32)
|
|
||||||
case kindUint:
|
|
||||||
kv, err = strconv.ParseUint(key, 10, 32)
|
|
||||||
case kindUint64:
|
|
||||||
kv, err = strconv.ParseUint(key, 10, 64)
|
|
||||||
case kindString:
|
|
||||||
kv = key
|
|
||||||
case kindBytes:
|
|
||||||
kv = []byte(key) // todo: or decode from base64?
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("internal error: unknown primary key kind %v", tv.Fields[0].Type.Kind)
|
err = fmt.Errorf("%w: invalid bool %q", ErrParam, key)
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pkv := reflect.ValueOf(kv)
|
|
||||||
kind, err := typeKind(pkv.Type())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if kind != tv.Fields[0].Type.Kind {
|
|
||||||
// Convert from various int types above to required type. The ParseInt/ParseUint
|
|
||||||
// calls already validated that the values fit.
|
|
||||||
pkt := reflect.TypeOf(tv.Fields[0].Type.zero(nil))
|
|
||||||
pkv = pkv.Convert(pkt)
|
|
||||||
}
|
|
||||||
k, err := packPK(pkv)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
case kindInt8:
|
||||||
|
kv, err = strconv.ParseInt(key, 10, 8)
|
||||||
|
case kindInt16:
|
||||||
|
kv, err = strconv.ParseInt(key, 10, 16)
|
||||||
|
case kindInt32:
|
||||||
|
kv, err = strconv.ParseInt(key, 10, 32)
|
||||||
|
case kindInt:
|
||||||
|
kv, err = strconv.ParseInt(key, 10, 32)
|
||||||
|
case kindInt64:
|
||||||
|
kv, err = strconv.ParseInt(key, 10, 64)
|
||||||
|
case kindUint8:
|
||||||
|
kv, err = strconv.ParseUint(key, 10, 8)
|
||||||
|
case kindUint16:
|
||||||
|
kv, err = strconv.ParseUint(key, 10, 16)
|
||||||
|
case kindUint32:
|
||||||
|
kv, err = strconv.ParseUint(key, 10, 32)
|
||||||
|
case kindUint:
|
||||||
|
kv, err = strconv.ParseUint(key, 10, 32)
|
||||||
|
case kindUint64:
|
||||||
|
kv, err = strconv.ParseUint(key, 10, 64)
|
||||||
|
case kindString:
|
||||||
|
kv = key
|
||||||
|
case kindBytes:
|
||||||
|
kv = []byte(key) // todo: or decode from base64?
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("internal error: unknown primary key kind %v", tv.Fields[0].Type.Kind)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pkv := reflect.ValueOf(kv)
|
||||||
|
kind, err := typeKind(pkv.Type())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if kind != tv.Fields[0].Type.Kind {
|
||||||
|
// Convert from various int types above to required type. The ParseInt/ParseUint
|
||||||
|
// calls already validated that the values fit.
|
||||||
|
pkt := reflect.TypeOf(tv.Fields[0].Type.zeroKey())
|
||||||
|
pkv = pkv.Convert(pkt)
|
||||||
|
}
|
||||||
|
k, err := packPK(pkv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
tx.stats.Records.Get++
|
tx.stats.Records.Get++
|
||||||
bv := rb.Get(k)
|
bv := rb.Get(k)
|
||||||
if bv == nil {
|
if bv == nil {
|
||||||
return ErrAbsent
|
return nil, ErrAbsent
|
||||||
}
|
}
|
||||||
record, err := parseMap(versions, k, bv)
|
record, err := parseMap(versions, k, bv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
r = record
|
return record, nil
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return r, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Records calls "fn" for each record of "typeName". Records sets "fields" to
|
// Records calls "fn" for each record of "typeName". Records sets "fields" to
|
||||||
// the fields of the type. The type does not have to be registered with Open or
|
// the fields of the type. The type does not have to be registered with Open or
|
||||||
// Register. Record parses the data without the Go type present. BinaryMarshal
|
// Register. Record parses the data without the Go type present. BinaryMarshal
|
||||||
// fields are returned as bytes.
|
// fields are returned as bytes.
|
||||||
func (db *DB) Records(typeName string, fields *[]string, fn func(map[string]any) error) error {
|
func (tx *Tx) Records(typeName string, fields *[]string, fn func(map[string]any) error) error {
|
||||||
return db.Read(func(tx *Tx) error {
|
versions, _, rb, xfields, err := tx.db.prepareType(tx, typeName)
|
||||||
versions, _, rb, xfields, err := db.prepareType(tx, typeName)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*fields = xfields
|
||||||
|
|
||||||
|
ctxDone := tx.ctx.Done()
|
||||||
|
|
||||||
|
return rb.ForEach(func(bk, bv []byte) error {
|
||||||
|
tx.stats.Records.Cursor++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctxDone:
|
||||||
|
return tx.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := parseMap(versions, bk, bv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*fields = xfields
|
return fn(record)
|
||||||
|
|
||||||
return rb.ForEach(func(bk, bv []byte) error {
|
|
||||||
tx.stats.Records.Cursor++
|
|
||||||
|
|
||||||
record, err := parseMap(versions, bk, bv)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fn(record)
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,7 +240,7 @@ func parseMap(versions map[uint32]*typeVersion, bk, bv []byte) (record map[strin
|
||||||
|
|
||||||
r := map[string]any{}
|
r := map[string]any{}
|
||||||
|
|
||||||
v := reflect.New(reflect.TypeOf(tv.Fields[0].Type.zero(p))).Elem()
|
v := reflect.New(reflect.TypeOf(tv.Fields[0].Type.zeroKey())).Elem()
|
||||||
err := parsePK(v, bk)
|
err := parsePK(v, bk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -243,12 +255,12 @@ func parseMap(versions map[uint32]*typeVersion, bk, bv []byte) (record map[strin
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
r[f.Name] = f.Type.parseValue(p)
|
r[f.Name] = f.Type.parseValue(p)
|
||||||
} else {
|
} else {
|
||||||
r[f.Name] = f.Type.zero(p)
|
r[f.Name] = f.Type.zeroExportValue()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.buf) != 0 {
|
if len(p.buf) != 0 {
|
||||||
return nil, fmt.Errorf("%w: leftover data after parsing", ErrStore)
|
return nil, fmt.Errorf("%w: leftover data after parsing (%d %x %q)", ErrStore, len(p.buf), p.buf, p.buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
|
@ -315,14 +327,21 @@ func (ft fieldType) parseValue(p *parser) any {
|
||||||
var l []any
|
var l []any
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
l = append(l, ft.List.parseValue(p))
|
l = append(l, ft.ListElem.parseValue(p))
|
||||||
} else {
|
} else {
|
||||||
// Always add non-zero elements, or we would
|
// Always add non-zero elements, or we would
|
||||||
// change the number of elements in a list.
|
// change the number of elements in a list.
|
||||||
l = append(l, ft.List.zero(p))
|
l = append(l, ft.ListElem.zeroExportValue())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return l
|
return l
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
var l []any
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
l = append(l, ft.ListElem.parseValue(p))
|
||||||
|
}
|
||||||
|
return l
|
||||||
case kindMap:
|
case kindMap:
|
||||||
un := p.Uvarint()
|
un := p.Uvarint()
|
||||||
n := p.checkInt(un)
|
n := p.checkInt(un)
|
||||||
|
@ -338,19 +357,19 @@ func (ft fieldType) parseValue(p *parser) any {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
v = ft.MapValue.parseValue(p)
|
v = ft.MapValue.parseValue(p)
|
||||||
} else {
|
} else {
|
||||||
v = ft.MapValue.zero(p)
|
v = ft.MapValue.zeroExportValue()
|
||||||
}
|
}
|
||||||
m[k] = v
|
m[k] = v
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
fm := p.Fieldmap(len(ft.Fields))
|
fm := p.Fieldmap(len(ft.structFields))
|
||||||
m := map[string]any{}
|
m := map[string]any{}
|
||||||
for i, f := range ft.Fields {
|
for i, f := range ft.structFields {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
m[f.Name] = f.Type.parseValue(p)
|
m[f.Name] = f.Type.parseValue(p)
|
||||||
} else {
|
} else {
|
||||||
m[f.Name] = f.Type.zero(p)
|
m[f.Name] = f.Type.zeroExportValue()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
|
@ -359,7 +378,7 @@ func (ft fieldType) parseValue(p *parser) any {
|
||||||
panic("cannot happen")
|
panic("cannot happen")
|
||||||
}
|
}
|
||||||
|
|
||||||
var zerovalues = map[kind]any{
|
var zeroExportValues = map[kind]any{
|
||||||
kindBytes: []byte(nil),
|
kindBytes: []byte(nil),
|
||||||
kindBinaryMarshal: []byte(nil), // We don't have the actual type available, so we just return binary data.
|
kindBinaryMarshal: []byte(nil), // We don't have the actual type available, so we just return binary data.
|
||||||
kindBool: false,
|
kindBool: false,
|
||||||
|
@ -380,12 +399,53 @@ var zerovalues = map[kind]any{
|
||||||
kindSlice: []any(nil),
|
kindSlice: []any(nil),
|
||||||
kindMap: map[string]any(nil),
|
kindMap: map[string]any(nil),
|
||||||
kindStruct: map[string]any(nil),
|
kindStruct: map[string]any(nil),
|
||||||
|
// kindArray handled in zeroExportValue()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ft fieldType) zero(p *parser) any {
|
// zeroExportValue returns the zero value for a fieldType for use with exporting.
|
||||||
v, ok := zerovalues[ft.Kind]
|
func (ft fieldType) zeroExportValue() any {
|
||||||
|
if ft.Kind == kindArray {
|
||||||
|
ev := ft.ListElem.zeroExportValue()
|
||||||
|
l := make([]any, ft.ArrayLength)
|
||||||
|
for i := 0; i < ft.ArrayLength; i++ {
|
||||||
|
l[i] = ev
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
v, ok := zeroExportValues[ft.Kind]
|
||||||
if !ok {
|
if !ok {
|
||||||
p.Errorf("internal error: unhandled zero value for field type %v", ft.Kind)
|
panic(fmt.Errorf("internal error: unhandled zero value for field type %v", ft.Kind))
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
var zeroKeys = map[kind]any{
|
||||||
|
kindBytes: []byte(nil),
|
||||||
|
kindBool: false,
|
||||||
|
kindInt8: int8(0),
|
||||||
|
kindInt16: int16(0),
|
||||||
|
kindInt32: int32(0),
|
||||||
|
kindInt: int(0),
|
||||||
|
kindInt64: int64(0),
|
||||||
|
kindUint8: uint8(0),
|
||||||
|
kindUint16: uint16(0),
|
||||||
|
kindUint32: uint32(0),
|
||||||
|
kindUint: uint(0),
|
||||||
|
kindUint64: uint64(0),
|
||||||
|
kindString: "",
|
||||||
|
kindTime: zerotime,
|
||||||
|
// kindSlice handled in zeroKeyValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroKeyValue returns the zero value for a fieldType for use with exporting.
|
||||||
|
func (ft fieldType) zeroKey() any {
|
||||||
|
k := ft.Kind
|
||||||
|
if k == kindSlice {
|
||||||
|
k = ft.ListElem.Kind
|
||||||
|
}
|
||||||
|
v, ok := zeroKeys[k]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("internal error: unhandled zero value for field type %v", ft.Kind))
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
8
vendor/github.com/mjl-/bstore/format.md
generated
vendored
8
vendor/github.com/mjl-/bstore/format.md
generated
vendored
|
@ -17,8 +17,8 @@ version is added to the "types" subbucket. Data is always inserted/updated with
|
||||||
the most recent type version. But the database may still hold data records
|
the most recent type version. But the database may still hold data records
|
||||||
referencing older type versions. Bstore decodes a packed data record with the
|
referencing older type versions. Bstore decodes a packed data record with the
|
||||||
referenced type version. For storage efficiency: the type version is reused for
|
referenced type version. For storage efficiency: the type version is reused for
|
||||||
many stored records, a self-describing format (like JSON) would duplicate the
|
many stored records, a self-describing format (like JSON) for each stored
|
||||||
field names in each stored record.
|
record would duplicate the field names in each stored record.
|
||||||
|
|
||||||
# Record storage
|
# Record storage
|
||||||
|
|
||||||
|
@ -51,6 +51,8 @@ more space than the single bit and are stored consecutively after the fieldmap:
|
||||||
the zero value marked in the fieldmap.
|
the zero value marked in the fieldmap.
|
||||||
- Slices use a uvarint for the number of elements, followed by a bitmap for
|
- Slices use a uvarint for the number of elements, followed by a bitmap for
|
||||||
nonzero values, followed by the encoded nonzero elements.
|
nonzero values, followed by the encoded nonzero elements.
|
||||||
|
- Arrays (fixed length) start with a bitmap for nonzero values, followed by
|
||||||
|
the encoded nonzero elements.
|
||||||
- Maps use a uvariant for the number of key/value pairs, followed by a
|
- Maps use a uvariant for the number of key/value pairs, followed by a
|
||||||
fieldmap for the values (the keys are always present), followed by each
|
fieldmap for the values (the keys are always present), followed by each
|
||||||
pair: key (always present), value (only if nonzero); key, value; etc.
|
pair: key (always present), value (only if nonzero); key, value; etc.
|
||||||
|
@ -71,7 +73,7 @@ unsigned integer, or between string and []byte.
|
||||||
Indexes are stored in subbuckets, named starting with "index." followed by the
|
Indexes are stored in subbuckets, named starting with "index." followed by the
|
||||||
index name. Keys are a self-delimiting encodings of the fields that make up the
|
index name. Keys are a self-delimiting encodings of the fields that make up the
|
||||||
key, followed by the primary key for the "records" bucket. Values are always
|
key, followed by the primary key for the "records" bucket. Values are always
|
||||||
empty in index buckets. For bool and integer types, the same fixed with
|
empty in index buckets. For bool and integer types, the same fixed width
|
||||||
encoding as for primary keys in the "records" subbucket is used. Strings are
|
encoding as for primary keys in the "records" subbucket is used. Strings are
|
||||||
encoded by their bytes (no \0 allowed) followed by a delimiting \0. Unlike
|
encoded by their bytes (no \0 allowed) followed by a delimiting \0. Unlike
|
||||||
primary keys, an index can cover a field with type time.Time. Times are encoded
|
primary keys, an index can cover a field with type time.Time. Times are encoded
|
||||||
|
|
176
vendor/github.com/mjl-/bstore/keys.go
generated
vendored
176
vendor/github.com/mjl-/bstore/keys.go
generated
vendored
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
/*
|
/*
|
||||||
The records buckets map a primary key to the record data. The primary key is of
|
The records buckets map a primary key to the record data. The primary key is of
|
||||||
a form that we can scan/range over. So fixed with for integers. For strings and
|
a form that we can scan/range over. So fixed width for integers. For strings and
|
||||||
bytes they are just their byte representation. We do not store the PK in the
|
bytes they are just their byte representation. We do not store the PK in the
|
||||||
record data. This means we cannot store a time.Time as primary key, because we
|
record data. This means we cannot store a time.Time as primary key, because we
|
||||||
cannot have the timezone encoded for comparison reasons.
|
cannot have the timezone encoded for comparison reasons.
|
||||||
|
@ -150,7 +150,12 @@ fields:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
switch f.Type.Kind {
|
ft := f.Type
|
||||||
|
if ft.Kind == kindSlice {
|
||||||
|
// For an index on a slice, we store each value in the slice in a separate index key.
|
||||||
|
ft = *ft.ListElem
|
||||||
|
}
|
||||||
|
switch ft.Kind {
|
||||||
case kindString:
|
case kindString:
|
||||||
for i, b := range buf {
|
for i, b := range buf {
|
||||||
if b == 0 {
|
if b == 0 {
|
||||||
|
@ -174,6 +179,8 @@ fields:
|
||||||
take(8)
|
take(8)
|
||||||
case kindTime:
|
case kindTime:
|
||||||
take(8 + 4)
|
take(8 + 4)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("%w: unhandled kind %v for index key", ErrStore, ft.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -203,9 +210,14 @@ fields:
|
||||||
return pk, nil, nil
|
return pk, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// packKey returns a key to store in an index: first the prefix without pk, then
|
type indexkey struct {
|
||||||
// the prefix including pk.
|
pre []byte // Packed fields excluding PK, a slice of full.
|
||||||
func (idx *index) packKey(rv reflect.Value, pk []byte) ([]byte, []byte, error) {
|
full []byte // Packed fields including PK.
|
||||||
|
}
|
||||||
|
|
||||||
|
// packKey returns keys to store in an index: first the key prefixes without pk, then
|
||||||
|
// the prefixes including pk.
|
||||||
|
func (idx *index) packKey(rv reflect.Value, pk []byte) ([]indexkey, error) {
|
||||||
var l []reflect.Value
|
var l []reflect.Value
|
||||||
for _, f := range idx.Fields {
|
for _, f := range idx.Fields {
|
||||||
frv := rv.FieldByIndex(f.structField.Index)
|
frv := rv.FieldByIndex(f.structField.Index)
|
||||||
|
@ -215,68 +227,108 @@ func (idx *index) packKey(rv reflect.Value, pk []byte) ([]byte, []byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// packIndexKeys packs values from l, followed by the pk.
|
// packIndexKeys packs values from l, followed by the pk.
|
||||||
// It returns the key prefix (without pk), and full key with pk.
|
// It returns the key prefixes (without pk), and full keys with pk.
|
||||||
func packIndexKeys(l []reflect.Value, pk []byte) ([]byte, []byte, error) {
|
func packIndexKeys(l []reflect.Value, pk []byte) ([]indexkey, error) {
|
||||||
var prek, ik []byte
|
ikl := []indexkey{{}}
|
||||||
for _, frv := range l {
|
for _, frv := range l {
|
||||||
k, err := typeKind(frv.Type())
|
bufs, err := packIndexKey(frv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var buf []byte
|
|
||||||
switch k {
|
if len(bufs) == 1 {
|
||||||
case kindBool:
|
for i := range ikl {
|
||||||
buf = []byte{0}
|
ikl[i].full = append(ikl[i].full, bufs[0]...)
|
||||||
if frv.Bool() {
|
|
||||||
buf[0] = 1
|
|
||||||
}
|
}
|
||||||
case kindInt8:
|
} else if len(ikl) == 1 && len(bufs) > 1 {
|
||||||
buf = []byte{byte(int8(frv.Int()) + math.MinInt8)}
|
nikl := make([]indexkey, len(bufs))
|
||||||
case kindInt16:
|
for i, buf := range bufs {
|
||||||
buf = binary.BigEndian.AppendUint16(nil, uint16(int16(frv.Int())+math.MinInt16))
|
nikl[i] = indexkey{full: append(append([]byte{}, ikl[0].full...), buf...)}
|
||||||
case kindInt32:
|
|
||||||
buf = binary.BigEndian.AppendUint32(nil, uint32(int32(frv.Int())+math.MinInt32))
|
|
||||||
case kindInt:
|
|
||||||
i := frv.Int()
|
|
||||||
if i < math.MinInt32 || i > math.MaxInt32 {
|
|
||||||
return nil, nil, fmt.Errorf("%w: int value %d does not fit in int32", ErrParam, i)
|
|
||||||
}
|
}
|
||||||
buf = binary.BigEndian.AppendUint32(nil, uint32(int32(i)+math.MinInt32))
|
ikl = nikl
|
||||||
case kindInt64:
|
} else if len(bufs) == 0 {
|
||||||
buf = binary.BigEndian.AppendUint64(nil, uint64(frv.Int()+math.MinInt64))
|
return nil, nil
|
||||||
case kindUint8:
|
} else {
|
||||||
buf = []byte{byte(frv.Uint())}
|
return nil, fmt.Errorf("%w: multiple index fields that result in multiple values, or no data for index key, %d keys so far, %d new buffers", ErrStore, len(ikl), len(bufs))
|
||||||
case kindUint16:
|
|
||||||
buf = binary.BigEndian.AppendUint16(nil, uint16(frv.Uint()))
|
|
||||||
case kindUint32:
|
|
||||||
buf = binary.BigEndian.AppendUint32(nil, uint32(frv.Uint()))
|
|
||||||
case kindUint:
|
|
||||||
i := frv.Uint()
|
|
||||||
if i > math.MaxUint32 {
|
|
||||||
return nil, nil, fmt.Errorf("%w: uint value %d does not fit in uint32", ErrParam, i)
|
|
||||||
}
|
|
||||||
buf = binary.BigEndian.AppendUint32(nil, uint32(i))
|
|
||||||
case kindUint64:
|
|
||||||
buf = binary.BigEndian.AppendUint64(nil, uint64(frv.Uint()))
|
|
||||||
case kindString:
|
|
||||||
buf = []byte(frv.String())
|
|
||||||
for _, c := range buf {
|
|
||||||
if c == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("%w: string used as index key cannot have \\0", ErrParam)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf = append(buf, 0)
|
|
||||||
case kindTime:
|
|
||||||
tm := frv.Interface().(time.Time)
|
|
||||||
buf = binary.BigEndian.AppendUint64(nil, uint64(tm.Unix()+math.MinInt64))
|
|
||||||
buf = binary.BigEndian.AppendUint32(buf, uint32(tm.Nanosecond()))
|
|
||||||
default:
|
|
||||||
return nil, nil, fmt.Errorf("internal error: bad type %v for index", frv.Type()) // todo: should be caught when making index type
|
|
||||||
}
|
}
|
||||||
ik = append(ik, buf...)
|
|
||||||
}
|
}
|
||||||
n := len(ik)
|
for i := range ikl {
|
||||||
ik = append(ik, pk...)
|
n := len(ikl[i].full)
|
||||||
prek = ik[:n]
|
ikl[i].full = append(ikl[i].full, pk...)
|
||||||
return prek, ik, nil
|
ikl[i].pre = ikl[i].full[:n]
|
||||||
|
}
|
||||||
|
return ikl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func packIndexKey(frv reflect.Value) ([][]byte, error) {
|
||||||
|
k, err := typeKind(frv.Type())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf []byte
|
||||||
|
switch k {
|
||||||
|
case kindBool:
|
||||||
|
buf = []byte{0}
|
||||||
|
if frv.Bool() {
|
||||||
|
buf[0] = 1
|
||||||
|
}
|
||||||
|
case kindInt8:
|
||||||
|
buf = []byte{byte(int8(frv.Int()) + math.MinInt8)}
|
||||||
|
case kindInt16:
|
||||||
|
buf = binary.BigEndian.AppendUint16(nil, uint16(int16(frv.Int())+math.MinInt16))
|
||||||
|
case kindInt32:
|
||||||
|
buf = binary.BigEndian.AppendUint32(nil, uint32(int32(frv.Int())+math.MinInt32))
|
||||||
|
case kindInt:
|
||||||
|
i := frv.Int()
|
||||||
|
if i < math.MinInt32 || i > math.MaxInt32 {
|
||||||
|
return nil, fmt.Errorf("%w: int value %d does not fit in int32", ErrParam, i)
|
||||||
|
}
|
||||||
|
buf = binary.BigEndian.AppendUint32(nil, uint32(int32(i)+math.MinInt32))
|
||||||
|
case kindInt64:
|
||||||
|
buf = binary.BigEndian.AppendUint64(nil, uint64(frv.Int()+math.MinInt64))
|
||||||
|
case kindUint8:
|
||||||
|
buf = []byte{byte(frv.Uint())}
|
||||||
|
case kindUint16:
|
||||||
|
buf = binary.BigEndian.AppendUint16(nil, uint16(frv.Uint()))
|
||||||
|
case kindUint32:
|
||||||
|
buf = binary.BigEndian.AppendUint32(nil, uint32(frv.Uint()))
|
||||||
|
case kindUint:
|
||||||
|
i := frv.Uint()
|
||||||
|
if i > math.MaxUint32 {
|
||||||
|
return nil, fmt.Errorf("%w: uint value %d does not fit in uint32", ErrParam, i)
|
||||||
|
}
|
||||||
|
buf = binary.BigEndian.AppendUint32(nil, uint32(i))
|
||||||
|
case kindUint64:
|
||||||
|
buf = binary.BigEndian.AppendUint64(nil, uint64(frv.Uint()))
|
||||||
|
case kindString:
|
||||||
|
buf = []byte(frv.String())
|
||||||
|
for _, c := range buf {
|
||||||
|
if c == 0 {
|
||||||
|
return nil, fmt.Errorf("%w: string used as index key cannot have \\0", ErrParam)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf = append(buf, 0)
|
||||||
|
case kindTime:
|
||||||
|
tm := frv.Interface().(time.Time)
|
||||||
|
buf = binary.BigEndian.AppendUint64(nil, uint64(tm.Unix()+math.MinInt64))
|
||||||
|
buf = binary.BigEndian.AppendUint32(buf, uint32(tm.Nanosecond()))
|
||||||
|
case kindSlice:
|
||||||
|
n := frv.Len()
|
||||||
|
bufs := make([][]byte, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
nbufs, err := packIndexKey(frv.Index(i))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("packing element from slice field: %w", err)
|
||||||
|
}
|
||||||
|
if len(nbufs) != 1 {
|
||||||
|
return nil, fmt.Errorf("packing element from slice field resulted in multiple buffers (%d)", len(bufs))
|
||||||
|
}
|
||||||
|
bufs[i] = nbufs[0]
|
||||||
|
}
|
||||||
|
return bufs, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("internal error: bad type %v for index", frv.Type()) // todo: should be caught when making index type
|
||||||
|
}
|
||||||
|
return [][]byte{buf}, nil
|
||||||
}
|
}
|
||||||
|
|
393
vendor/github.com/mjl-/bstore/nonzero.go
generated
vendored
393
vendor/github.com/mjl-/bstore/nonzero.go
generated
vendored
|
@ -12,207 +12,324 @@ func (ft fieldType) isZero(v reflect.Value) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if ft.Ptr {
|
if ft.Ptr {
|
||||||
return v.IsNil()
|
return v.IsZero()
|
||||||
}
|
}
|
||||||
switch ft.Kind {
|
switch ft.Kind {
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
for _, f := range ft.Fields {
|
for _, f := range ft.structFields {
|
||||||
if !f.Type.isZero(v.FieldByIndex(f.structField.Index)) {
|
if !f.Type.isZero(v.FieldByIndex(f.structField.Index)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use standard IsZero otherwise, also for kindBinaryMarshal.
|
// Use standard IsZero otherwise, also for kindBinaryMarshal.
|
||||||
return v.IsZero()
|
return v.IsZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkNonzero compare ofields and nfields (from previous type schema vs newly
|
// We ensure nonzero constraints when opening a database. An updated schema, with
|
||||||
// created type schema) for nonzero struct tag. If an existing field got a
|
// added nonzero constraints, can mean all records have to be checked. With cyclic
|
||||||
// nonzero struct tag added, we verify that there are indeed no nonzero values
|
// types, we have to take care not to recurse, and for efficiency we want to only
|
||||||
// in the database. If there are, we return ErrZero.
|
// check fields/types that are affected. Steps:
|
||||||
|
//
|
||||||
|
// - Go through each field of the struct, and recurse into the field types,
|
||||||
|
// gathering the types and newly nonzero fields.
|
||||||
|
// - Propagate the need for nonzero checks to types that reference the changed
|
||||||
|
// types.
|
||||||
|
// - By now, if there was a new nonzero constraint, the top-level type will be
|
||||||
|
// marked as needing a check, so we'll read through all records and check all the
|
||||||
|
// immediate newly nonzero fields of a type, and recurse into fields of types that
|
||||||
|
// are marked as needing a check.
|
||||||
|
|
||||||
|
// nonzeroCheckType is tracked per reflect.Type that has been analysed (always the
|
||||||
|
// non-pointer type, i.e. a pointer is dereferenced). These types can be cyclic. We
|
||||||
|
// gather them for all types involved, including map and slice types and basic
|
||||||
|
// types, but "newlyNonzero" and "fields" will only be set for structs.
|
||||||
|
type nonzeroCheckType struct {
|
||||||
|
needsCheck bool
|
||||||
|
|
||||||
|
newlyNonzero []field // Fields in this type that have a new nonzero constraint themselves.
|
||||||
|
fields []field // All fields in a struct type.
|
||||||
|
|
||||||
|
// Types that reference this type. Used to propagate needsCheck to the top.
|
||||||
|
referencedBy map[reflect.Type]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ct *nonzeroCheckType) markRefBy(t reflect.Type) {
|
||||||
|
if t != nil {
|
||||||
|
ct.referencedBy[t] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNonzero compares ofields (optional previous type schema) and nfields (new
|
||||||
|
// type schema) for nonzero struct tags. If an existing field has a new nonzero
|
||||||
|
// constraint, we verify that there are indeed no nonzero values in the existing
|
||||||
|
// records. If there are, we return ErrZero. checkNonzero looks at (potentially
|
||||||
|
// cyclic) types referenced by fields.
|
||||||
func (tx *Tx) checkNonzero(st storeType, tv *typeVersion, ofields, nfields []field) error {
|
func (tx *Tx) checkNonzero(st storeType, tv *typeVersion, ofields, nfields []field) error {
|
||||||
// First we gather paths that we need to check, so we can later simply
|
// Gather all new nonzero constraints on fields.
|
||||||
// execute those steps on all data we need to read.
|
m := map[reflect.Type]*nonzeroCheckType{}
|
||||||
paths := &follows{}
|
nonzeroCheckGather(m, st.Type, nil, ofields, nfields)
|
||||||
next:
|
|
||||||
for _, f := range nfields {
|
// Propagate the need for a check on all types due to a referenced type having a
|
||||||
for _, of := range ofields {
|
// new nonzero constraint.
|
||||||
if f.Name == of.Name {
|
// todo: this can probably be done more elegantly, with fewer graph walks...
|
||||||
err := f.checkNonzeroGather(&of, paths)
|
for t, ct := range m {
|
||||||
if err != nil {
|
if ct.needsCheck {
|
||||||
return err
|
nonzeroCheckPropagate(m, t, t, ct)
|
||||||
}
|
|
||||||
continue next
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := f.checkNonzeroGather(nil, paths); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(paths.paths) == 0 {
|
// If needsCheck wasn't propagated to the top-level, there was no new nonzero
|
||||||
// Common case, not reading all data.
|
// constraint, and we're not going to read all the data. This is the common case
|
||||||
|
// when opening a database.
|
||||||
|
if !m[st.Type].needsCheck {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally actually do the checks.
|
// Read through all data, and check the new nonzero constraint.
|
||||||
// todo: if there are only top-level fields to check, and we have an index, we can use the index check this without reading all data.
|
// todo optimize: if there are only top-level fields to check, and we have indices on those fields, we can use the index to check this without reading all data.
|
||||||
return tx.checkNonzeroPaths(st, tv, paths.paths)
|
return checkNonzeroRecords(tx, st, tv, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
type follow struct {
|
// Walk down fields, gathering their types (including those they reference), and
|
||||||
mapKey, mapValue bool
|
// marking needsCheck if any of a type's immediate field has a new nonzero
|
||||||
field field
|
// constraint. The need for a check is not propagated to referencing types by this
|
||||||
}
|
// function.
|
||||||
|
func nonzeroCheckGather(m map[reflect.Type]*nonzeroCheckType, t, refBy reflect.Type, ofields, nfields []field) {
|
||||||
type follows struct {
|
ct := m[t]
|
||||||
current []follow
|
if ct != nil {
|
||||||
paths [][]follow
|
// Already gathered, don't recurse, for cyclic types.
|
||||||
}
|
ct.markRefBy(refBy)
|
||||||
|
return
|
||||||
func (f *follows) push(ff follow) {
|
|
||||||
f.current = append(f.current, ff)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *follows) pop() {
|
|
||||||
f.current = f.current[:len(f.current)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *follows) add() {
|
|
||||||
f.paths = append(f.paths, append([]follow{}, f.current...))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f field) checkNonzeroGather(of *field, paths *follows) error {
|
|
||||||
paths.push(follow{field: f})
|
|
||||||
defer paths.pop()
|
|
||||||
if f.Nonzero && (of == nil || !of.Nonzero) {
|
|
||||||
paths.add()
|
|
||||||
}
|
}
|
||||||
if of != nil {
|
ct = &nonzeroCheckType{
|
||||||
return f.Type.checkNonzeroGather(of.Type, paths)
|
fields: nfields,
|
||||||
|
referencedBy: map[reflect.Type]struct{}{},
|
||||||
}
|
}
|
||||||
return nil
|
ct.markRefBy(refBy)
|
||||||
}
|
m[t] = ct
|
||||||
|
|
||||||
func (ft fieldType) checkNonzeroGather(oft fieldType, paths *follows) error {
|
for _, f := range nfields {
|
||||||
switch ft.Kind {
|
// Check if this field is newly nonzero.
|
||||||
case kindMap:
|
var of *field
|
||||||
paths.push(follow{mapKey: true})
|
for i := range ofields {
|
||||||
if err := ft.MapKey.checkNonzeroGather(*oft.MapKey, paths); err != nil {
|
if f.Name == ofields[i].Name {
|
||||||
return err
|
of = &ofields[i]
|
||||||
}
|
// Compare with existing field.
|
||||||
paths.pop()
|
if f.Nonzero && !of.Nonzero {
|
||||||
|
ct.newlyNonzero = append(ct.newlyNonzero, f)
|
||||||
paths.push(follow{mapValue: true})
|
ct.needsCheck = true
|
||||||
if err := ft.MapValue.checkNonzeroGather(*oft.MapValue, paths); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
paths.pop()
|
|
||||||
|
|
||||||
case kindSlice:
|
|
||||||
err := ft.List.checkNonzeroGather(*oft.List, paths)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case kindStruct:
|
|
||||||
next:
|
|
||||||
for _, ff := range ft.Fields {
|
|
||||||
for _, off := range oft.Fields {
|
|
||||||
if ff.Name == off.Name {
|
|
||||||
err := ff.checkNonzeroGather(&off, paths)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
continue next
|
|
||||||
}
|
}
|
||||||
}
|
break
|
||||||
err := ff.checkNonzeroGather(nil, paths)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Check if this is a new field entirely, with nonzero constraint.
|
||||||
|
if of == nil && f.Nonzero {
|
||||||
|
ct.newlyNonzero = append(ct.newlyNonzero, f)
|
||||||
|
ct.needsCheck = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Descend into referenced types, adding references back to this type.
|
||||||
|
var oft *fieldType
|
||||||
|
if of != nil {
|
||||||
|
oft = &of.Type
|
||||||
|
}
|
||||||
|
ft := f.structField.Type
|
||||||
|
nonzeroCheckGatherFieldType(m, ft, t, oft, f.Type)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkNonzero reads through all records of a type, and checks that the fields
|
// gather new nonzero constraints for type "t", which is referenced by "refBy" (and
|
||||||
|
// will be marked as such). type "t" is described by "nft" and optionally
|
||||||
|
// previously by "oft".
|
||||||
|
func nonzeroCheckGatherFieldType(m map[reflect.Type]*nonzeroCheckType, t, refBy reflect.Type, oft *fieldType, nft fieldType) {
|
||||||
|
// If this is a pointer type, dereference the reflect type.
|
||||||
|
if nft.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if nft.Kind == kindStruct {
|
||||||
|
var fofields []field
|
||||||
|
if oft != nil {
|
||||||
|
fofields = oft.structFields
|
||||||
|
}
|
||||||
|
nonzeroCheckGather(m, t, refBy, fofields, nft.structFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark this type as gathered, so we don't process it again if we recurse.
|
||||||
|
ct := m[t]
|
||||||
|
if ct != nil {
|
||||||
|
ct.markRefBy(refBy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ct = &nonzeroCheckType{
|
||||||
|
fields: nft.structFields,
|
||||||
|
referencedBy: map[reflect.Type]struct{}{},
|
||||||
|
}
|
||||||
|
ct.markRefBy(refBy)
|
||||||
|
m[t] = ct
|
||||||
|
|
||||||
|
switch nft.Kind {
|
||||||
|
case kindMap:
|
||||||
|
var koft, voft *fieldType
|
||||||
|
if oft != nil {
|
||||||
|
koft = oft.MapKey
|
||||||
|
voft = oft.MapValue
|
||||||
|
}
|
||||||
|
nonzeroCheckGatherFieldType(m, t.Key(), t, koft, *nft.MapKey)
|
||||||
|
nonzeroCheckGatherFieldType(m, t.Elem(), t, voft, *nft.MapValue)
|
||||||
|
case kindSlice:
|
||||||
|
var loft *fieldType
|
||||||
|
if oft != nil {
|
||||||
|
loft = oft.ListElem
|
||||||
|
}
|
||||||
|
nonzeroCheckGatherFieldType(m, t.Elem(), t, loft, *nft.ListElem)
|
||||||
|
case kindArray:
|
||||||
|
var loft *fieldType
|
||||||
|
if oft != nil {
|
||||||
|
loft = oft.ListElem
|
||||||
|
}
|
||||||
|
nonzeroCheckGatherFieldType(m, t.Elem(), t, loft, *nft.ListElem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate that type "t" is affected by a new nonzero constrained and needs to be
|
||||||
|
// checked. The types referencing "t" are in ct.referencedBy. "origt" is the
|
||||||
|
// starting type for this propagation.
|
||||||
|
func nonzeroCheckPropagate(m map[reflect.Type]*nonzeroCheckType, origt, t reflect.Type, ct *nonzeroCheckType) {
|
||||||
|
for rt := range ct.referencedBy {
|
||||||
|
if rt == origt {
|
||||||
|
continue // End recursion.
|
||||||
|
}
|
||||||
|
m[rt].needsCheck = true
|
||||||
|
nonzeroCheckPropagate(m, origt, rt, m[rt])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNonzeroPaths reads through all records of a type, and checks that the fields
|
||||||
// indicated by paths are nonzero. If not, ErrZero is returned.
|
// indicated by paths are nonzero. If not, ErrZero is returned.
|
||||||
func (tx *Tx) checkNonzeroPaths(st storeType, tv *typeVersion, paths [][]follow) error {
|
func checkNonzeroRecords(tx *Tx, st storeType, tv *typeVersion, m map[reflect.Type]*nonzeroCheckType) error {
|
||||||
rb, err := tx.recordsBucket(st.Current.name, st.Current.fillPercent)
|
rb, err := tx.recordsBucket(st.Current.name, st.Current.fillPercent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctxDone := tx.ctx.Done()
|
||||||
|
|
||||||
return rb.ForEach(func(bk, bv []byte) error {
|
return rb.ForEach(func(bk, bv []byte) error {
|
||||||
tx.stats.Records.Cursor++
|
tx.stats.Records.Cursor++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctxDone:
|
||||||
|
return tx.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo optimize: instead of parsing the full record, use the fieldmap to see if the value is nonzero.
|
||||||
rv, err := st.parseNew(bk, bv)
|
rv, err := st.parseNew(bk, bv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// todo optimization: instead of parsing the full record, use the fieldmap to see if the value is nonzero.
|
ct := m[st.Type]
|
||||||
for _, path := range paths {
|
return checkNonzeroFields(m, st.Type, ct.newlyNonzero, ct.fields, rv)
|
||||||
frv := rv.FieldByIndex(path[0].field.structField.Index)
|
|
||||||
if err := path[0].field.checkNonzero(frv, path[1:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f field) checkNonzero(rv reflect.Value, path []follow) error {
|
// checkNonzeroFields checks that the newly nonzero fields of a struct value are
|
||||||
if len(path) == 0 {
|
// indeed nonzero, and walks down referenced types, checking the constraint.
|
||||||
if !f.Nonzero {
|
func checkNonzeroFields(m map[reflect.Type]*nonzeroCheckType, t reflect.Type, newlyNonzero, fields []field, rv reflect.Value) error {
|
||||||
return fmt.Errorf("internal error: checkNonzero: expected field to have Nonzero set")
|
// Check the newly nonzero fields.
|
||||||
}
|
for _, f := range newlyNonzero {
|
||||||
if f.Type.isZero(rv) {
|
frv := rv.FieldByIndex(f.structField.Index)
|
||||||
|
if f.Type.isZero(frv) {
|
||||||
return fmt.Errorf("%w: field %q", ErrZero, f.Name)
|
return fmt.Errorf("%w: field %q", ErrZero, f.Name)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return f.Type.checkNonzero(rv, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft fieldType) checkNonzero(rv reflect.Value, path []follow) error {
|
// Descend into referenced types.
|
||||||
switch ft.Kind {
|
for _, f := range fields {
|
||||||
case kindMap:
|
switch f.Type.Kind {
|
||||||
follow := path[0]
|
case kindMap, kindSlice, kindStruct, kindArray:
|
||||||
path = path[1:]
|
ft := f.structField.Type
|
||||||
key := follow.mapKey
|
if err := checkNonzeroFieldType(m, f.Type, ft, rv.FieldByIndex(f.structField.Index)); err != nil {
|
||||||
if !key && !follow.mapValue {
|
|
||||||
return fmt.Errorf("internal error: following map, expected mapKey or mapValue, got %#v", follow)
|
|
||||||
}
|
|
||||||
|
|
||||||
iter := rv.MapRange()
|
|
||||||
for iter.Next() {
|
|
||||||
var err error
|
|
||||||
if key {
|
|
||||||
err = ft.MapKey.checkNonzero(iter.Key(), path)
|
|
||||||
} else {
|
|
||||||
err = ft.MapValue.checkNonzero(iter.Value(), path)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNonzeroFieldType walks down a value, and checks that its (struct) types
|
||||||
|
// don't violate nonzero constraints.
|
||||||
|
// Does not check whether the value itself is nonzero. If required, that has
|
||||||
|
// already been checked.
|
||||||
|
func checkNonzeroFieldType(m map[reflect.Type]*nonzeroCheckType, ft fieldType, t reflect.Type, rv reflect.Value) error {
|
||||||
|
if ft.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
if !m[t].needsCheck {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ft.Ptr && rv.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ft.Ptr {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
unptr := func(t reflect.Type, ptr bool) reflect.Type {
|
||||||
|
if ptr {
|
||||||
|
return t.Elem()
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ft.Kind {
|
||||||
|
case kindMap:
|
||||||
|
kt := t.Key()
|
||||||
|
vt := t.Elem()
|
||||||
|
checkKey := m[unptr(kt, ft.MapKey.Ptr)].needsCheck
|
||||||
|
checkValue := m[unptr(vt, ft.MapValue.Ptr)].needsCheck
|
||||||
|
iter := rv.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
if checkKey {
|
||||||
|
if err := checkNonzeroFieldType(m, *ft.MapKey, kt, iter.Key()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if checkValue {
|
||||||
|
if err := checkNonzeroFieldType(m, *ft.MapValue, vt, iter.Value()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case kindSlice:
|
case kindSlice:
|
||||||
|
et := t.Elem()
|
||||||
n := rv.Len()
|
n := rv.Len()
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if err := ft.List.checkNonzero(rv.Index(i), path); err != nil {
|
if err := checkNonzeroFieldType(m, *ft.ListElem, et, rv.Index(i)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case kindArray:
|
||||||
|
et := t.Elem()
|
||||||
|
n := ft.ArrayLength
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if err := checkNonzeroFieldType(m, *ft.ListElem, et, rv.Index(i)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
follow := path[0]
|
ct := m[t]
|
||||||
path = path[1:]
|
if err := checkNonzeroFields(m, t, ct.newlyNonzero, ct.fields, rv); err != nil {
|
||||||
frv := rv.FieldByIndex(follow.field.structField.Index)
|
|
||||||
if err := follow.field.checkNonzero(frv, path); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return fmt.Errorf("internal error: checkNonzero with non-empty path, but kind %v", ft.Kind)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
64
vendor/github.com/mjl-/bstore/pack.go
generated
vendored
64
vendor/github.com/mjl-/bstore/pack.go
generated
vendored
|
@ -17,14 +17,14 @@ type fieldmap struct {
|
||||||
buf []byte // Bitmap, we write the next 0/1 at bit n.
|
buf []byte // Bitmap, we write the next 0/1 at bit n.
|
||||||
n int // Fields seen so far.
|
n int // Fields seen so far.
|
||||||
offset int // In final output, we write buf back after finish. Only relevant for packing.
|
offset int // In final output, we write buf back after finish. Only relevant for packing.
|
||||||
Errorf func(format string, args ...any)
|
errorf func(format string, args ...any)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add bit to fieldmap indicating if the field is nonzero.
|
// add bit to fieldmap indicating if the field is nonzero.
|
||||||
func (f *fieldmap) Field(nonzero bool) {
|
func (f *fieldmap) Field(nonzero bool) {
|
||||||
o := f.n / 8
|
o := f.n / 8
|
||||||
if f.n >= f.max {
|
if f.n >= f.max {
|
||||||
f.Errorf("internal error: too many fields, max %d", f.max)
|
f.errorf("internal error: too many fields, max %d", f.max)
|
||||||
}
|
}
|
||||||
if nonzero {
|
if nonzero {
|
||||||
f.buf[o] |= 1 << (7 - f.n%8)
|
f.buf[o] |= 1 << (7 - f.n%8)
|
||||||
|
@ -46,7 +46,7 @@ type packer struct {
|
||||||
popped []*fieldmap // Completed fieldmaps, to be written back during finish.
|
popped []*fieldmap // Completed fieldmaps, to be written back during finish.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packer) Errorf(format string, args ...any) {
|
func (p *packer) errorf(format string, args ...any) {
|
||||||
panic(packErr{fmt.Errorf(format, args...)})
|
panic(packErr{fmt.Errorf(format, args...)})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ func (p *packer) Errorf(format string, args ...any) {
|
||||||
func (p *packer) PushFieldmap(n int) {
|
func (p *packer) PushFieldmap(n int) {
|
||||||
p.fieldmaps = append(p.fieldmaps, p.fieldmap)
|
p.fieldmaps = append(p.fieldmaps, p.fieldmap)
|
||||||
buf := make([]byte, (n+7)/8)
|
buf := make([]byte, (n+7)/8)
|
||||||
p.fieldmap = &fieldmap{max: n, buf: buf, offset: p.offset, Errorf: p.Errorf}
|
p.fieldmap = &fieldmap{max: n, buf: buf, offset: p.offset, errorf: p.errorf}
|
||||||
p.Write(buf) // Updates offset. Write errors cause panic.
|
p.Write(buf) // Updates offset. Write errors cause panic.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ func (p *packer) PushFieldmap(n int) {
|
||||||
// bytes during finish.
|
// bytes during finish.
|
||||||
func (p *packer) PopFieldmap() {
|
func (p *packer) PopFieldmap() {
|
||||||
if p.fieldmap.n != p.fieldmap.max {
|
if p.fieldmap.n != p.fieldmap.max {
|
||||||
p.Errorf("internal error: fieldmap n %d != max %d", p.fieldmap.n, p.fieldmap.max)
|
p.errorf("internal error: fieldmap n %d != max %d", p.fieldmap.n, p.fieldmap.max)
|
||||||
}
|
}
|
||||||
p.popped = append(p.popped, p.fieldmap)
|
p.popped = append(p.popped, p.fieldmap)
|
||||||
p.fieldmap = p.fieldmaps[len(p.fieldmaps)-1]
|
p.fieldmap = p.fieldmaps[len(p.fieldmaps)-1]
|
||||||
|
@ -73,7 +73,7 @@ func (p *packer) PopFieldmap() {
|
||||||
// returning the final bytes representation of this record.
|
// returning the final bytes representation of this record.
|
||||||
func (p *packer) Finish() []byte {
|
func (p *packer) Finish() []byte {
|
||||||
if p.fieldmap != nil {
|
if p.fieldmap != nil {
|
||||||
p.Errorf("internal error: leftover fieldmap during finish")
|
p.errorf("internal error: leftover fieldmap during finish")
|
||||||
}
|
}
|
||||||
buf := p.b.Bytes()
|
buf := p.b.Bytes()
|
||||||
for _, f := range p.popped {
|
for _, f := range p.popped {
|
||||||
|
@ -90,7 +90,7 @@ func (p *packer) Field(nonzero bool) {
|
||||||
func (p *packer) Write(buf []byte) (int, error) {
|
func (p *packer) Write(buf []byte) (int, error) {
|
||||||
n, err := p.b.Write(buf)
|
n, err := p.b.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.Errorf("write: %w", err)
|
p.errorf("write: %w", err)
|
||||||
}
|
}
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
p.offset += n
|
p.offset += n
|
||||||
|
@ -149,11 +149,12 @@ func (tv typeVersion) pack(p *packer, rv reflect.Value) {
|
||||||
nrv := rv.FieldByIndex(f.structField.Index)
|
nrv := rv.FieldByIndex(f.structField.Index)
|
||||||
if f.Type.isZero(nrv) {
|
if f.Type.isZero(nrv) {
|
||||||
if f.Nonzero {
|
if f.Nonzero {
|
||||||
p.Errorf("%w: %q", ErrZero, f.Name)
|
p.errorf("%w: %q", ErrZero, f.Name)
|
||||||
}
|
}
|
||||||
p.Field(false)
|
p.Field(false)
|
||||||
// Pretend to pack to get the nonzero checks.
|
// Pretend to pack to get the nonzero checks.
|
||||||
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsNil()) {
|
// todo: we should be able to do nonzero-check without pretending to pack.
|
||||||
|
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsZero()) {
|
||||||
f.Type.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
f.Type.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -176,7 +177,7 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
v := rv
|
v := rv
|
||||||
buf, err := v.Interface().(encoding.BinaryMarshaler).MarshalBinary()
|
buf, err := v.Interface().(encoding.BinaryMarshaler).MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.Errorf("marshalbinary: %w", err)
|
p.errorf("marshalbinary: %w", err)
|
||||||
}
|
}
|
||||||
p.AddBytes(buf)
|
p.AddBytes(buf)
|
||||||
case kindBool:
|
case kindBool:
|
||||||
|
@ -192,7 +193,7 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
case kindInt:
|
case kindInt:
|
||||||
v := rv.Int()
|
v := rv.Int()
|
||||||
if v < math.MinInt32 || v > math.MaxInt32 {
|
if v < math.MinInt32 || v > math.MaxInt32 {
|
||||||
p.Errorf("%w: int %d does not fit in int32", ErrParam, v)
|
p.errorf("%w: int %d does not fit in int32", ErrParam, v)
|
||||||
}
|
}
|
||||||
p.Varint(v)
|
p.Varint(v)
|
||||||
case kindInt8, kindInt16, kindInt32, kindInt64:
|
case kindInt8, kindInt16, kindInt32, kindInt64:
|
||||||
|
@ -202,7 +203,7 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
case kindUint:
|
case kindUint:
|
||||||
v := rv.Uint()
|
v := rv.Uint()
|
||||||
if v > math.MaxUint32 {
|
if v > math.MaxUint32 {
|
||||||
p.Errorf("%w: uint %d does not fit in uint32", ErrParam, v)
|
p.errorf("%w: uint %d does not fit in uint32", ErrParam, v)
|
||||||
}
|
}
|
||||||
p.Uvarint(v)
|
p.Uvarint(v)
|
||||||
case kindFloat32:
|
case kindFloat32:
|
||||||
|
@ -214,7 +215,7 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
case kindTime:
|
case kindTime:
|
||||||
buf, err := rv.Interface().(time.Time).MarshalBinary()
|
buf, err := rv.Interface().(time.Time).MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.Errorf("%w: pack time: %s", ErrParam, err)
|
p.errorf("%w: pack time: %s", ErrParam, err)
|
||||||
}
|
}
|
||||||
p.AddBytes(buf)
|
p.AddBytes(buf)
|
||||||
case kindSlice:
|
case kindSlice:
|
||||||
|
@ -223,15 +224,32 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
p.PushFieldmap(n)
|
p.PushFieldmap(n)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
nrv := rv.Index(i)
|
nrv := rv.Index(i)
|
||||||
if ft.List.isZero(nrv) {
|
if ft.ListElem.isZero(nrv) {
|
||||||
p.Field(false)
|
p.Field(false)
|
||||||
// Pretend to pack to get the nonzero checks of the element.
|
// Pretend to pack to get the nonzero checks of the element.
|
||||||
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsNil()) {
|
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsZero()) {
|
||||||
ft.List.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
ft.ListElem.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
p.Field(true)
|
p.Field(true)
|
||||||
ft.List.pack(p, nrv)
|
ft.ListElem.pack(p, nrv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.PopFieldmap()
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
p.PushFieldmap(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
nrv := rv.Index(i)
|
||||||
|
if ft.ListElem.isZero(nrv) {
|
||||||
|
p.Field(false)
|
||||||
|
// Pretend to pack to get the nonzero checks of the element.
|
||||||
|
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsZero()) {
|
||||||
|
ft.ListElem.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.Field(true)
|
||||||
|
ft.ListElem.pack(p, nrv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.PopFieldmap()
|
p.PopFieldmap()
|
||||||
|
@ -249,7 +267,7 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
if ft.MapValue.isZero(v) {
|
if ft.MapValue.isZero(v) {
|
||||||
p.Field(false)
|
p.Field(false)
|
||||||
// Pretend to pack to get the nonzero checks of the key type.
|
// Pretend to pack to get the nonzero checks of the key type.
|
||||||
if v.IsValid() && (v.Kind() != reflect.Ptr || !v.IsNil()) {
|
if v.IsValid() && (v.Kind() != reflect.Ptr || !v.IsZero()) {
|
||||||
ft.MapValue.pack(&packer{b: &bytes.Buffer{}}, v)
|
ft.MapValue.pack(&packer{b: &bytes.Buffer{}}, v)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -259,16 +277,16 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
}
|
}
|
||||||
p.PopFieldmap()
|
p.PopFieldmap()
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
p.PushFieldmap(len(ft.Fields))
|
p.PushFieldmap(len(ft.structFields))
|
||||||
for _, f := range ft.Fields {
|
for _, f := range ft.structFields {
|
||||||
nrv := rv.FieldByIndex(f.structField.Index)
|
nrv := rv.FieldByIndex(f.structField.Index)
|
||||||
if f.Type.isZero(nrv) {
|
if f.Type.isZero(nrv) {
|
||||||
if f.Nonzero {
|
if f.Nonzero {
|
||||||
p.Errorf("%w: %q", ErrZero, f.Name)
|
p.errorf("%w: %q", ErrZero, f.Name)
|
||||||
}
|
}
|
||||||
p.Field(false)
|
p.Field(false)
|
||||||
// Pretend to pack to get the nonzero checks.
|
// Pretend to pack to get the nonzero checks.
|
||||||
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsNil()) {
|
if nrv.IsValid() && (nrv.Kind() != reflect.Ptr || !nrv.IsZero()) {
|
||||||
f.Type.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
f.Type.pack(&packer{b: &bytes.Buffer{}}, nrv)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -278,6 +296,6 @@ func (ft fieldType) pack(p *packer, rv reflect.Value) {
|
||||||
}
|
}
|
||||||
p.PopFieldmap()
|
p.PopFieldmap()
|
||||||
default:
|
default:
|
||||||
p.Errorf("internal error: unhandled field type") // should be prevented when registering type
|
p.errorf("internal error: unhandled field type") // should be prevented when registering type
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
37
vendor/github.com/mjl-/bstore/parse.go
generated
vendored
37
vendor/github.com/mjl-/bstore/parse.go
generated
vendored
|
@ -119,7 +119,7 @@ func (st storeType) parse(rv reflect.Value, buf []byte) (rerr error) {
|
||||||
tv.parse(p, rv)
|
tv.parse(p, rv)
|
||||||
|
|
||||||
if len(p.buf) != 0 {
|
if len(p.buf) != 0 {
|
||||||
return fmt.Errorf("%w: leftover data after parsing", ErrStore)
|
return fmt.Errorf("%w: leftover data after parsing (%d, %x %q)", ErrStore, len(p.buf), p.buf, p.buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -173,7 +173,8 @@ func (tv typeVersion) parse(p *parser, rv reflect.Value) {
|
||||||
|
|
||||||
// parse a nonzero fieldType.
|
// parse a nonzero fieldType.
|
||||||
func (ft fieldType) parse(p *parser, rv reflect.Value) {
|
func (ft fieldType) parse(p *parser, rv reflect.Value) {
|
||||||
// Because we allow schema changes from ptr to nonptr, rv can be a pointer or direct value regardless of ft.Ptr.
|
// Because we allow schema changes from ptr to nonptr, rv can be a
|
||||||
|
// pointer or direct value regardless of ft.Ptr.
|
||||||
if rv.Kind() == reflect.Ptr {
|
if rv.Kind() == reflect.Ptr {
|
||||||
nrv := reflect.New(rv.Type().Elem())
|
nrv := reflect.New(rv.Type().Elem())
|
||||||
rv.Set(nrv)
|
rv.Set(nrv)
|
||||||
|
@ -239,10 +240,18 @@ func (ft fieldType) parse(p *parser, rv reflect.Value) {
|
||||||
slc := reflect.MakeSlice(rv.Type(), n, n)
|
slc := reflect.MakeSlice(rv.Type(), n, n)
|
||||||
for i := 0; i < int(n); i++ {
|
for i := 0; i < int(n); i++ {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
ft.List.parse(p, slc.Index(i))
|
ft.ListElem.parse(p, slc.Index(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rv.Set(slc)
|
rv.Set(slc)
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
fm := p.Fieldmap(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if fm.Nonzero(i) {
|
||||||
|
ft.ListElem.parse(p, rv.Index(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
case kindMap:
|
case kindMap:
|
||||||
un := p.Uvarint()
|
un := p.Uvarint()
|
||||||
n := p.checkInt(un)
|
n := p.checkInt(un)
|
||||||
|
@ -259,11 +268,13 @@ func (ft fieldType) parse(p *parser, rv reflect.Value) {
|
||||||
}
|
}
|
||||||
rv.Set(mp)
|
rv.Set(mp)
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
fm := p.Fieldmap(len(ft.Fields))
|
fm := p.Fieldmap(len(ft.structFields))
|
||||||
strct := reflect.New(rv.Type()).Elem()
|
strct := reflect.New(rv.Type()).Elem()
|
||||||
for i, f := range ft.Fields {
|
for i, f := range ft.structFields {
|
||||||
if f.structField.Type == nil {
|
if f.structField.Type == nil {
|
||||||
f.Type.skip(p)
|
if fm.Nonzero(i) {
|
||||||
|
f.Type.skip(p)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
|
@ -303,7 +314,15 @@ func (ft fieldType) skip(p *parser) {
|
||||||
fm := p.Fieldmap(n)
|
fm := p.Fieldmap(n)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
ft.List.skip(p)
|
ft.ListElem.skip(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case kindArray:
|
||||||
|
n := ft.ArrayLength
|
||||||
|
fm := p.Fieldmap(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if fm.Nonzero(i) {
|
||||||
|
ft.ListElem.skip(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case kindMap:
|
case kindMap:
|
||||||
|
@ -317,8 +336,8 @@ func (ft fieldType) skip(p *parser) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
fm := p.Fieldmap(len(ft.Fields))
|
fm := p.Fieldmap(len(ft.structFields))
|
||||||
for i, f := range ft.Fields {
|
for i, f := range ft.structFields {
|
||||||
if fm.Nonzero(i) {
|
if fm.Nonzero(i) {
|
||||||
f.Type.skip(p)
|
f.Type.skip(p)
|
||||||
}
|
}
|
||||||
|
|
87
vendor/github.com/mjl-/bstore/plan.go
generated
vendored
87
vendor/github.com/mjl-/bstore/plan.go
generated
vendored
|
@ -7,6 +7,12 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// todo: cache query plans? perhaps explicitly through something like a prepared statement. the current plan includes values in keys,start,stop, which would need to be calculated for each execution. should benchmark time spent in planning first.
|
||||||
|
// todo optimize: handle multiple sorts with multikey indices if they match
|
||||||
|
// todo optimize: combine multiple filter (not)in/equals calls for same field
|
||||||
|
// todo optimize: efficiently pack booleans in an index (eg for Message.Flags), and use it to query.
|
||||||
|
// todo optimize: do multiple range scans if necessary when we can use an index for an equal check with multiple values.
|
||||||
|
|
||||||
// Plan represents a plan to execute a query, possibly using a simple/quick
|
// Plan represents a plan to execute a query, possibly using a simple/quick
|
||||||
// bucket "get" or cursor scan (forward/backward) on either the records or an
|
// bucket "get" or cursor scan (forward/backward) on either the records or an
|
||||||
// index.
|
// index.
|
||||||
|
@ -31,9 +37,9 @@ type plan[T any] struct {
|
||||||
startInclusive bool // If the start and stop values are inclusive or exclusive.
|
startInclusive bool // If the start and stop values are inclusive or exclusive.
|
||||||
stopInclusive bool
|
stopInclusive bool
|
||||||
|
|
||||||
// Filter we need to apply on after retrieving the record. If all
|
// Filter we need to apply after retrieving the record. If all original filters
|
||||||
// original filters from a query were handled by "keys" above, or by a
|
// from a query were handled by "keys" above, or by a range scan, this field is
|
||||||
// range scan, this field is empty.
|
// empty.
|
||||||
filters []filter[T]
|
filters []filter[T]
|
||||||
|
|
||||||
// Orders we need to apply after first retrieving all records. As with
|
// Orders we need to apply after first retrieving all records. As with
|
||||||
|
@ -73,8 +79,7 @@ func (q *Query[T]) selectPlan() (*plan[T], error) {
|
||||||
// filter per field. If there are multiple, we would use the last one.
|
// filter per field. If there are multiple, we would use the last one.
|
||||||
// That's okay, we'll filter records out when we execute the leftover
|
// That's okay, we'll filter records out when we execute the leftover
|
||||||
// filters. Probably not common.
|
// filters. Probably not common.
|
||||||
// This is common for filterEqual and filterIn on
|
// This is common for filterEqual and filterIn on fields that have a unique index.
|
||||||
// fields that have a unique index.
|
|
||||||
equalsIn := map[string]*filter[T]{}
|
equalsIn := map[string]*filter[T]{}
|
||||||
for i := range q.xfilters {
|
for i := range q.xfilters {
|
||||||
ff := &q.xfilters[i]
|
ff := &q.xfilters[i]
|
||||||
|
@ -98,8 +103,8 @@ indices:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Calculate all keys that we need to retrieve from the index.
|
// Calculate all keys that we need to retrieve from the index.
|
||||||
// todo optimization: if there is a sort involving these fields, we could do the sorting before fetching data.
|
// todo optimize: if there is a sort involving these fields, we could do the sorting before fetching data.
|
||||||
// todo optimization: we can generate the keys on demand, will help when limit is in use: we are not generating all keys.
|
// todo optimize: we can generate the keys on demand, will help when limit is in use: we are not generating all keys.
|
||||||
var keys [][]byte
|
var keys [][]byte
|
||||||
var skipFilters []*filter[T] // Filters to remove from the full list because they are handled by quering the index.
|
var skipFilters []*filter[T] // Filters to remove from the full list because they are handled by quering the index.
|
||||||
for i, f := range idx.Fields {
|
for i, f := range idx.Fields {
|
||||||
|
@ -116,12 +121,15 @@ indices:
|
||||||
}
|
}
|
||||||
fekeys := make([][]byte, len(rvalues))
|
fekeys := make([][]byte, len(rvalues))
|
||||||
for j, fv := range rvalues {
|
for j, fv := range rvalues {
|
||||||
key, _, err := packIndexKeys([]reflect.Value{fv}, nil)
|
ikl, err := packIndexKeys([]reflect.Value{fv}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
q.error(err)
|
q.error(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
fekeys[j] = key
|
if len(ikl) != 1 {
|
||||||
|
return nil, fmt.Errorf("internal error: multiple index keys for unique index (%d)", len(ikl))
|
||||||
|
}
|
||||||
|
fekeys[j] = ikl[0].pre
|
||||||
}
|
}
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
keys = fekeys
|
keys = fekeys
|
||||||
|
@ -148,22 +156,26 @@ indices:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try all other indices. We treat them all as non-unique indices now.
|
// Try all other indices. We treat them all as non-unique indices now.
|
||||||
// We want to use the one with as many "equal" prefix fields as
|
// We want to use the one with as many "equal" or "inslice" field filters as
|
||||||
// possible. Then we hope to use a scan on the remaining, either
|
// possible. Then we hope to use a scan on the remaining, either because of a
|
||||||
// because of a filterCompare, or for an ordering. If there is a limit,
|
// filterCompare, or for an ordering. If there is a limit, orderings are preferred
|
||||||
// orderings are preferred over compares.
|
// over compares.
|
||||||
equals := map[string]*filter[T]{}
|
equals := map[string]*filter[T]{}
|
||||||
|
inslices := map[string]*filter[T]{}
|
||||||
for i := range q.xfilters {
|
for i := range q.xfilters {
|
||||||
ff := &q.xfilters[i]
|
ff := &q.xfilters[i]
|
||||||
switch f := (*ff).(type) {
|
switch f := (*ff).(type) {
|
||||||
case filterEqual[T]:
|
case filterEqual[T]:
|
||||||
equals[f.field.Name] = ff
|
equals[f.field.Name] = ff
|
||||||
|
case filterInSlice[T]:
|
||||||
|
inslices[f.field.Name] = ff
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are going to generate new plans, and keep the new one if it is better than what we have.
|
// We are going to generate new plans, and keep the new one if it is better than
|
||||||
|
// what we have so far.
|
||||||
var p *plan[T]
|
var p *plan[T]
|
||||||
var nequals int
|
var nexact int
|
||||||
var nrange int
|
var nrange int
|
||||||
var ordered bool
|
var ordered bool
|
||||||
|
|
||||||
|
@ -181,18 +193,27 @@ indices:
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
packKeys = func(l []reflect.Value) ([]byte, error) {
|
packKeys = func(l []reflect.Value) ([]byte, error) {
|
||||||
key, _, err := packIndexKeys(l, nil)
|
ikl, err := packIndexKeys(l, nil)
|
||||||
return key, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err == nil && len(ikl) != 1 {
|
||||||
|
return nil, fmt.Errorf("internal error: multiple index keys for exact filters, %v", ikl)
|
||||||
|
}
|
||||||
|
return ikl[0].pre, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var neq = 0
|
var nex = 0
|
||||||
// log.Printf("idx %v", idx)
|
// log.Printf("idx %v", idx)
|
||||||
var skipFilters []*filter[T]
|
var skipFilters []*filter[T]
|
||||||
for _, f := range idx.Fields {
|
for _, f := range idx.Fields {
|
||||||
if ff, ok := equals[f.Name]; ok {
|
if equals[f.Name] != nil && f.Type.Kind != kindSlice {
|
||||||
skipFilters = append(skipFilters, ff)
|
skipFilters = append(skipFilters, equals[f.Name])
|
||||||
neq++
|
nex++
|
||||||
|
} else if inslices[f.Name] != nil && f.Type.Kind == kindSlice {
|
||||||
|
skipFilters = append(skipFilters, inslices[f.Name])
|
||||||
|
nex++
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -203,8 +224,8 @@ indices:
|
||||||
var nrng int
|
var nrng int
|
||||||
var order *order
|
var order *order
|
||||||
orders := q.xorders
|
orders := q.xorders
|
||||||
if neq < len(idx.Fields) {
|
if nex < len(idx.Fields) {
|
||||||
nf := idx.Fields[neq]
|
nf := idx.Fields[nex]
|
||||||
for i := range q.xfilters {
|
for i := range q.xfilters {
|
||||||
ff := &q.xfilters[i]
|
ff := &q.xfilters[i]
|
||||||
switch f := (*ff).(type) {
|
switch f := (*ff).(type) {
|
||||||
|
@ -230,7 +251,7 @@ indices:
|
||||||
}
|
}
|
||||||
|
|
||||||
// See if it can be used for ordering.
|
// See if it can be used for ordering.
|
||||||
// todo optimization: we could use multiple orders
|
// todo optimize: we could use multiple orders
|
||||||
if len(orders) > 0 && orders[0].field.Name == nf.Name {
|
if len(orders) > 0 && orders[0].field.Name == nf.Name {
|
||||||
order = &orders[0]
|
order = &orders[0]
|
||||||
orders = orders[1:]
|
orders = orders[1:]
|
||||||
|
@ -238,23 +259,29 @@ indices:
|
||||||
}
|
}
|
||||||
|
|
||||||
// See if this is better than what we had.
|
// See if this is better than what we had.
|
||||||
if !(neq > nequals || (neq == nequals && (nrng > nrange || order != nil && !ordered && (q.xlimit > 0 || nrng == nrange)))) {
|
if !(nex > nexact || (nex == nexact && (nrng > nrange || order != nil && !ordered && (q.xlimit > 0 || nrng == nrange)))) {
|
||||||
// log.Printf("plan not better, neq %d, nrng %d, limit %d, order %v ordered %v", neq, nrng, q.limit, order, ordered)
|
// log.Printf("plan not better, nex %d, nrng %d, limit %d, order %v ordered %v", nex, nrng, q.limit, order, ordered)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
nequals = neq
|
nexact = nex
|
||||||
nrange = nrng
|
nrange = nrng
|
||||||
ordered = order != nil
|
ordered = order != nil
|
||||||
|
|
||||||
// Calculate the prefix key.
|
// Calculate the prefix key.
|
||||||
var kvalues []reflect.Value
|
var kvalues []reflect.Value
|
||||||
for i := 0; i < neq; i++ {
|
for i := 0; i < nex; i++ {
|
||||||
f := idx.Fields[i]
|
f := idx.Fields[i]
|
||||||
kvalues = append(kvalues, (*equals[f.Name]).(filterEqual[T]).rvalue)
|
var v reflect.Value
|
||||||
|
if f.Type.Kind != kindSlice {
|
||||||
|
v = (*equals[f.Name]).(filterEqual[T]).rvalue
|
||||||
|
} else {
|
||||||
|
v = (*inslices[f.Name]).(filterInSlice[T]).rvalue
|
||||||
|
}
|
||||||
|
kvalues = append(kvalues, v)
|
||||||
}
|
}
|
||||||
var key []byte
|
var key []byte
|
||||||
var err error
|
var err error
|
||||||
if neq > 0 {
|
if nex > 0 {
|
||||||
key, err = packKeys(kvalues)
|
key, err = packKeys(kvalues)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
130
vendor/github.com/mjl-/bstore/query.go
generated
vendored
130
vendor/github.com/mjl-/bstore/query.go
generated
vendored
|
@ -1,6 +1,8 @@
|
||||||
package bstore
|
package bstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
@ -23,12 +25,14 @@ import (
|
||||||
//
|
//
|
||||||
// A Query is not safe for concurrent use.
|
// A Query is not safe for concurrent use.
|
||||||
type Query[T any] struct {
|
type Query[T any] struct {
|
||||||
st storeType // Of T.
|
ctx context.Context
|
||||||
pkType reflect.Type // Shortcut for st.Current.Fields[0].
|
ctxDone <-chan struct{} // ctx.Done(), kept here for fast access.
|
||||||
xtx *Tx // If nil, a new transaction is automatically created from db. Using a tx goes through tx() one exists.
|
st storeType // Of T.
|
||||||
xdb *DB // If not nil, xtx was created to execute the operation and is when the operation finishes (also on error).
|
pkType reflect.Type // Shortcut for st.Current.Fields[0].
|
||||||
err error // If set, returned by operations. For indicating failed filters, or that an operation has finished.
|
xtx *Tx // If nil, a new transaction is automatically created from db. Using a tx goes through tx() one exists.
|
||||||
xfilterIDs *filterIDs[T] // Kept separately from filters because these filters make us use the PK without further index planning.
|
xdb *DB // If not nil, xtx was created to execute the operation and is when the operation finishes (also on error).
|
||||||
|
err error // If set, returned by operations. For indicating failed filters, or that an operation has finished.
|
||||||
|
xfilterIDs *filterIDs[T] // Kept separately from filters because these filters make us use the PK without further index planning.
|
||||||
xfilters []filter[T]
|
xfilters []filter[T]
|
||||||
xorders []order
|
xorders []order
|
||||||
|
|
||||||
|
@ -99,6 +103,14 @@ type filterNotIn[T any] struct {
|
||||||
|
|
||||||
func (filterNotIn[T]) filter() {}
|
func (filterNotIn[T]) filter() {}
|
||||||
|
|
||||||
|
// For matching one of the values in a field that is a slice of the same type.
|
||||||
|
type filterInSlice[T any] struct {
|
||||||
|
field field // Of field type, a slice.
|
||||||
|
rvalue reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (filterInSlice[T]) filter() {}
|
||||||
|
|
||||||
type compareOp byte
|
type compareOp byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -158,23 +170,37 @@ func (p *pair[T]) Value(e *exec[T]) (T, error) {
|
||||||
// QueryDB returns a new Query for type T. When an operation on the query is
|
// QueryDB returns a new Query for type T. When an operation on the query is
|
||||||
// executed, a read-only/writable transaction is created as appropriate for the
|
// executed, a read-only/writable transaction is created as appropriate for the
|
||||||
// operation.
|
// operation.
|
||||||
func QueryDB[T any](db *DB) *Query[T] {
|
func QueryDB[T any](ctx context.Context, db *DB) *Query[T] {
|
||||||
// We lock db for storeTypes. We keep it locked until Query is done.
|
// We lock db for storeTypes. We keep it locked until Query is done.
|
||||||
db.typesMutex.RLock()
|
db.typesMutex.RLock()
|
||||||
q := &Query[T]{xdb: db}
|
q := &Query[T]{xdb: db}
|
||||||
q.init(db)
|
q.init(ctx, db)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query returns a new Query that operates on type T using transaction tx.
|
// QueryTx returns a new Query that operates on type T using transaction tx.
|
||||||
|
// The context of the transaction is used for the query.
|
||||||
func QueryTx[T any](tx *Tx) *Query[T] {
|
func QueryTx[T any](tx *Tx) *Query[T] {
|
||||||
// note: Since we are in a transaction, we already hold an rlock on the
|
// note: Since we are in a transaction, we already hold an rlock on the
|
||||||
// db types.
|
// db types.
|
||||||
q := &Query[T]{xtx: tx}
|
q := &Query[T]{xtx: tx}
|
||||||
q.init(tx.db)
|
if tx.err != nil {
|
||||||
|
q.err = tx.err
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
q.init(tx.ctx, tx.db)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *Query[T]) ctxErr() error {
|
||||||
|
select {
|
||||||
|
case <-q.ctxDone:
|
||||||
|
return q.ctx.Err()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stats returns the current statistics for this query. When a query finishes,
|
// Stats returns the current statistics for this query. When a query finishes,
|
||||||
// its stats are added to those of its transaction. When a transaction
|
// its stats are added to those of its transaction. When a transaction
|
||||||
// finishes, its stats are added to those of its database.
|
// finishes, its stats are added to those of its database.
|
||||||
|
@ -182,7 +208,7 @@ func (q *Query[T]) Stats() Stats {
|
||||||
return q.stats
|
return q.stats
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query[T]) init(db *DB) {
|
func (q *Query[T]) init(ctx context.Context, db *DB) {
|
||||||
var v T
|
var v T
|
||||||
t := reflect.TypeOf(v)
|
t := reflect.TypeOf(v)
|
||||||
if t.Kind() != reflect.Struct {
|
if t.Kind() != reflect.Struct {
|
||||||
|
@ -194,6 +220,11 @@ func (q *Query[T]) init(db *DB) {
|
||||||
q.stats.LastType = q.st.Name
|
q.stats.LastType = q.st.Name
|
||||||
q.pkType = q.st.Current.Fields[0].structField.Type
|
q.pkType = q.st.Current.Fields[0].structField.Type
|
||||||
}
|
}
|
||||||
|
q.ctx = ctx
|
||||||
|
q.ctxDone = ctx.Done()
|
||||||
|
if err := q.ctxErr(); q.err == nil && err != nil {
|
||||||
|
q.err = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query[T]) tx(write bool) (*Tx, error) {
|
func (q *Query[T]) tx(write bool) (*Tx, error) {
|
||||||
|
@ -207,7 +238,7 @@ func (q *Query[T]) tx(write bool) (*Tx, error) {
|
||||||
q.error(err)
|
q.error(err)
|
||||||
return nil, q.err
|
return nil, q.err
|
||||||
}
|
}
|
||||||
q.xtx = &Tx{db: q.xdb, btx: tx}
|
q.xtx = &Tx{ctx: q.ctx, db: q.xdb, btx: tx}
|
||||||
if write {
|
if write {
|
||||||
q.stats.Writes++
|
q.stats.Writes++
|
||||||
} else {
|
} else {
|
||||||
|
@ -308,6 +339,11 @@ func (q *Query[T]) checkErr() bool {
|
||||||
// Probably the result of using a Query zero value.
|
// Probably the result of using a Query zero value.
|
||||||
q.errorf("%w: invalid query, use QueryDB or QueryTx to make a query", ErrParam)
|
q.errorf("%w: invalid query, use QueryDB or QueryTx to make a query", ErrParam)
|
||||||
}
|
}
|
||||||
|
if q.err == nil {
|
||||||
|
if err := q.ctxErr(); err != nil {
|
||||||
|
q.err = err
|
||||||
|
}
|
||||||
|
}
|
||||||
return q.err == nil
|
return q.err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,7 +401,10 @@ func (q *Query[T]) foreachKey(write, value bool, fn func(bk []byte, v T) error)
|
||||||
return nil
|
return nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if err := fn(bk, v); err != nil {
|
} else if err := fn(bk, v); err == StopForEach {
|
||||||
|
q.error(ErrFinished)
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
q.error(err)
|
q.error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -436,14 +475,14 @@ var convertFieldKinds = map[convertKinds]struct{}{
|
||||||
// Check type of value for field and return a reflect value that can directly be set on the field.
|
// Check type of value for field and return a reflect value that can directly be set on the field.
|
||||||
// If the field is a pointer, we allow non-pointers and convert them.
|
// If the field is a pointer, we allow non-pointers and convert them.
|
||||||
// We require value to be of a type that can be converted without loss of precision to the type of field.
|
// We require value to be of a type that can be converted without loss of precision to the type of field.
|
||||||
func (q *Query[T]) prepareValue(fname string, ft fieldType, sf reflect.StructField, rv reflect.Value) (reflect.Value, bool) {
|
func (q *Query[T]) prepareValue(fname string, ft fieldType, st reflect.Type, rv reflect.Value) (reflect.Value, bool) {
|
||||||
if !rv.IsValid() {
|
if !rv.IsValid() {
|
||||||
q.errorf("%w: invalid value", ErrParam)
|
q.errorf("%w: invalid value", ErrParam)
|
||||||
return rv, false
|
return rv, false
|
||||||
}
|
}
|
||||||
// Quick check first.
|
// Quick check first.
|
||||||
t := rv.Type()
|
t := rv.Type()
|
||||||
if t == sf.Type {
|
if t == st {
|
||||||
return rv, true
|
return rv, true
|
||||||
}
|
}
|
||||||
if !ft.Ptr && rv.Kind() == reflect.Ptr {
|
if !ft.Ptr && rv.Kind() == reflect.Ptr {
|
||||||
|
@ -461,14 +500,14 @@ func (q *Query[T]) prepareValue(fname string, ft fieldType, sf reflect.StructFie
|
||||||
return reflect.Value{}, false
|
return reflect.Value{}, false
|
||||||
}
|
}
|
||||||
if k != ft.Kind {
|
if k != ft.Kind {
|
||||||
dt := sf.Type
|
dt := st
|
||||||
if ft.Ptr {
|
if ft.Ptr {
|
||||||
dt = dt.Elem()
|
dt = dt.Elem()
|
||||||
}
|
}
|
||||||
rv = rv.Convert(dt)
|
rv = rv.Convert(dt)
|
||||||
}
|
}
|
||||||
if ft.Ptr && rv.Kind() != reflect.Ptr {
|
if ft.Ptr && rv.Kind() != reflect.Ptr {
|
||||||
nv := reflect.New(sf.Type.Elem())
|
nv := reflect.New(st.Elem())
|
||||||
nv.Elem().Set(rv)
|
nv.Elem().Set(rv)
|
||||||
rv = nv
|
rv = nv
|
||||||
}
|
}
|
||||||
|
@ -654,7 +693,7 @@ func (q *Query[T]) filterEqual(fieldName string, values []any, not bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(values) == 1 {
|
if len(values) == 1 {
|
||||||
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField, reflect.ValueOf(values[0]))
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(values[0]))
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -667,7 +706,7 @@ func (q *Query[T]) filterEqual(fieldName string, values []any, not bool) {
|
||||||
}
|
}
|
||||||
rvs := make([]reflect.Value, len(values))
|
rvs := make([]reflect.Value, len(values))
|
||||||
for i, value := range values {
|
for i, value := range values {
|
||||||
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField, reflect.ValueOf(value))
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, reflect.ValueOf(value))
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -680,6 +719,42 @@ func (q *Query[T]) filterEqual(fieldName string, values []any, not bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilterIn selects records that have one of values of the string slice fieldName.
|
||||||
|
//
|
||||||
|
// If fieldName has an index, it is used to select rows.
|
||||||
|
//
|
||||||
|
// Note: Value must be a compatible type for comparison with the elements of
|
||||||
|
// fieldName. Go constant numbers become ints, which are not compatible with uint
|
||||||
|
// or float types.
|
||||||
|
func (q *Query[T]) FilterIn(fieldName string, value any) *Query[T] {
|
||||||
|
if !q.checkErr() {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
ff, ok := q.lookupField(fieldName)
|
||||||
|
if !ok {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
if ff.Type.Ptr {
|
||||||
|
q.errorf("%w: cannot compare pointer values", ErrParam)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
if ff.Type.Kind != kindSlice {
|
||||||
|
q.errorf("%w: field for FilterIn must be a slice", ErrParam)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
et := ff.Type.ListElem
|
||||||
|
if et.Ptr {
|
||||||
|
q.errorf("%w: cannot compare element pointer values", ErrParam)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
rv, ok := q.prepareValue(ff.Name, *et, ff.structField.Type.Elem(), reflect.ValueOf(value))
|
||||||
|
if !ok {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
q.addFilter(filterInSlice[T]{ff, rv})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
// FilterGreater selects records that have fieldName > value.
|
// FilterGreater selects records that have fieldName > value.
|
||||||
//
|
//
|
||||||
// Note: Value must be a compatible type for comparison with fieldName. Go
|
// Note: Value must be a compatible type for comparison with fieldName. Go
|
||||||
|
@ -716,7 +791,7 @@ func (q *Query[T]) filterCompare(fieldName string, op compareOp, value reflect.V
|
||||||
q.errorf("%w: cannot compare %s", ErrParam, ff.Type.Kind)
|
q.errorf("%w: cannot compare %s", ErrParam, ff.Type.Kind)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField, value)
|
rv, ok := q.prepareValue(ff.Name, ff.Type, ff.structField.Type, value)
|
||||||
if !ok {
|
if !ok {
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
@ -831,7 +906,8 @@ func (q *Query[T]) gather(v T, rv reflect.Value) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err returns if an error is set on the query. Can happen for invalid filters.
|
// Err returns if an error is set on the query. Can happen for invalid filters or
|
||||||
|
// canceled contexts.
|
||||||
// Finished queries return ErrFinished.
|
// Finished queries return ErrFinished.
|
||||||
func (q *Query[T]) Err() error {
|
func (q *Query[T]) Err() error {
|
||||||
q.checkErr()
|
q.checkErr()
|
||||||
|
@ -979,7 +1055,7 @@ next:
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
return 0, fmt.Errorf("%w: cannot update primary key", ErrParam)
|
return 0, fmt.Errorf("%w: cannot update primary key", ErrParam)
|
||||||
}
|
}
|
||||||
rv, ok := q.prepareValue(f.Name, f.Type, f.structField, reflect.ValueOf(value))
|
rv, ok := q.prepareValue(f.Name, f.Type, f.structField.Type, reflect.ValueOf(value))
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, q.err
|
return 0, q.err
|
||||||
}
|
}
|
||||||
|
@ -991,7 +1067,7 @@ next:
|
||||||
if ef.Name != name {
|
if ef.Name != name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rv, ok := q.prepareValue(ef.Name, ef.Type, ef.structField, reflect.ValueOf(value))
|
rv, ok := q.prepareValue(ef.Name, ef.Type, ef.structField.Type, reflect.ValueOf(value))
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, q.err
|
return 0, q.err
|
||||||
}
|
}
|
||||||
|
@ -1051,6 +1127,8 @@ func (q *Query[T]) IDs(idsptr any) (rerr error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: should we have an iteration object that we can call Next and NextID on?
|
||||||
|
|
||||||
// Next fetches the next record, moving the cursor forward.
|
// Next fetches the next record, moving the cursor forward.
|
||||||
//
|
//
|
||||||
// ErrAbsent is returned if no more records match.
|
// ErrAbsent is returned if no more records match.
|
||||||
|
@ -1116,7 +1194,13 @@ func (q *Query[T]) Exists() (exists bool, rerr error) {
|
||||||
return err == nil, err
|
return err == nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StopForEach is an error value that, if returned by the function passed to
|
||||||
|
// Query.ForEach, stops further iterations.
|
||||||
|
var StopForEach error = errors.New("stop foreach")
|
||||||
|
|
||||||
// ForEach calls fn on each selected record.
|
// ForEach calls fn on each selected record.
|
||||||
|
// If fn returns StopForEach, ForEach stops iterating, so no longer calls fn,
|
||||||
|
// and returns nil.
|
||||||
func (q *Query[T]) ForEach(fn func(value T) error) (rerr error) {
|
func (q *Query[T]) ForEach(fn func(value T) error) (rerr error) {
|
||||||
defer q.finish(&rerr)
|
defer q.finish(&rerr)
|
||||||
q.checkNotNext()
|
q.checkNotNext()
|
||||||
|
|
356
vendor/github.com/mjl-/bstore/register.go
generated
vendored
356
vendor/github.com/mjl-/bstore/register.go
generated
vendored
|
@ -2,9 +2,12 @@ package bstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -14,10 +17,22 @@ import (
|
||||||
bolt "go.etcd.io/bbolt"
|
bolt "go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// todo: implement changing PK type, eg to wider int. requires rewriting all values, and removing old typeVersions.
|
||||||
|
// todo: allow schema change between []byte and string?
|
||||||
|
// todo: allow more schema changes, eg int to string, bool to int or string, int to bool, perhaps even string to int/bool. and between structs and maps. would require rewriting the records.
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// First version.
|
||||||
ondiskVersion1 = 1
|
ondiskVersion1 = 1
|
||||||
|
|
||||||
|
// With support for cyclic types, adding typeField.FieldsTypeSeq to
|
||||||
|
// define/reference types. Only used when a type has a field that references another
|
||||||
|
// struct type.
|
||||||
|
ondiskVersion2 = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errSchemaCheck = errors.New("schema check")
|
||||||
|
|
||||||
// Register registers the Go types of each value in typeValues for use with the
|
// Register registers the Go types of each value in typeValues for use with the
|
||||||
// database. Each value must be a struct, not a pointer.
|
// database. Each value must be a struct, not a pointer.
|
||||||
//
|
//
|
||||||
|
@ -30,7 +45,11 @@ const (
|
||||||
//
|
//
|
||||||
// Register can be called multiple times, with different types. But types that
|
// Register can be called multiple times, with different types. But types that
|
||||||
// reference each other must be registered in the same call to Registers.
|
// reference each other must be registered in the same call to Registers.
|
||||||
func (db *DB) Register(typeValues ...any) error {
|
//
|
||||||
|
// To help during development, if environment variable "bstore_schema_check" is set
|
||||||
|
// to "changed", an error is returned if there is no schema change. If it is set to
|
||||||
|
// "unchanged", an error is returned if there was a schema change.
|
||||||
|
func (db *DB) Register(ctx context.Context, typeValues ...any) error {
|
||||||
// We will drop/create new indices as needed. For changed indices, we drop
|
// We will drop/create new indices as needed. For changed indices, we drop
|
||||||
// and recreate. E.g. if an index becomes a unique index, or if a field in
|
// and recreate. E.g. if an index becomes a unique index, or if a field in
|
||||||
// an index changes. These values map type and index name to their index.
|
// an index changes. These values map type and index name to their index.
|
||||||
|
@ -41,7 +60,10 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
ntypeversions := map[string]*typeVersion{} // New typeversions, through new types or updated versions of existing types.
|
ntypeversions := map[string]*typeVersion{} // New typeversions, through new types or updated versions of existing types.
|
||||||
registered := map[string]*storeType{} // Registered in this call.
|
registered := map[string]*storeType{} // Registered in this call.
|
||||||
|
|
||||||
return db.Write(func(tx *Tx) error {
|
checkSchemaChanged := os.Getenv("bstore_schema_check") == "changed"
|
||||||
|
checkSchemaUnchanged := os.Getenv("bstore_schema_check") == "unchanged"
|
||||||
|
|
||||||
|
return db.Write(ctx, func(tx *Tx) error {
|
||||||
for _, t := range typeValues {
|
for _, t := range typeValues {
|
||||||
rt := reflect.TypeOf(t)
|
rt := reflect.TypeOf(t)
|
||||||
if rt.Kind() != reflect.Struct {
|
if rt.Kind() != reflect.Struct {
|
||||||
|
@ -118,6 +140,11 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
|
|
||||||
// Decide if we need to add a new typeVersion to the database. I.e. a new type schema.
|
// Decide if we need to add a new typeVersion to the database. I.e. a new type schema.
|
||||||
if st.Current == nil || !st.Current.typeEqual(*tv) {
|
if st.Current == nil || !st.Current.typeEqual(*tv) {
|
||||||
|
if checkSchemaUnchanged {
|
||||||
|
return fmt.Errorf("%w: schema changed but bstore_schema_check=unchanged is set (type %v)", errSchemaCheck, st.Name)
|
||||||
|
}
|
||||||
|
checkSchemaChanged = false // After registering types, we check that it is false.
|
||||||
|
|
||||||
tv.Version = 1
|
tv.Version = 1
|
||||||
if st.Current != nil {
|
if st.Current != nil {
|
||||||
tv.Version = st.Current.Version + 1
|
tv.Version = st.Current.Version + 1
|
||||||
|
@ -127,6 +154,13 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
return fmt.Errorf("internal error: packing schema for type %q", tv.name)
|
return fmt.Errorf("internal error: packing schema for type %q", tv.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanity check: parse the typeVersion again, and check that we think it is equal to the typeVersion just written.
|
||||||
|
if xtv, err := parseSchema(k, v); err != nil {
|
||||||
|
return fmt.Errorf("%w: parsing generated typeVersion: %v", ErrStore, err)
|
||||||
|
} else if !xtv.typeEqual(*tv) {
|
||||||
|
return fmt.Errorf("%w: generated typeVersion not equal to itself after pack and parse", ErrStore)
|
||||||
|
}
|
||||||
|
|
||||||
// note: we don't track types bucket operations in stats.
|
// note: we don't track types bucket operations in stats.
|
||||||
if err := tb.Put(k, v); err != nil {
|
if err := tb.Put(k, v); err != nil {
|
||||||
return fmt.Errorf("storing new schema: %w", err)
|
return fmt.Errorf("storing new schema: %w", err)
|
||||||
|
@ -202,6 +236,10 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
registered[st.Name] = &st
|
registered[st.Name] = &st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if checkSchemaChanged {
|
||||||
|
return fmt.Errorf("%w: schema did not change while bstore_schema_check=changed is set", errSchemaCheck)
|
||||||
|
}
|
||||||
|
|
||||||
// Check that referenced types exist, and make links in the referenced types.
|
// Check that referenced types exist, and make links in the referenced types.
|
||||||
for _, st := range registered {
|
for _, st := range registered {
|
||||||
tv := st.Current
|
tv := st.Current
|
||||||
|
@ -237,7 +275,7 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
// We cannot just recalculate the ReferencedBy, because the whole point is to
|
// We cannot just recalculate the ReferencedBy, because the whole point is to
|
||||||
// detect types that are missing in this Register.
|
// detect types that are missing in this Register.
|
||||||
updateReferencedBy := map[string]struct{}{}
|
updateReferencedBy := map[string]struct{}{}
|
||||||
for _, ntv := range ntypeversions {
|
for ntname, ntv := range ntypeversions {
|
||||||
otv := otypeversions[ntv.name] // Can be nil, on first register.
|
otv := otypeversions[ntv.name] // Can be nil, on first register.
|
||||||
|
|
||||||
// Look for references that were added.
|
// Look for references that were added.
|
||||||
|
@ -251,6 +289,66 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
if _, ok := registered[name].Current.ReferencedBy[ntv.name]; ok {
|
if _, ok := registered[name].Current.ReferencedBy[ntv.name]; ok {
|
||||||
return fmt.Errorf("%w: type %q introduces reference to %q but is already marked as ReferencedBy in that type", ErrStore, ntv.name, name)
|
return fmt.Errorf("%w: type %q introduces reference to %q but is already marked as ReferencedBy in that type", ErrStore, ntv.name, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify that the new reference does not violate the foreign key constraint.
|
||||||
|
var foundField bool
|
||||||
|
for _, f := range ntv.Fields {
|
||||||
|
for _, rname := range f.References {
|
||||||
|
if rname != name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
foundField = true
|
||||||
|
|
||||||
|
// For newly added references, check they are valid.
|
||||||
|
b, err := tx.recordsBucket(ntname, ntv.fillPercent)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: bucket for type %s with field with new reference: %v", ErrStore, ntname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rb, err := tx.recordsBucket(name, registered[name].Current.fillPercent)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: bucket for referenced type %s: %v", ErrStore, name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nst := registered[ntname]
|
||||||
|
rv := reflect.New(nst.Type).Elem()
|
||||||
|
ctxDone := ctx.Done()
|
||||||
|
err = b.ForEach(func(bk, bv []byte) error {
|
||||||
|
tx.stats.Records.Cursor++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctxDone:
|
||||||
|
return tx.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := nst.parse(rv, bv); err != nil {
|
||||||
|
return fmt.Errorf("parsing record for %s: %w", ntname, err)
|
||||||
|
}
|
||||||
|
frv := rv.FieldByIndex(f.structField.Index)
|
||||||
|
if frv.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rpk, err := packPK(frv)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("packing pk for referenced type %s: %w", name, err)
|
||||||
|
}
|
||||||
|
tx.stats.Records.Cursor++
|
||||||
|
if rb.Get(rpk) == nil {
|
||||||
|
return fmt.Errorf("%w: value %v not in %s", ErrReference, frv.Interface(), name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: ensuring referential integrity for newly added reference of %s.%s", err, ntname, f.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundField {
|
||||||
|
return fmt.Errorf("%w: could not find field causing newly referenced type %s in type %s", ErrStore, name, ntname)
|
||||||
|
}
|
||||||
|
|
||||||
// note: we are updating the previous tv's ReferencedBy, not tidy but it is safe.
|
// note: we are updating the previous tv's ReferencedBy, not tidy but it is safe.
|
||||||
registered[name].Current.ReferencedBy[ntv.name] = struct{}{}
|
registered[name].Current.ReferencedBy[ntv.name] = struct{}{}
|
||||||
updateReferencedBy[name] = struct{}{}
|
updateReferencedBy[name] = struct{}{}
|
||||||
|
@ -271,8 +369,10 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
if _, ok := ntv.references[name]; ok {
|
if _, ok := ntv.references[name]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, ok := registered[name].Current.ReferencedBy[ntv.name]; !ok {
|
if rtv, ok := registered[name]; !ok {
|
||||||
return fmt.Errorf("%w: previously referenced type %q not present in %q", ErrStore, ntv.name, name)
|
return fmt.Errorf("%w: type %q formerly referenced by %q not yet registered", ErrStore, name, ntv.name)
|
||||||
|
} else if _, ok := rtv.Current.ReferencedBy[ntv.name]; !ok {
|
||||||
|
return fmt.Errorf("%w: formerly referenced type %q missing from %q", ErrStore, name, ntv.name)
|
||||||
}
|
}
|
||||||
// note: we are updating the previous tv's ReferencedBy, not tidy but it is safe.
|
// note: we are updating the previous tv's ReferencedBy, not tidy but it is safe.
|
||||||
delete(registered[name].Current.ReferencedBy, ntv.name)
|
delete(registered[name].Current.ReferencedBy, ntv.name)
|
||||||
|
@ -416,20 +516,29 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
}
|
}
|
||||||
ibkeys := make([][]key, len(idxs))
|
ibkeys := make([][]key, len(idxs))
|
||||||
|
|
||||||
|
ctxDone := ctx.Done()
|
||||||
err = rb.ForEach(func(bk, bv []byte) error {
|
err = rb.ForEach(func(bk, bv []byte) error {
|
||||||
tx.stats.Records.Cursor++
|
tx.stats.Records.Cursor++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctxDone:
|
||||||
|
return tx.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
rv := reflect.New(st.Type).Elem()
|
rv := reflect.New(st.Type).Elem()
|
||||||
if err := st.parse(rv, bv); err != nil {
|
if err := st.parse(rv, bv); err != nil {
|
||||||
return fmt.Errorf("parsing record for index for %s: %w", name, err)
|
return fmt.Errorf("parsing record for index for %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, idx := range idxs {
|
for i, idx := range idxs {
|
||||||
prek, ik, err := idx.packKey(rv, bk)
|
ikl, err := idx.packKey(rv, bk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating key for %s.%s: %w", name, idx.Name, err)
|
return fmt.Errorf("creating key for %s.%s: %w", name, idx.Name, err)
|
||||||
}
|
}
|
||||||
ibkeys[i] = append(ibkeys[i], key{ik, uint16(len(prek))})
|
for _, ik := range ikl {
|
||||||
|
ibkeys[i] = append(ibkeys[i], key{ik.full, uint16(len(ik.pre))})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -447,14 +556,14 @@ func (db *DB) Register(typeValues ...any) error {
|
||||||
prev := keys[i-1]
|
prev := keys[i-1]
|
||||||
if bytes.Equal(prev.buf[:prev.pre], k.buf[:k.pre]) {
|
if bytes.Equal(prev.buf[:prev.pre], k.buf[:k.pre]) {
|
||||||
// Do quite a bit of work to make a helpful error message.
|
// Do quite a bit of work to make a helpful error message.
|
||||||
a := reflect.New(reflect.TypeOf(idx.tv.Fields[0].Type.zero(nil))).Elem()
|
a := reflect.New(reflect.TypeOf(idx.tv.Fields[0].Type.zeroKey())).Elem()
|
||||||
b := reflect.New(reflect.TypeOf(idx.tv.Fields[0].Type.zero(nil))).Elem()
|
b := reflect.New(reflect.TypeOf(idx.tv.Fields[0].Type.zeroKey())).Elem()
|
||||||
parsePK(a, prev.buf[prev.pre:]) // Ignore error, nothing to do.
|
parsePK(a, prev.buf[prev.pre:]) // Ignore error, nothing to do.
|
||||||
parsePK(b, k.buf[k.pre:]) // Ignore error, nothing to do.
|
parsePK(b, k.buf[k.pre:]) // Ignore error, nothing to do.
|
||||||
var dup []any
|
var dup []any
|
||||||
_, values, _ := idx.parseKey(k.buf, true)
|
_, values, _ := idx.parseKey(k.buf, true)
|
||||||
for i := range values {
|
for i := range values {
|
||||||
x := reflect.New(reflect.TypeOf(idx.Fields[i].Type.zero(nil))).Elem()
|
x := reflect.New(reflect.TypeOf(idx.Fields[i].Type.zeroKey())).Elem()
|
||||||
parsePK(x, values[i]) // Ignore error, nothing to do.
|
parsePK(x, values[i]) // Ignore error, nothing to do.
|
||||||
dup = append(dup, x.Interface())
|
dup = append(dup, x.Interface())
|
||||||
}
|
}
|
||||||
|
@ -502,8 +611,10 @@ func parseSchema(bk, bv []byte) (*typeVersion, error) {
|
||||||
if tv.Version != version {
|
if tv.Version != version {
|
||||||
return nil, fmt.Errorf("%w: version in schema %d does not match key %d", ErrStore, tv.Version, version)
|
return nil, fmt.Errorf("%w: version in schema %d does not match key %d", ErrStore, tv.Version, version)
|
||||||
}
|
}
|
||||||
if tv.OndiskVersion != ondiskVersion1 {
|
switch tv.OndiskVersion {
|
||||||
return nil, fmt.Errorf("internal error: OndiskVersion %d not supported", tv.OndiskVersion)
|
case ondiskVersion1, ondiskVersion2:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("internal error: OndiskVersion %d not recognized/supported", tv.OndiskVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill references, used for comparing/checking schema updates.
|
// Fill references, used for comparing/checking schema updates.
|
||||||
|
@ -514,12 +625,72 @@ func parseSchema(bk, bv []byte) (*typeVersion, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve fieldType.structFields, for referencing defined types, used for
|
||||||
|
// supporting cyclic types. The type itself always implicitly has sequence 1.
|
||||||
|
seqFields := map[int][]field{1: tv.Fields}
|
||||||
|
origOndiskVersion := tv.OndiskVersion
|
||||||
|
for i := range tv.Fields {
|
||||||
|
if err := tv.resolveStructFields(seqFields, &tv.Fields[i].Type); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: resolving struct fields for referencing types: %v", ErrStore, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tv.OndiskVersion != origOndiskVersion {
|
||||||
|
return nil, fmt.Errorf("%w: resolving cyclic types changed ondisk version from %d to %d", ErrStore, origOndiskVersion, tv.OndiskVersion)
|
||||||
|
}
|
||||||
|
|
||||||
return &tv, nil
|
return &tv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve structFields in ft (and recursively), either by setting it to Fields
|
||||||
|
// (common), or by setting it to the fields of a referenced type in case of a
|
||||||
|
// cyclic data type.
|
||||||
|
func (tv *typeVersion) resolveStructFields(seqFields map[int][]field, ft *fieldType) error {
|
||||||
|
if ft.Kind == kindStruct {
|
||||||
|
if ft.FieldsTypeSeq < 0 {
|
||||||
|
var ok bool
|
||||||
|
ft.structFields, ok = seqFields[-ft.FieldsTypeSeq]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("reference to undefined FieldsTypeSeq %d (n %d)", -ft.FieldsTypeSeq, len(seqFields))
|
||||||
|
}
|
||||||
|
if len(ft.DefinitionFields) != 0 {
|
||||||
|
return fmt.Errorf("reference to FieldsTypeSeq while also defining fields")
|
||||||
|
}
|
||||||
|
} else if ft.FieldsTypeSeq > 0 {
|
||||||
|
if _, ok := seqFields[ft.FieldsTypeSeq]; ok {
|
||||||
|
return fmt.Errorf("duplicate definition of FieldsTypeSeq %d (n %d)", ft.FieldsTypeSeq, len(seqFields))
|
||||||
|
}
|
||||||
|
seqFields[ft.FieldsTypeSeq] = ft.DefinitionFields
|
||||||
|
ft.structFields = ft.DefinitionFields
|
||||||
|
}
|
||||||
|
// note: ondiskVersion1 does not have/use this field, so it defaults to 0.
|
||||||
|
if ft.FieldsTypeSeq == 0 {
|
||||||
|
ft.structFields = ft.DefinitionFields
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range ft.DefinitionFields {
|
||||||
|
if err := tv.resolveStructFields(seqFields, &ft.DefinitionFields[i].Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xftl := []*fieldType{ft.MapKey, ft.MapValue, ft.ListElem}
|
||||||
|
for _, xft := range xftl {
|
||||||
|
if xft == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := tv.resolveStructFields(seqFields, xft); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// packSchema returns a key and value to store in the types bucket.
|
// packSchema returns a key and value to store in the types bucket.
|
||||||
func packSchema(tv *typeVersion) ([]byte, []byte, error) {
|
func packSchema(tv *typeVersion) ([]byte, []byte, error) {
|
||||||
if tv.OndiskVersion != ondiskVersion1 {
|
switch tv.OndiskVersion {
|
||||||
|
case ondiskVersion1, ondiskVersion2:
|
||||||
|
default:
|
||||||
return nil, nil, fmt.Errorf("internal error: invalid OndiskVersion %d", tv.OndiskVersion)
|
return nil, nil, fmt.Errorf("internal error: invalid OndiskVersion %d", tv.OndiskVersion)
|
||||||
}
|
}
|
||||||
v, err := json.Marshal(tv)
|
v, err := json.Marshal(tv)
|
||||||
|
@ -540,12 +711,17 @@ func gatherTypeVersion(t reflect.Type) (*typeVersion, error) {
|
||||||
}
|
}
|
||||||
tv := &typeVersion{
|
tv := &typeVersion{
|
||||||
Version: 0, // Set by caller.
|
Version: 0, // Set by caller.
|
||||||
OndiskVersion: ondiskVersion1, // Current on-disk format.
|
OndiskVersion: ondiskVersion2, // When opening a database with ondiskVersion1, we add a new typeVersion.
|
||||||
ReferencedBy: map[string]struct{}{},
|
ReferencedBy: map[string]struct{}{},
|
||||||
name: tname,
|
name: tname,
|
||||||
fillPercent: 0.5,
|
fillPercent: 0.5,
|
||||||
}
|
}
|
||||||
tv.Fields, tv.embedFields, err = gatherTypeFields(t, true, true, false)
|
|
||||||
|
// The type being parsed implicitly has sequence 1. Next struct types will be
|
||||||
|
// assigned the next value (based on length of typeseqs). FieldTypes referencing
|
||||||
|
// another type are resolved below, after all fields have been gathered.
|
||||||
|
typeSeqs := map[reflect.Type]int{t: 1}
|
||||||
|
tv.Fields, tv.embedFields, err = gatherTypeFields(typeSeqs, t, true, true, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -562,6 +738,15 @@ func gatherTypeVersion(t reflect.Type) (*typeVersion, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve structFields for the typeFields that reference an earlier defined type,
|
||||||
|
// using the same function as used when loading a type from disk.
|
||||||
|
seqFields := map[int][]field{1: tv.Fields}
|
||||||
|
for i := range tv.Fields {
|
||||||
|
if err := tv.resolveStructFields(seqFields, &tv.Fields[i].Type); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: resolving struct fields for referencing types: %v", ErrStore, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Find indices.
|
// Find indices.
|
||||||
tv.Indices = map[string]*index{}
|
tv.Indices = map[string]*index{}
|
||||||
|
|
||||||
|
@ -572,6 +757,7 @@ func gatherTypeVersion(t reflect.Type) (*typeVersion, error) {
|
||||||
}
|
}
|
||||||
idx = &index{unique, iname, nil, tv}
|
idx = &index{unique, iname, nil, tv}
|
||||||
tv.Indices[iname] = idx
|
tv.Indices[iname] = idx
|
||||||
|
nslice := 0
|
||||||
for _, f := range fields {
|
for _, f := range fields {
|
||||||
// todo: can we have a unique index on bytes? seems like this should be possible to have max 1 []byte in an index key, only to be used for unique get plans.
|
// todo: can we have a unique index on bytes? seems like this should be possible to have max 1 []byte in an index key, only to be used for unique get plans.
|
||||||
if f.Type.Ptr {
|
if f.Type.Ptr {
|
||||||
|
@ -579,6 +765,14 @@ func gatherTypeVersion(t reflect.Type) (*typeVersion, error) {
|
||||||
}
|
}
|
||||||
switch f.Type.Kind {
|
switch f.Type.Kind {
|
||||||
case kindBool, kindInt8, kindInt16, kindInt32, kindInt64, kindInt, kindUint8, kindUint16, kindUint32, kindUint64, kindUint, kindString, kindTime:
|
case kindBool, kindInt8, kindInt16, kindInt32, kindInt64, kindInt, kindUint8, kindUint16, kindUint32, kindUint64, kindUint, kindString, kindTime:
|
||||||
|
case kindSlice:
|
||||||
|
nslice++
|
||||||
|
if nslice > 1 {
|
||||||
|
return fmt.Errorf("%w: can only have one slice field in index, for field %q", ErrType, f.Name)
|
||||||
|
}
|
||||||
|
if unique {
|
||||||
|
return fmt.Errorf("%w: can only use slice field %v in field %q as index without unique", ErrType, f.Type.Kind, f.Name)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: cannot use type %v in field %q as index/unique", ErrType, f.Type.Kind, f.Name)
|
return fmt.Errorf("%w: cannot use type %v in field %q as index/unique", ErrType, f.Type.Kind, f.Name)
|
||||||
}
|
}
|
||||||
|
@ -692,7 +886,7 @@ func gatherTypeVersion(t reflect.Type) (*typeVersion, error) {
|
||||||
// field must not be ignored and be a valid primary key field (eg no pointer).
|
// field must not be ignored and be a valid primary key field (eg no pointer).
|
||||||
// topLevel must be true only for the top-level struct fields, not for fields of
|
// topLevel must be true only for the top-level struct fields, not for fields of
|
||||||
// deeper levels. Deeper levels cannot have index/unique constraints.
|
// deeper levels. Deeper levels cannot have index/unique constraints.
|
||||||
func gatherTypeFields(t reflect.Type, needFirst, topLevel, inMap bool) ([]field, []embed, error) {
|
func gatherTypeFields(typeSeqs map[reflect.Type]int, t reflect.Type, needFirst, topLevel, inMap, newSeq bool) ([]field, []embed, error) {
|
||||||
var fields []field
|
var fields []field
|
||||||
var embedFields []embed
|
var embedFields []embed
|
||||||
|
|
||||||
|
@ -744,7 +938,7 @@ func gatherTypeFields(t reflect.Type, needFirst, topLevel, inMap bool) ([]field,
|
||||||
}
|
}
|
||||||
names[name] = struct{}{}
|
names[name] = struct{}{}
|
||||||
|
|
||||||
ft, err := gatherFieldType(sf.Type, inMap)
|
ft, err := gatherFieldType(typeSeqs, sf.Type, inMap, newSeq && !sf.Anonymous)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("field %q: %w", sf.Name, err)
|
return nil, nil, fmt.Errorf("field %q: %w", sf.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -817,11 +1011,13 @@ func gatherTypeFields(t reflect.Type, needFirst, topLevel, inMap bool) ([]field,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sf.Anonymous {
|
// We don't store anonymous/embed fields, unless it is a cyclic type, because then
|
||||||
|
// we wouldn't have included any of its type's fields.
|
||||||
|
if sf.Anonymous && ft.FieldsTypeSeq == 0 {
|
||||||
e := embed{name, ft, sf}
|
e := embed{name, ft, sf}
|
||||||
embedFields = append(embedFields, e)
|
embedFields = append(embedFields, e)
|
||||||
} else {
|
} else {
|
||||||
f := field{name, ft, nonzero, tags.List("ref"), defstr, def, sf, nil}
|
f := field{name, ft, nonzero, tags.List("ref"), defstr, def, sf, false, nil}
|
||||||
fields = append(fields, f)
|
fields = append(fields, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -842,12 +1038,13 @@ func checkKeyType(t reflect.Type) error {
|
||||||
return fmt.Errorf("%w: type %v not valid for primary key", ErrType, t)
|
return fmt.Errorf("%w: type %v not valid for primary key", ErrType, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func gatherFieldType(t reflect.Type, inMap bool) (fieldType, error) {
|
func gatherFieldType(typeSeqs map[reflect.Type]int, t reflect.Type, inMap, newSeq bool) (fieldType, error) {
|
||||||
ft := fieldType{}
|
ft := fieldType{}
|
||||||
if t.Kind() == reflect.Ptr {
|
if t.Kind() == reflect.Ptr {
|
||||||
t = t.Elem()
|
t = t.Elem()
|
||||||
ft.Ptr = true
|
ft.Ptr = true
|
||||||
}
|
}
|
||||||
|
|
||||||
k, err := typeKind(t)
|
k, err := typeKind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fieldType{}, err
|
return fieldType{}, err
|
||||||
|
@ -855,32 +1052,52 @@ func gatherFieldType(t reflect.Type, inMap bool) (fieldType, error) {
|
||||||
ft.Kind = k
|
ft.Kind = k
|
||||||
switch ft.Kind {
|
switch ft.Kind {
|
||||||
case kindSlice:
|
case kindSlice:
|
||||||
l, err := gatherFieldType(t.Elem(), inMap)
|
l, err := gatherFieldType(typeSeqs, t.Elem(), inMap, newSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ft, fmt.Errorf("list: %w", err)
|
return ft, fmt.Errorf("slice: %w", err)
|
||||||
}
|
}
|
||||||
ft.List = &l
|
ft.ListElem = &l
|
||||||
|
case kindArray:
|
||||||
|
l, err := gatherFieldType(typeSeqs, t.Elem(), inMap, newSeq)
|
||||||
|
if err != nil {
|
||||||
|
return ft, fmt.Errorf("array: %w", err)
|
||||||
|
}
|
||||||
|
ft.ListElem = &l
|
||||||
|
ft.ArrayLength = t.Len()
|
||||||
case kindMap:
|
case kindMap:
|
||||||
kft, err := gatherFieldType(t.Key(), true)
|
kft, err := gatherFieldType(typeSeqs, t.Key(), true, newSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ft, fmt.Errorf("map key: %w", err)
|
return ft, fmt.Errorf("map key: %w", err)
|
||||||
}
|
}
|
||||||
if kft.Ptr {
|
if kft.Ptr {
|
||||||
return ft, fmt.Errorf("%w: map key with pointer type not supported", ErrType)
|
return ft, fmt.Errorf("%w: map key with pointer type not supported", ErrType)
|
||||||
}
|
}
|
||||||
vft, err := gatherFieldType(t.Elem(), true)
|
vft, err := gatherFieldType(typeSeqs, t.Elem(), true, newSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ft, fmt.Errorf("map value: %w", err)
|
return ft, fmt.Errorf("map value: %w", err)
|
||||||
}
|
}
|
||||||
ft.MapKey = &kft
|
ft.MapKey = &kft
|
||||||
ft.MapValue = &vft
|
ft.MapValue = &vft
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
// note: we have no reason to gather embed field beyond top-level
|
// If this is a known type, track a reference to the earlier defined type. Once the
|
||||||
fields, _, err := gatherTypeFields(t, false, false, inMap)
|
// type with all Fields is fully parsed, the references will be resolved.
|
||||||
|
if seq, ok := typeSeqs[t]; ok {
|
||||||
|
ft.FieldsTypeSeq = -seq
|
||||||
|
return ft, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are processing an anonymous (embed) field, we don't assign a new seq,
|
||||||
|
// because we won't be walking it when resolving again.
|
||||||
|
seq := len(typeSeqs) + 1
|
||||||
|
if newSeq {
|
||||||
|
typeSeqs[t] = seq
|
||||||
|
ft.FieldsTypeSeq = seq
|
||||||
|
}
|
||||||
|
fields, _, err := gatherTypeFields(typeSeqs, t, false, false, inMap, newSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fieldType{}, fmt.Errorf("struct: %w", err)
|
return fieldType{}, fmt.Errorf("struct: %w", err)
|
||||||
}
|
}
|
||||||
ft.Fields = fields
|
ft.DefinitionFields = fields
|
||||||
}
|
}
|
||||||
return ft, nil
|
return ft, nil
|
||||||
}
|
}
|
||||||
|
@ -941,6 +1158,10 @@ tv:
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *field) prepare(nfields []field, later, mvlater [][]field) {
|
func (f *field) prepare(nfields []field, later, mvlater [][]field) {
|
||||||
|
if f.prepared {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.prepared = true
|
||||||
for _, nf := range nfields {
|
for _, nf := range nfields {
|
||||||
if nf.Name == f.Name {
|
if nf.Name == f.Name {
|
||||||
f.structField = nf.structField
|
f.structField = nf.structField
|
||||||
|
@ -954,26 +1175,26 @@ func (ft fieldType) laterFields() (later, mvlater []field) {
|
||||||
later, _ = ft.MapKey.laterFields()
|
later, _ = ft.MapKey.laterFields()
|
||||||
mvlater, _ = ft.MapValue.laterFields()
|
mvlater, _ = ft.MapValue.laterFields()
|
||||||
return later, mvlater
|
return later, mvlater
|
||||||
} else if ft.List != nil {
|
} else if ft.ListElem != nil {
|
||||||
return ft.List.laterFields()
|
return ft.ListElem.laterFields()
|
||||||
}
|
}
|
||||||
return ft.Fields, nil
|
return ft.structFields, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ft fieldType) prepare(nft *fieldType, later, mvlater [][]field) {
|
func (ft fieldType) prepare(nft *fieldType, later, mvlater [][]field) {
|
||||||
for i, f := range ft.Fields {
|
for i, f := range ft.structFields {
|
||||||
nlater, nmvlater, skip := lookupLater(f.Name, later)
|
nlater, nmvlater, skip := lookupLater(f.Name, later)
|
||||||
if skip {
|
if skip {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ft.Fields[i].prepare(nft.Fields, nlater, nmvlater)
|
ft.structFields[i].prepare(nft.structFields, nlater, nmvlater)
|
||||||
}
|
}
|
||||||
if ft.MapKey != nil {
|
if ft.MapKey != nil {
|
||||||
ft.MapKey.prepare(nft.MapKey, later, nil)
|
ft.MapKey.prepare(nft.MapKey, later, nil)
|
||||||
ft.MapValue.prepare(nft.MapValue, mvlater, nil)
|
ft.MapValue.prepare(nft.MapValue, mvlater, nil)
|
||||||
}
|
}
|
||||||
if ft.List != nil {
|
if ft.ListElem != nil {
|
||||||
ft.List.prepare(nft.List, later, mvlater)
|
ft.ListElem.prepare(nft.ListElem, later, mvlater)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1032,18 +1253,24 @@ func (ft fieldType) typeEqual(nft fieldType) bool {
|
||||||
if ft.Ptr != nft.Ptr || ft.Kind != nft.Kind {
|
if ft.Ptr != nft.Ptr || ft.Kind != nft.Kind {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(ft.Fields) != len(nft.Fields) {
|
if ft.FieldsTypeSeq != nft.FieldsTypeSeq {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for i, f := range ft.Fields {
|
if len(ft.DefinitionFields) != len(nft.DefinitionFields) {
|
||||||
if !f.typeEqual(nft.Fields[i]) {
|
return false
|
||||||
|
}
|
||||||
|
for i, f := range ft.DefinitionFields {
|
||||||
|
if !f.typeEqual(nft.DefinitionFields[i]) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ft.MapKey != nil && (!ft.MapKey.typeEqual(*nft.MapKey) || !ft.MapValue.typeEqual(*nft.MapValue)) {
|
if ft.MapKey != nil && (!ft.MapKey.typeEqual(*nft.MapKey) || !ft.MapValue.typeEqual(*nft.MapValue)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if ft.List != nil && !ft.List.typeEqual(*nft.List) {
|
if ft.ListElem != nil && !ft.ListElem.typeEqual(*nft.ListElem) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ft.ArrayLength != nft.ArrayLength {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
@ -1069,12 +1296,16 @@ func (idx *index) typeEqual(nidx *index) bool {
|
||||||
// into an int32. Indices that need to be recreated (for an int width change) are
|
// into an int32. Indices that need to be recreated (for an int width change) are
|
||||||
// recorded in recreateIndices.
|
// recorded in recreateIndices.
|
||||||
func (tx *Tx) checkTypes(otv, ntv *typeVersion, recreateIndices map[string]struct{}) error {
|
func (tx *Tx) checkTypes(otv, ntv *typeVersion, recreateIndices map[string]struct{}) error {
|
||||||
|
// Used to track that two nonzero FieldsTypeSeq have been checked, to prevent
|
||||||
|
// recursing while checking.
|
||||||
|
checked := map[[2]int]struct{}{}
|
||||||
|
|
||||||
for _, f := range ntv.Fields {
|
for _, f := range ntv.Fields {
|
||||||
for _, of := range otv.Fields {
|
for _, of := range otv.Fields {
|
||||||
if f.Name != of.Name {
|
if f.Name != of.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
increase, err := of.Type.compatible(f.Type)
|
increase, err := of.Type.compatible(f.Type, checked)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: field %q: %s", ErrIncompatible, f.Name, err)
|
return fmt.Errorf("%w: field %q: %s", ErrIncompatible, f.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -1099,7 +1330,7 @@ func (tx *Tx) checkTypes(otv, ntv *typeVersion, recreateIndices map[string]struc
|
||||||
// for maps/slices/structs). If not an error is returned. If they are, the first
|
// for maps/slices/structs). If not an error is returned. If they are, the first
|
||||||
// return value indicates if this is a field that needs it index recreated
|
// return value indicates if this is a field that needs it index recreated
|
||||||
// (currently for ints that are packed with fixed width encoding).
|
// (currently for ints that are packed with fixed width encoding).
|
||||||
func (ft fieldType) compatible(nft fieldType) (bool, error) {
|
func (ft fieldType) compatible(nft fieldType, checked map[[2]int]struct{}) (bool, error) {
|
||||||
need := func(incr bool, l ...kind) (bool, error) {
|
need := func(incr bool, l ...kind) (bool, error) {
|
||||||
for _, k := range l {
|
for _, k := range l {
|
||||||
if nft.Kind == k {
|
if nft.Kind == k {
|
||||||
|
@ -1160,10 +1391,10 @@ func (ft fieldType) compatible(nft fieldType) (bool, error) {
|
||||||
if nk != k {
|
if nk != k {
|
||||||
return false, fmt.Errorf("map to %v: %w", nk, ErrIncompatible)
|
return false, fmt.Errorf("map to %v: %w", nk, ErrIncompatible)
|
||||||
}
|
}
|
||||||
if _, err := ft.MapKey.compatible(*nft.MapKey); err != nil {
|
if _, err := ft.MapKey.compatible(*nft.MapKey, checked); err != nil {
|
||||||
return false, fmt.Errorf("map key: %w", err)
|
return false, fmt.Errorf("map key: %w", err)
|
||||||
}
|
}
|
||||||
if _, err := ft.MapValue.compatible(*nft.MapValue); err != nil {
|
if _, err := ft.MapValue.compatible(*nft.MapValue, checked); err != nil {
|
||||||
return false, fmt.Errorf("map value: %w", err)
|
return false, fmt.Errorf("map value: %w", err)
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
|
@ -1171,18 +1402,41 @@ func (ft fieldType) compatible(nft fieldType) (bool, error) {
|
||||||
if nk != k {
|
if nk != k {
|
||||||
return false, fmt.Errorf("slice to %v: %w", nk, ErrIncompatible)
|
return false, fmt.Errorf("slice to %v: %w", nk, ErrIncompatible)
|
||||||
}
|
}
|
||||||
if _, err := ft.List.compatible(*nft.List); err != nil {
|
if _, err := ft.ListElem.compatible(*nft.ListElem, checked); err != nil {
|
||||||
return false, fmt.Errorf("list: %w", err)
|
return false, fmt.Errorf("slice: %w", err)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
case kindArray:
|
||||||
|
if nk != k {
|
||||||
|
return false, fmt.Errorf("array to %v: %w", nk, ErrIncompatible)
|
||||||
|
}
|
||||||
|
if nft.ArrayLength != ft.ArrayLength {
|
||||||
|
return false, fmt.Errorf("array size cannot change (from %d to %d)", ft.ArrayLength, nft.ArrayLength)
|
||||||
|
}
|
||||||
|
if _, err := ft.ListElem.compatible(*nft.ListElem, checked); err != nil {
|
||||||
|
return false, fmt.Errorf("array: %w", err)
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
if nk != k {
|
if nk != k {
|
||||||
return false, fmt.Errorf("struct to %v: %w", nk, ErrIncompatible)
|
return false, fmt.Errorf("struct to %v: %w", nk, ErrIncompatible)
|
||||||
}
|
}
|
||||||
for _, nf := range nft.Fields {
|
|
||||||
for _, f := range ft.Fields {
|
// For ondiskVersion2, the seqs are both nonzero, and we must check that we already
|
||||||
|
// did the check to prevent recursion.
|
||||||
|
haveSeq := nft.FieldsTypeSeq != 0 || ft.FieldsTypeSeq != 0
|
||||||
|
if haveSeq {
|
||||||
|
k := [2]int{nft.FieldsTypeSeq, ft.FieldsTypeSeq}
|
||||||
|
if _, ok := checked[k]; ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
checked[k] = struct{}{} // Set early to prevent recursion in call below.
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nf := range nft.structFields {
|
||||||
|
for _, f := range ft.structFields {
|
||||||
if nf.Name == f.Name {
|
if nf.Name == f.Name {
|
||||||
_, err := f.Type.compatible(nf.Type)
|
_, err := f.Type.compatible(nf.Type, checked)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("field %q: %w", nf.Name, err)
|
return false, fmt.Errorf("field %q: %w", nf.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -1201,11 +1455,11 @@ func (ft fieldType) hasNonzeroField(stopAtPtr bool) bool {
|
||||||
}
|
}
|
||||||
switch ft.Kind {
|
switch ft.Kind {
|
||||||
case kindMap:
|
case kindMap:
|
||||||
return ft.List.hasNonzeroField(true)
|
|
||||||
case kindSlice:
|
|
||||||
return ft.MapValue.hasNonzeroField(true)
|
return ft.MapValue.hasNonzeroField(true)
|
||||||
|
case kindSlice, kindArray:
|
||||||
|
return ft.ListElem.hasNonzeroField(true)
|
||||||
case kindStruct:
|
case kindStruct:
|
||||||
for _, f := range ft.Fields {
|
for _, f := range ft.structFields {
|
||||||
if f.Nonzero || f.Type.hasNonzeroField(true) {
|
if f.Nonzero || f.Type.hasNonzeroField(true) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
305
vendor/github.com/mjl-/bstore/store.go
generated
vendored
305
vendor/github.com/mjl-/bstore/store.go
generated
vendored
|
@ -1,7 +1,9 @@
|
||||||
package bstore
|
package bstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding"
|
"encoding"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -14,6 +16,18 @@ import (
|
||||||
bolt "go.etcd.io/bbolt"
|
bolt "go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
- todo: should thoroughly review guarantees, where some of the bstore struct tags are allowed (e.g. top-level fields vs deeper struct fields), check that all features work well when combined (cyclic types, embed structs, default values, nonzero checks, type equality, zero values with fieldmap, skipping values (hidden due to later typeversions) and having different type versions), write more extensive tests.
|
||||||
|
- todo: write tests for invalid (meta)data inside the boltdb buckets (not for invalid boltdb files). we should detect the error properly, give a reasonable message. we shouldn't panic (nil deref, out of bounds index, consume too much memory). typeVersions, records, indices.
|
||||||
|
- todo: add benchmarks. is there a standard dataset databases use for benchmarking?
|
||||||
|
- todo optimize: profile and see if we can optimize for some quick wins.
|
||||||
|
- todo: should we add a way for ad-hoc data manipulation? e.g. with sql-like queries, e.g. update, delete, insert; and export results of queries to csv.
|
||||||
|
- todo: should we have a function that returns records in a map? eg Map() that is like List() but maps a key to T (too bad we cannot have a type for the key!).
|
||||||
|
- todo: better error messages (ordering of description & error; mention typename, fields (path), field types and offending value & type more often)
|
||||||
|
- todo: should we add types for dates and numerics?
|
||||||
|
- todo: struct tag for enums? where we check if the values match.
|
||||||
|
*/
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrAbsent = errors.New("absent") // If a function can return an ErrAbsent, it can be compared directly, without errors.Is.
|
ErrAbsent = errors.New("absent") // If a function can return an ErrAbsent, it can be compared directly, without errors.Is.
|
||||||
ErrZero = errors.New("must be nonzero")
|
ErrZero = errors.New("must be nonzero")
|
||||||
|
@ -26,6 +40,7 @@ var (
|
||||||
ErrFinished = errors.New("query finished")
|
ErrFinished = errors.New("query finished")
|
||||||
ErrStore = errors.New("internal/storage error") // E.g. when buckets disappear, possibly by external users of the underlying BoltDB database.
|
ErrStore = errors.New("internal/storage error") // E.g. when buckets disappear, possibly by external users of the underlying BoltDB database.
|
||||||
ErrParam = errors.New("bad parameters")
|
ErrParam = errors.New("bad parameters")
|
||||||
|
ErrTxBotched = errors.New("botched transaction") // Set on transactions after failed and aborted write operations.
|
||||||
|
|
||||||
errTxClosed = errors.New("transaction is closed")
|
errTxClosed = errors.New("transaction is closed")
|
||||||
errNestedIndex = errors.New("struct tags index/unique only allowed at top-level structs")
|
errNestedIndex = errors.New("struct tags index/unique only allowed at top-level structs")
|
||||||
|
@ -42,7 +57,7 @@ type DB struct {
|
||||||
// needs a wlock.
|
// needs a wlock.
|
||||||
typesMutex sync.RWMutex
|
typesMutex sync.RWMutex
|
||||||
types map[reflect.Type]storeType
|
types map[reflect.Type]storeType
|
||||||
typeNames map[string]storeType // Go type name to store type, for checking duplicates.
|
typeNames map[string]storeType // Type name to store type, for checking duplicates.
|
||||||
|
|
||||||
statsMutex sync.Mutex
|
statsMutex sync.Mutex
|
||||||
stats Stats
|
stats Stats
|
||||||
|
@ -52,7 +67,9 @@ type DB struct {
|
||||||
//
|
//
|
||||||
// A Tx is not safe for concurrent use.
|
// A Tx is not safe for concurrent use.
|
||||||
type Tx struct {
|
type Tx struct {
|
||||||
db *DB // If nil, this transaction is closed.
|
ctx context.Context // Check before starting operations, query next calls, and during foreach.
|
||||||
|
err error // If not nil, operations return this error. Set when write operations fail, e.g. insert with constraint violations.
|
||||||
|
db *DB // If nil, this transaction is closed.
|
||||||
btx *bolt.Tx
|
btx *bolt.Tx
|
||||||
|
|
||||||
bucketCache map[bucketKey]*bolt.Bucket
|
bucketCache map[bucketKey]*bolt.Bucket
|
||||||
|
@ -109,9 +126,9 @@ type typeVersion struct {
|
||||||
type field struct {
|
type field struct {
|
||||||
Name string
|
Name string
|
||||||
Type fieldType
|
Type fieldType
|
||||||
Nonzero bool
|
Nonzero bool `json:",omitempty"`
|
||||||
References []string // Referenced fields. Only for the top-level struct fields, not for nested structs.
|
References []string `json:",omitempty"` // Referenced fields. Only for the top-level struct fields, not for nested structs.
|
||||||
Default string // As specified in struct tag. Processed version is defaultValue.
|
Default string `json:",omitempty"` // As specified in struct tag. Processed version is defaultValue.
|
||||||
|
|
||||||
// If not the zero reflect.Value, set this value instead of a zero value on insert.
|
// If not the zero reflect.Value, set this value instead of a zero value on insert.
|
||||||
// This is always a non-pointer value. Only set for the current typeVersion
|
// This is always a non-pointer value. Only set for the current typeVersion
|
||||||
|
@ -123,6 +140,9 @@ type field struct {
|
||||||
// if this field is no longer in the type, or if it has been removed and
|
// if this field is no longer in the type, or if it has been removed and
|
||||||
// added again in later schema versions.
|
// added again in later schema versions.
|
||||||
structField reflect.StructField
|
structField reflect.StructField
|
||||||
|
// Whether this field has been prepared for parsing into, i.e.
|
||||||
|
// structField set if needed.
|
||||||
|
prepared bool
|
||||||
|
|
||||||
indices map[string]*index
|
indices map[string]*index
|
||||||
}
|
}
|
||||||
|
@ -134,79 +154,135 @@ type embed struct {
|
||||||
structField reflect.StructField
|
structField reflect.StructField
|
||||||
}
|
}
|
||||||
|
|
||||||
type kind int
|
type kind string
|
||||||
|
|
||||||
|
func (k kind) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(string(k))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *kind) UnmarshalJSON(buf []byte) error {
|
||||||
|
if string(buf) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(buf) > 0 && buf[0] == '"' {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(buf, &s); err != nil {
|
||||||
|
return fmt.Errorf("parsing fieldType.Kind string value %q: %v", buf, err)
|
||||||
|
}
|
||||||
|
nk, ok := kindsMap[s]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown fieldType.Kind value %q", s)
|
||||||
|
}
|
||||||
|
*k = nk
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// In ondiskVersion1, the kinds were integers, starting at 1.
|
||||||
|
var i int
|
||||||
|
if err := json.Unmarshal(buf, &i); err != nil {
|
||||||
|
return fmt.Errorf("parsing fieldType.Kind int value %q: %v", buf, err)
|
||||||
|
}
|
||||||
|
if i <= 0 || i-1 >= len(kinds) {
|
||||||
|
return fmt.Errorf("unknown fieldType.Kind value %d", i)
|
||||||
|
}
|
||||||
|
*k = kinds[i-1]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
kindInvalid kind = iota
|
kindBytes kind = "bytes" // 1, etc
|
||||||
kindBytes
|
kindBool kind = "bool"
|
||||||
kindBool
|
kindInt kind = "int"
|
||||||
kindInt
|
kindInt8 kind = "int8"
|
||||||
kindInt8
|
kindInt16 kind = "int16"
|
||||||
kindInt16
|
kindInt32 kind = "int32"
|
||||||
kindInt32
|
kindInt64 kind = "int64"
|
||||||
kindInt64
|
kindUint kind = "uint"
|
||||||
kindUint
|
kindUint8 kind = "uint8"
|
||||||
kindUint8
|
kindUint16 kind = "uint16"
|
||||||
kindUint16
|
kindUint32 kind = "uint32"
|
||||||
kindUint32
|
kindUint64 kind = "uint64"
|
||||||
kindUint64
|
kindFloat32 kind = "float32"
|
||||||
kindFloat32
|
kindFloat64 kind = "float64"
|
||||||
kindFloat64
|
kindMap kind = "map"
|
||||||
kindMap
|
kindSlice kind = "slice"
|
||||||
kindSlice
|
kindString kind = "string"
|
||||||
kindString
|
kindTime kind = "time"
|
||||||
kindTime
|
kindBinaryMarshal kind = "binarymarshal"
|
||||||
kindBinaryMarshal
|
kindStruct kind = "struct"
|
||||||
kindStruct
|
kindArray kind = "array"
|
||||||
)
|
)
|
||||||
|
|
||||||
var kindStrings = []string{
|
// In ondiskVersion1, the kinds were integers, starting at 1.
|
||||||
"(invalid)",
|
// Do not change the order. Add new values at the end.
|
||||||
"bytes",
|
var kinds = []kind{
|
||||||
"bool",
|
kindBytes,
|
||||||
"int",
|
kindBool,
|
||||||
"int8",
|
kindInt,
|
||||||
"int16",
|
kindInt8,
|
||||||
"int32",
|
kindInt16,
|
||||||
"int64",
|
kindInt32,
|
||||||
"uint",
|
kindInt64,
|
||||||
"uint8",
|
kindUint,
|
||||||
"uint16",
|
kindUint8,
|
||||||
"uint32",
|
kindUint16,
|
||||||
"uint64",
|
kindUint32,
|
||||||
"float32",
|
kindUint64,
|
||||||
"float64",
|
kindFloat32,
|
||||||
"map",
|
kindFloat64,
|
||||||
"slice",
|
kindMap,
|
||||||
"string",
|
kindSlice,
|
||||||
"time",
|
kindString,
|
||||||
"binarymarshal",
|
kindTime,
|
||||||
"struct",
|
kindBinaryMarshal,
|
||||||
|
kindStruct,
|
||||||
|
kindArray,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k kind) String() string {
|
func makeKindsMap() map[string]kind {
|
||||||
return kindStrings[k]
|
m := map[string]kind{}
|
||||||
|
for _, k := range kinds {
|
||||||
|
m[string(k)] = k
|
||||||
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var kindsMap = makeKindsMap()
|
||||||
|
|
||||||
type fieldType struct {
|
type fieldType struct {
|
||||||
Ptr bool // If type is a pointer.
|
Ptr bool `json:",omitempty"` // If type is a pointer.
|
||||||
Kind kind // Type with possible Ptr deferenced.
|
Kind kind // Type with possible Ptr deferenced.
|
||||||
Fields []field // For kindStruct.
|
|
||||||
MapKey, MapValue *fieldType // For kindMap.
|
|
||||||
List *fieldType // For kindSlice.
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft fieldType) String() string {
|
MapKey *fieldType `json:",omitempty"`
|
||||||
s := ft.Kind.String()
|
MapValue *fieldType `json:",omitempty"` // For kindMap.
|
||||||
if ft.Ptr {
|
ListElem *fieldType `json:"List,omitempty"` // For kindSlice and kindArray. Named List in JSON for compatibility.
|
||||||
return s + "ptr"
|
ArrayLength int `json:",omitempty"` // For kindArray.
|
||||||
}
|
|
||||||
return s
|
// For kindStruct, the fields of the struct. Only set for the first use of the type
|
||||||
|
// within a registered type. Code dealing with fields should use structFields
|
||||||
|
// (below) most of the time instead, it has the effective fields after resolving
|
||||||
|
// the type reference.
|
||||||
|
// Named "Fields" in JSON to stay compatible with ondiskVersion1, named
|
||||||
|
// DefinitionFields in Go for clarity.
|
||||||
|
DefinitionFields []field `json:"Fields,omitempty"`
|
||||||
|
|
||||||
|
// For struct types, the sequence number of this type (within the registered type).
|
||||||
|
// Needed for supporting cyclic types. Each struct type is assigned the next
|
||||||
|
// sequence number. The registered type implicitly has sequence 1. If positive,
|
||||||
|
// this defines a type (i.e. when it is first encountered analyzing fields
|
||||||
|
// depth-first). If negative, it references the type with positive seq (when a
|
||||||
|
// field is encountered of a type that was seen before). New since ondiskVersion2,
|
||||||
|
// structs in ondiskVersion1 will have zero value 0.
|
||||||
|
FieldsTypeSeq int `json:",omitempty"`
|
||||||
|
|
||||||
|
// Fields after taking cyclic types into account. Set when registering/loading a
|
||||||
|
// type. Not stored on disk because of potential cyclic data.
|
||||||
|
structFields []field
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options configure how a database should be opened or initialized.
|
// Options configure how a database should be opened or initialized.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Timeout time.Duration // Abort if opening DB takes longer than Timeout.
|
Timeout time.Duration // Abort if opening DB takes longer than Timeout. If not set, the deadline from the context is used.
|
||||||
Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used.
|
Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used.
|
||||||
MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned.
|
MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned.
|
||||||
}
|
}
|
||||||
|
@ -219,10 +295,19 @@ type Options struct {
|
||||||
//
|
//
|
||||||
// Only one DB instance can be open for a file at a time. Use opts.Timeout to
|
// Only one DB instance can be open for a file at a time. Use opts.Timeout to
|
||||||
// specify a timeout during open to prevent indefinite blocking.
|
// specify a timeout during open to prevent indefinite blocking.
|
||||||
func Open(path string, opts *Options, typeValues ...any) (*DB, error) {
|
//
|
||||||
|
// The context is used for opening and initializing the database, not for further
|
||||||
|
// operations. If the context is canceled while waiting on the database file lock,
|
||||||
|
// the operation is not aborted other than when the deadline/timeout is reached.
|
||||||
|
//
|
||||||
|
// See function Register for checks for changed/unchanged schema during open
|
||||||
|
// based on environment variable "bstore_schema_check".
|
||||||
|
func Open(ctx context.Context, path string, opts *Options, typeValues ...any) (*DB, error) {
|
||||||
var bopts *bolt.Options
|
var bopts *bolt.Options
|
||||||
if opts != nil && opts.Timeout > 0 {
|
if opts != nil && opts.Timeout > 0 {
|
||||||
bopts = &bolt.Options{Timeout: opts.Timeout}
|
bopts = &bolt.Options{Timeout: opts.Timeout}
|
||||||
|
} else if end, ok := ctx.Deadline(); ok {
|
||||||
|
bopts = &bolt.Options{Timeout: time.Until(end)}
|
||||||
}
|
}
|
||||||
var mode fs.FileMode = 0600
|
var mode fs.FileMode = 0600
|
||||||
if opts != nil && opts.Perm != 0 {
|
if opts != nil && opts.Perm != 0 {
|
||||||
|
@ -241,7 +326,7 @@ func Open(path string, opts *Options, typeValues ...any) (*DB, error) {
|
||||||
typeNames := map[string]storeType{}
|
typeNames := map[string]storeType{}
|
||||||
types := map[reflect.Type]storeType{}
|
types := map[reflect.Type]storeType{}
|
||||||
db := &DB{bdb: bdb, typeNames: typeNames, types: types}
|
db := &DB{bdb: bdb, typeNames: typeNames, types: types}
|
||||||
if err := db.Register(typeValues...); err != nil {
|
if err := db.Register(ctx, typeValues...); err != nil {
|
||||||
bdb.Close()
|
bdb.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -272,6 +357,9 @@ func (tx *Tx) Stats() Stats {
|
||||||
|
|
||||||
// WriteTo writes the entire database to w, not including changes made during this transaction.
|
// WriteTo writes the entire database to w, not including changes made during this transaction.
|
||||||
func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
|
func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
if err := tx.error(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
return tx.btx.WriteTo(w)
|
return tx.btx.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,58 +416,67 @@ func (tx *Tx) indexBucket(idx *index) (*bolt.Bucket, error) {
|
||||||
// If a type is still referenced by another type, eg through a "ref" struct tag,
|
// If a type is still referenced by another type, eg through a "ref" struct tag,
|
||||||
// ErrReference is returned.
|
// ErrReference is returned.
|
||||||
// If the type does not exist, ErrAbsent is returned.
|
// If the type does not exist, ErrAbsent is returned.
|
||||||
func (db *DB) Drop(name string) error {
|
func (db *DB) Drop(ctx context.Context, name string) error {
|
||||||
return db.Write(func(tx *Tx) error {
|
var st storeType
|
||||||
|
var ok bool
|
||||||
|
err := db.Write(ctx, func(tx *Tx) error {
|
||||||
tx.stats.Bucket.Get++
|
tx.stats.Bucket.Get++
|
||||||
if tx.btx.Bucket([]byte(name)) == nil {
|
if tx.btx.Bucket([]byte(name)) == nil {
|
||||||
return ErrAbsent
|
return ErrAbsent
|
||||||
}
|
}
|
||||||
|
|
||||||
if st, ok := db.typeNames[name]; ok && len(st.Current.referencedBy) > 0 {
|
st, ok = db.typeNames[name]
|
||||||
|
if ok && len(st.Current.referencedBy) > 0 {
|
||||||
return fmt.Errorf("%w: type is still referenced", ErrReference)
|
return fmt.Errorf("%w: type is still referenced", ErrReference)
|
||||||
} else if ok {
|
|
||||||
for ref := range st.Current.references {
|
|
||||||
var n []*index
|
|
||||||
for _, idx := range db.typeNames[ref].Current.referencedBy {
|
|
||||||
if idx.tv != st.Current {
|
|
||||||
n = append(n, idx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
db.typeNames[ref].Current.referencedBy = n
|
|
||||||
}
|
|
||||||
delete(db.typeNames, name)
|
|
||||||
delete(db.types, st.Type)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.stats.Bucket.Delete++
|
tx.stats.Bucket.Delete++
|
||||||
return tx.btx.DeleteBucket([]byte(name))
|
return tx.btx.DeleteBucket([]byte(name))
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
for ref := range st.Current.references {
|
||||||
|
var n []*index
|
||||||
|
for _, idx := range db.typeNames[ref].Current.referencedBy {
|
||||||
|
if idx.tv != st.Current {
|
||||||
|
n = append(n, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.typeNames[ref].Current.referencedBy = n
|
||||||
|
}
|
||||||
|
delete(db.typeNames, name)
|
||||||
|
delete(db.types, st.Type)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete calls Delete on a new writable Tx.
|
// Delete calls Delete on a new writable Tx.
|
||||||
func (db *DB) Delete(values ...any) error {
|
func (db *DB) Delete(ctx context.Context, values ...any) error {
|
||||||
return db.Write(func(tx *Tx) error {
|
return db.Write(ctx, func(tx *Tx) error {
|
||||||
return tx.Delete(values...)
|
return tx.Delete(values...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get calls Get on a new read-only Tx.
|
// Get calls Get on a new read-only Tx.
|
||||||
func (db *DB) Get(values ...any) error {
|
func (db *DB) Get(ctx context.Context, values ...any) error {
|
||||||
return db.Read(func(tx *Tx) error {
|
return db.Read(ctx, func(tx *Tx) error {
|
||||||
return tx.Get(values...)
|
return tx.Get(values...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert calls Insert on a new writable Tx.
|
// Insert calls Insert on a new writable Tx.
|
||||||
func (db *DB) Insert(values ...any) error {
|
func (db *DB) Insert(ctx context.Context, values ...any) error {
|
||||||
return db.Write(func(tx *Tx) error {
|
return db.Write(ctx, func(tx *Tx) error {
|
||||||
return tx.Insert(values...)
|
return tx.Insert(values...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update calls Update on a new writable Tx.
|
// Update calls Update on a new writable Tx.
|
||||||
func (db *DB) Update(values ...any) error {
|
func (db *DB) Update(ctx context.Context, values ...any) error {
|
||||||
return db.Write(func(tx *Tx) error {
|
return db.Write(ctx, func(tx *Tx) error {
|
||||||
return tx.Update(values...)
|
return tx.Update(values...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -401,6 +498,7 @@ var typeKinds = map[reflect.Kind]kind{
|
||||||
reflect.Map: kindMap,
|
reflect.Map: kindMap,
|
||||||
reflect.Slice: kindSlice,
|
reflect.Slice: kindSlice,
|
||||||
reflect.String: kindString,
|
reflect.String: kindString,
|
||||||
|
reflect.Array: kindArray,
|
||||||
}
|
}
|
||||||
|
|
||||||
func typeKind(t reflect.Type) (kind, error) {
|
func typeKind(t reflect.Type) (kind, error) {
|
||||||
|
@ -424,7 +522,10 @@ func typeKind(t reflect.Type) (kind, error) {
|
||||||
if t.Kind() == reflect.Struct {
|
if t.Kind() == reflect.Struct {
|
||||||
return kindStruct, nil
|
return kindStruct, nil
|
||||||
}
|
}
|
||||||
return kind(0), fmt.Errorf("%w: unsupported type %v", ErrType, t)
|
if t.Kind() == reflect.Ptr {
|
||||||
|
return "", fmt.Errorf("%w: pointer to pointers not supported: %v", ErrType, t.Elem())
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("%w: unsupported type %v", ErrType, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func typeName(t reflect.Type) (string, error) {
|
func typeName(t reflect.Type) (string, error) {
|
||||||
|
@ -509,27 +610,39 @@ func (tv typeVersion) keyValue(tx *Tx, rv reflect.Value, insert bool, rb *bolt.B
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read calls function fn with a new read-only transaction, ensuring transaction rollback.
|
// Read calls function fn with a new read-only transaction, ensuring transaction rollback.
|
||||||
func (db *DB) Read(fn func(*Tx) error) error {
|
func (db *DB) Read(ctx context.Context, fn func(*Tx) error) error {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
db.typesMutex.RLock()
|
db.typesMutex.RLock()
|
||||||
defer db.typesMutex.RUnlock()
|
defer db.typesMutex.RUnlock()
|
||||||
return db.bdb.View(func(btx *bolt.Tx) error {
|
return db.bdb.View(func(btx *bolt.Tx) error {
|
||||||
tx := &Tx{db: db, btx: btx}
|
tx := &Tx{ctx: ctx, db: db, btx: btx}
|
||||||
tx.stats.Reads++
|
tx.stats.Reads++
|
||||||
defer tx.addStats()
|
defer tx.addStats()
|
||||||
return fn(tx)
|
if err := fn(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write calls function fn with a new read-write transaction. If fn returns
|
// Write calls function fn with a new read-write transaction. If fn returns
|
||||||
// nil, the transaction is committed. Otherwise the transaction is rolled back.
|
// nil, the transaction is committed. Otherwise the transaction is rolled back.
|
||||||
func (db *DB) Write(fn func(*Tx) error) error {
|
func (db *DB) Write(ctx context.Context, fn func(*Tx) error) error {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
db.typesMutex.RLock()
|
db.typesMutex.RLock()
|
||||||
defer db.typesMutex.RUnlock()
|
defer db.typesMutex.RUnlock()
|
||||||
return db.bdb.Update(func(btx *bolt.Tx) error {
|
return db.bdb.Update(func(btx *bolt.Tx) error {
|
||||||
tx := &Tx{db: db, btx: btx}
|
tx := &Tx{ctx: ctx, db: db, btx: btx}
|
||||||
tx.stats.Writes++
|
tx.stats.Writes++
|
||||||
defer tx.addStats()
|
defer tx.addStats()
|
||||||
return fn(tx)
|
if err := fn(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
128
vendor/github.com/mjl-/bstore/tx.go
generated
vendored
128
vendor/github.com/mjl-/bstore/tx.go
generated
vendored
|
@ -2,12 +2,39 @@ package bstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
bolt "go.etcd.io/bbolt"
|
bolt "go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Mark a tx as botched, mentioning last actual error.
|
||||||
|
// Used when write operations fail. The transaction can be in inconsistent
|
||||||
|
// state, e.g. only some of a type's indicies may have been updated. We never
|
||||||
|
// want to commit such transactions.
|
||||||
|
func (tx *Tx) markError(err *error) {
|
||||||
|
if *err != nil && tx.err == nil {
|
||||||
|
tx.err = fmt.Errorf("%w (after %v)", ErrTxBotched, *err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if an error condition is set on on the transaction. To be called before
|
||||||
|
// starting an operation.
|
||||||
|
func (tx *Tx) error() error {
|
||||||
|
if tx.err != nil {
|
||||||
|
return tx.err
|
||||||
|
}
|
||||||
|
if tx.db == nil {
|
||||||
|
return errTxClosed
|
||||||
|
}
|
||||||
|
if err := tx.ctx.Err(); err != nil {
|
||||||
|
tx.err = err
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *Tx) structptr(value any) (reflect.Value, error) {
|
func (tx *Tx) structptr(value any) (reflect.Value, error) {
|
||||||
rv := reflect.ValueOf(value)
|
rv := reflect.ValueOf(value)
|
||||||
if !rv.IsValid() || rv.Kind() != reflect.Ptr || !rv.Elem().IsValid() || rv.Type().Elem().Kind() != reflect.Struct {
|
if !rv.IsValid() || rv.Kind() != reflect.Ptr || !rv.Elem().IsValid() || rv.Type().Elem().Kind() != reflect.Struct {
|
||||||
|
@ -42,10 +69,23 @@ func (tx *Tx) updateIndices(tv *typeVersion, pk []byte, ov, v reflect.Value) err
|
||||||
|
|
||||||
changed := func(idx *index) bool {
|
changed := func(idx *index) bool {
|
||||||
for _, f := range idx.Fields {
|
for _, f := range idx.Fields {
|
||||||
rofv := ov.FieldByIndex(f.structField.Index)
|
ofv := ov.FieldByIndex(f.structField.Index)
|
||||||
nofv := v.FieldByIndex(f.structField.Index)
|
nfv := v.FieldByIndex(f.structField.Index)
|
||||||
// note: checking the interface values is enough, we only allow comparable types as index fields.
|
if f.Type.Kind == kindSlice {
|
||||||
if rofv.Interface() != nofv.Interface() {
|
// Index field is a slice type, cannot use direct interface comparison.
|
||||||
|
on := ofv.Len()
|
||||||
|
nn := nfv.Len()
|
||||||
|
if on != nn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i := 0; i < nn; i++ {
|
||||||
|
// Slice elements are comparable.
|
||||||
|
if ofv.Index(i) != nfv.Index(i) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if ofv.Interface() != nfv.Interface() {
|
||||||
|
// note: checking the interface values is enough.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,36 +109,40 @@ func (tx *Tx) updateIndices(tv *typeVersion, pk []byte, ov, v reflect.Value) err
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if remove {
|
if remove {
|
||||||
_, ik, err := idx.packKey(ov, pk)
|
ikl, err := idx.packKey(ov, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tx.stats.Index.Delete++
|
for _, ik := range ikl {
|
||||||
if sanityChecks {
|
tx.stats.Index.Delete++
|
||||||
tx.stats.Index.Get++
|
if sanityChecks {
|
||||||
if ib.Get(ik) == nil {
|
tx.stats.Index.Get++
|
||||||
return fmt.Errorf("internal error: key missing from index")
|
if ib.Get(ik.full) == nil {
|
||||||
|
return fmt.Errorf("%w: key missing from index", ErrStore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := ib.Delete(ik.full); err != nil {
|
||||||
|
return fmt.Errorf("%w: removing from index: %s", ErrStore, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if err := ib.Delete(ik); err != nil {
|
|
||||||
return fmt.Errorf("%w: removing from index: %s", ErrStore, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if add {
|
if add {
|
||||||
prek, ik, err := idx.packKey(v, pk)
|
ikl, err := idx.packKey(v, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if idx.Unique {
|
for _, ik := range ikl {
|
||||||
tx.stats.Index.Cursor++
|
if idx.Unique {
|
||||||
if xk, _ := ib.Cursor().Seek(prek); xk != nil && bytes.HasPrefix(xk, prek) {
|
tx.stats.Index.Cursor++
|
||||||
return fmt.Errorf("%w: %q", ErrUnique, idx.Name)
|
if xk, _ := ib.Cursor().Seek(ik.pre); xk != nil && bytes.HasPrefix(xk, ik.pre) {
|
||||||
|
return fmt.Errorf("%w: %q", ErrUnique, idx.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
tx.stats.Index.Put++
|
tx.stats.Index.Put++
|
||||||
if err := ib.Put(ik, []byte{}); err != nil {
|
if err := ib.Put(ik.full, []byte{}); err != nil {
|
||||||
return fmt.Errorf("inserting into index: %w", err)
|
return fmt.Errorf("inserting into index: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,7 +168,7 @@ func (tx *Tx) checkReferences(tv *typeVersion, pk []byte, ov, rv reflect.Value)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if rb.Get(k) == nil {
|
if rb.Get(k) == nil {
|
||||||
return fmt.Errorf("%w: value %v from field %q to %q", ErrReference, frv.Interface(), f.Name, name)
|
return fmt.Errorf("%w: value %v from %q to %q", ErrReference, frv.Interface(), tv.name+"."+f.Name, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,8 +187,8 @@ func (tx *Tx) addStats() {
|
||||||
//
|
//
|
||||||
// ErrAbsent is returned if the record does not exist.
|
// ErrAbsent is returned if the record does not exist.
|
||||||
func (tx *Tx) Get(values ...any) error {
|
func (tx *Tx) Get(values ...any) error {
|
||||||
if tx.db == nil {
|
if err := tx.error(); err != nil {
|
||||||
return errTxClosed
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
@ -184,8 +228,8 @@ func (tx *Tx) Get(values ...any) error {
|
||||||
// ErrAbsent is returned if the record does not exist.
|
// ErrAbsent is returned if the record does not exist.
|
||||||
// ErrReference is returned if another record still references this record.
|
// ErrReference is returned if another record still references this record.
|
||||||
func (tx *Tx) Delete(values ...any) error {
|
func (tx *Tx) Delete(values ...any) error {
|
||||||
if tx.db == nil {
|
if err := tx.error(); err != nil {
|
||||||
return errTxClosed
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
@ -222,7 +266,7 @@ func (tx *Tx) Delete(values ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) delete(rb *bolt.Bucket, st storeType, k []byte, rov reflect.Value) error {
|
func (tx *Tx) delete(rb *bolt.Bucket, st storeType, k []byte, rov reflect.Value) (rerr error) {
|
||||||
// Check that anyone referencing this type does not reference this record.
|
// Check that anyone referencing this type does not reference this record.
|
||||||
for _, refBy := range st.Current.referencedBy {
|
for _, refBy := range st.Current.referencedBy {
|
||||||
if ib, err := tx.indexBucket(refBy); err != nil {
|
if ib, err := tx.indexBucket(refBy); err != nil {
|
||||||
|
@ -236,6 +280,7 @@ func (tx *Tx) delete(rb *bolt.Bucket, st storeType, k []byte, rov reflect.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete value from indices.
|
// Delete value from indices.
|
||||||
|
defer tx.markError(&rerr)
|
||||||
if err := tx.updateIndices(st.Current, k, rov, reflect.Value{}); err != nil {
|
if err := tx.updateIndices(st.Current, k, rov, reflect.Value{}); err != nil {
|
||||||
return fmt.Errorf("removing from indices: %w", err)
|
return fmt.Errorf("removing from indices: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -250,8 +295,8 @@ func (tx *Tx) delete(rb *bolt.Bucket, st storeType, k []byte, rov reflect.Value)
|
||||||
//
|
//
|
||||||
// ErrAbsent is returned if the record does not exist.
|
// ErrAbsent is returned if the record does not exist.
|
||||||
func (tx *Tx) Update(values ...any) error {
|
func (tx *Tx) Update(values ...any) error {
|
||||||
if tx.db == nil {
|
if err := tx.error(); err != nil {
|
||||||
return errTxClosed
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
@ -282,8 +327,8 @@ func (tx *Tx) Update(values ...any) error {
|
||||||
// ErrZero is returned if a nonzero constraint would be violated.
|
// ErrZero is returned if a nonzero constraint would be violated.
|
||||||
// ErrReference is returned if another record is referenced that does not exist.
|
// ErrReference is returned if another record is referenced that does not exist.
|
||||||
func (tx *Tx) Insert(values ...any) error {
|
func (tx *Tx) Insert(values ...any) error {
|
||||||
if tx.db == nil {
|
if err := tx.error(); err != nil {
|
||||||
return errTxClosed
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
@ -298,6 +343,7 @@ func (tx *Tx) Insert(values ...any) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo optimize: should track per field whether it (or a child) has a default value, and only applyDefault if so.
|
||||||
if err := st.Current.applyDefault(rv); err != nil {
|
if err := st.Current.applyDefault(rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -344,7 +390,7 @@ func (tx *Tx) put(st storeType, rv reflect.Value, insert bool) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) insert(rb *bolt.Bucket, st storeType, rv, krv reflect.Value, k []byte) error {
|
func (tx *Tx) insert(rb *bolt.Bucket, st storeType, rv, krv reflect.Value, k []byte) (rerr error) {
|
||||||
v, err := st.pack(rv)
|
v, err := st.pack(rv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -352,6 +398,7 @@ func (tx *Tx) insert(rb *bolt.Bucket, st storeType, rv, krv reflect.Value, k []b
|
||||||
if err := tx.checkReferences(st.Current, k, reflect.Value{}, rv); err != nil {
|
if err := tx.checkReferences(st.Current, k, reflect.Value{}, rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer tx.markError(&rerr)
|
||||||
if err := tx.updateIndices(st.Current, k, reflect.Value{}, rv); err != nil {
|
if err := tx.updateIndices(st.Current, k, reflect.Value{}, rv); err != nil {
|
||||||
return fmt.Errorf("updating indices for inserted value: %w", err)
|
return fmt.Errorf("updating indices for inserted value: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -363,7 +410,7 @@ func (tx *Tx) insert(rb *bolt.Bucket, st storeType, rv, krv reflect.Value, k []b
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) update(rb *bolt.Bucket, st storeType, rv, rov reflect.Value, k []byte) error {
|
func (tx *Tx) update(rb *bolt.Bucket, st storeType, rv, rov reflect.Value, k []byte) (rerr error) {
|
||||||
if st.Current.equal(rov, rv) {
|
if st.Current.equal(rov, rv) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -375,6 +422,7 @@ func (tx *Tx) update(rb *bolt.Bucket, st storeType, rv, rov reflect.Value, k []b
|
||||||
if err := tx.checkReferences(st.Current, k, rov, rv); err != nil {
|
if err := tx.checkReferences(st.Current, k, rov, rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer tx.markError(&rerr)
|
||||||
if err := tx.updateIndices(st.Current, k, rov, rv); err != nil {
|
if err := tx.updateIndices(st.Current, k, rov, rv); err != nil {
|
||||||
return fmt.Errorf("updating indices for updated record: %w", err)
|
return fmt.Errorf("updating indices for updated record: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -391,13 +439,16 @@ func (tx *Tx) update(rb *bolt.Bucket, st storeType, rv, rov reflect.Value, k []b
|
||||||
//
|
//
|
||||||
// A writable Tx can be committed or rolled back. A read-only transaction must
|
// A writable Tx can be committed or rolled back. A read-only transaction must
|
||||||
// always be rolled back.
|
// always be rolled back.
|
||||||
func (db *DB) Begin(writable bool) (*Tx, error) {
|
func (db *DB) Begin(ctx context.Context, writable bool) (*Tx, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
btx, err := db.bdb.Begin(writable)
|
btx, err := db.bdb.Begin(writable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
db.typesMutex.RLock()
|
db.typesMutex.RLock()
|
||||||
tx := &Tx{db: db, btx: btx}
|
tx := &Tx{ctx: ctx, db: db, btx: btx}
|
||||||
if writable {
|
if writable {
|
||||||
tx.stats.Writes++
|
tx.stats.Writes++
|
||||||
} else {
|
} else {
|
||||||
|
@ -422,9 +473,14 @@ func (tx *Tx) Rollback() error {
|
||||||
|
|
||||||
// Commit commits changes made in the transaction to the database.
|
// Commit commits changes made in the transaction to the database.
|
||||||
// Statistics are added to its DB.
|
// Statistics are added to its DB.
|
||||||
|
// If the commit fails, or the transaction was botched, the transaction is
|
||||||
|
// rolled back.
|
||||||
func (tx *Tx) Commit() error {
|
func (tx *Tx) Commit() error {
|
||||||
if tx.db == nil {
|
if tx.db == nil {
|
||||||
return errTxClosed
|
return errTxClosed
|
||||||
|
} else if tx.err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return tx.err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.addStats()
|
tx.addStats()
|
||||||
|
|
2
vendor/modules.txt
vendored
2
vendor/modules.txt
vendored
|
@ -11,7 +11,7 @@ github.com/golang/protobuf/ptypes/timestamp
|
||||||
# github.com/matttproud/golang_protobuf_extensions v1.0.1
|
# github.com/matttproud/golang_protobuf_extensions v1.0.1
|
||||||
## explicit
|
## explicit
|
||||||
github.com/matttproud/golang_protobuf_extensions/pbutil
|
github.com/matttproud/golang_protobuf_extensions/pbutil
|
||||||
# github.com/mjl-/bstore v0.0.0-20230211204415-a9899ef6e782
|
# github.com/mjl-/bstore v0.0.1
|
||||||
## explicit; go 1.19
|
## explicit; go 1.19
|
||||||
github.com/mjl-/bstore
|
github.com/mjl-/bstore
|
||||||
# github.com/mjl-/sconf v0.0.4
|
# github.com/mjl-/sconf v0.0.4
|
||||||
|
|
Loading…
Reference in a new issue