// Package subjectpass implements a mechanism for reject an incoming message with a challenge to include a token in a next delivery attempt. package subjectpass import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" "fmt" "io" "strings" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/mjl-/mox/dns" "github.com/mjl-/mox/message" "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/smtp" ) var log = mlog.New("subjectpass") var ( metricGenerate = promauto.NewCounter( prometheus.CounterOpts{ Name: "mox_subjectpass_generate_total", Help: "Number of generated subjectpass challenges.", }, ) metricVerify = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "mox_subjectpass_verify_total", Help: "Number of subjectpass verifications.", }, []string{ "result", // ok, fail }, ) ) var ( ErrMessage = errors.New("subjectpass: malformed message") ErrAbsent = errors.New("subjectpass: no token found") ErrFrom = errors.New("subjectpass: bad From") ErrInvalid = errors.New("subjectpass: malformed token") ErrVerify = errors.New("subjectpass: verification failed") ErrExpired = errors.New("subjectpass: token expired") ) var Explanation = "Your message resembles spam. If your email is legitimate, please send it again with the following added to the email message subject: " // Generate generates a token that is valid for "mailFrom", starting from "tm" // and signed with "key". // The token is of the form: (pass:) func Generate(mailFrom smtp.Address, key []byte, tm time.Time) string { metricGenerate.Inc() log.Debug("subjectpass generate", mlog.Field("mailfrom", mailFrom)) // We discard the lower 8 bits of the time, we can do with less precision. t := tm.Unix() buf := []byte{ 0 | (byte(t>>32) & 0x0f), // 4 bits version, 4 bits time byte(t>>24) & 0xff, byte(t>>16) & 0xff, byte(t>>8) & 0xff, } mac := hmac.New(sha256.New, key) mac.Write(buf) mac.Write([]byte(mailFrom.String())) h := mac.Sum(nil)[:12] buf = append(buf, h...) return "(pass:" + base64.RawURLEncoding.EncodeToString(buf) + ")" } // Verify parses "message" and checks if it includes a subjectpass token in its // Subject header that is still valid (within "period") and signed with "key". func Verify(log *mlog.Log, r io.ReaderAt, key []byte, period time.Duration) (rerr error) { var token string defer func() { result := "fail" if rerr == nil { result = "ok" } metricVerify.WithLabelValues(result).Inc() log.Debugx("subjectpass verify result", rerr, mlog.Field("token", token), mlog.Field("period", period)) }() p, err := message.Parse(log, true, r) if err != nil { return fmt.Errorf("%w: parse message: %s", ErrMessage, err) } header, err := p.Header() if err != nil { return fmt.Errorf("%w: parse message headers: %s", ErrMessage, err) } subject := header.Get("Subject") if subject == "" { log.Info("no subject header") return fmt.Errorf("%w: no subject header", ErrAbsent) } t := strings.SplitN(subject, "(pass:", 2) if len(t) != 2 { return fmt.Errorf("%w: no token in subject", ErrAbsent) } t = strings.SplitN(t[1], ")", 2) if len(t) != 2 { return fmt.Errorf("%w: no token in subject (2)", ErrAbsent) } token = t[0] if len(p.Envelope.From) != 1 { return fmt.Errorf("%w: need 1 from address, got %d", ErrFrom, len(p.Envelope.From)) } from := p.Envelope.From[0] d, err := dns.ParseDomain(from.Host) if err != nil { return fmt.Errorf("%w: from address with bad domain: %v", ErrFrom, err) } addr := smtp.Address{Localpart: smtp.Localpart(from.User), Domain: d}.Pack(true) buf, err := base64.RawURLEncoding.DecodeString(token) if err != nil { return fmt.Errorf("%w: parsing base64: %s", ErrInvalid, err) } if len(buf) == 0 { return fmt.Errorf("%w: empty pass token", ErrInvalid) } version := buf[0] >> 4 if version != 0 { return fmt.Errorf("%w: unknown version %d", ErrInvalid, version) } if len(buf) != 4+12 { return fmt.Errorf("%w: bad length of pass token, %d", ErrInvalid, len(buf)) } mac := hmac.New(sha256.New, key) mac.Write(buf[:4]) mac.Write([]byte(addr)) h := mac.Sum(nil)[:12] if !hmac.Equal(buf[4:], h) { return ErrVerify } tsign := time.Unix(int64(buf[0]&0x0f)<<32|int64(buf[1])<<24|int64(buf[2])<<16|int64(buf[3])<<8, 0) if time.Since(tsign) > period { return fmt.Errorf("%w: pass token expired, signed at %s, period %s", ErrExpired, tsign, period) } return nil }