package mtasts import ( "fmt" "strconv" "strings" "github.com/mjl-/mox/dns" ) type parseErr string func (e parseErr) Error() string { return string(e) } var _ error = parseErr("") // ParseRecord parses an MTA-STS record. func ParseRecord(txt string) (record *Record, ismtasts bool, err error) { defer func() { x := recover() if x == nil { return } if xerr, ok := x.(parseErr); ok { record = nil err = fmt.Errorf("%w: %s", ErrRecordSyntax, xerr) return } panic(x) }() // Parsing is mostly case-sensitive. // ../rfc/8461:306 p := newParser(txt) record = &Record{ Version: "STSv1", } seen := map[string]struct{}{} p.xtake("v=STSv1") p.xdelim() ismtasts = true for { k := p.xkey() p.xtake("=") // Section 3.1 about the TXT record does not say anything about duplicate fields. // But section 3.2 about (parsing) policies has a paragraph that starts // requirements on both TXT and policy records. That paragraph ends with a note // about handling duplicate fields. Let's assume that note also applies to TXT // records. ../rfc/8461:517 _, dup := seen[k] seen[k] = struct{}{} switch k { case "id": if !dup { record.ID = p.xid() } default: v := p.xvalue() record.Extensions = append(record.Extensions, Pair{k, v}) } if !p.delim() || p.empty() { break } } if !p.empty() { p.xerrorf("leftover characters") } if record.ID == "" { p.xerrorf("missing id") } return } // ParsePolicy parses an MTA-STS policy. func ParsePolicy(s string) (policy *Policy, err error) { defer func() { x := recover() if x == nil { return } if xerr, ok := x.(parseErr); ok { policy = nil err = fmt.Errorf("%w: %s", ErrPolicySyntax, xerr) return } panic(x) }() // ../rfc/8461:426 p := newParser(s) policy = &Policy{ Version: "STSv1", } seen := map[string]struct{}{} for { k := p.xkey() // For fields except "mx", only the first must be used. ../rfc/8461:517 _, dup := seen[k] seen[k] = struct{}{} p.xtake(":") p.wsp() switch k { case "version": policy.Version = p.xtake("STSv1") case "mode": mode := Mode(p.xtakelist("testing", "enforce", "none")) if !dup { policy.Mode = mode } case "max_age": maxage := p.xmaxage() if !dup { policy.MaxAgeSeconds = maxage } case "mx": policy.MX = append(policy.MX, p.xmx()) default: v := p.xpolicyvalue() policy.Extensions = append(policy.Extensions, Pair{k, v}) } p.wsp() if !p.eol() || p.empty() { break } } if !p.empty() { p.xerrorf("leftover characters") } required := []string{"version", "mode", "max_age"} for _, req := range required { if _, ok := seen[req]; !ok { p.xerrorf("missing field %q", req) } } if _, ok := seen["mx"]; !ok && policy.Mode != ModeNone { // ../rfc/8461:437 p.xerrorf("missing mx given mode") } return } type parser struct { s string o int } func newParser(s string) *parser { return &parser{s: s} } func (p *parser) xerrorf(format string, args ...any) { msg := fmt.Sprintf(format, args...) if p.o < len(p.s) { msg += fmt.Sprintf(" (remain %q)", p.s[p.o:]) } panic(parseErr(msg)) } func (p *parser) xtake(s string) string { if !p.prefix(s) { p.xerrorf("expected %q", s) } p.o += len(s) return s } func (p *parser) xdelim() { if !p.delim() { p.xerrorf("expected semicolon") } } func (p *parser) xtaken(n int) string { r := p.s[p.o : p.o+n] p.o += n return r } func (p *parser) xtakefn1(fn func(rune, int) bool) string { for i, b := range p.s[p.o:] { if !fn(b, i) { if i == 0 { p.xerrorf("expected at least one char") } return p.xtaken(i) } } if p.empty() { p.xerrorf("expected at least 1 char") } return p.xtaken(len(p.s) - p.o) } func (p *parser) prefix(s string) bool { return strings.HasPrefix(p.s[p.o:], s) } // File name, the known values match this syntax. // ../rfc/8461:482 func (p *parser) xkey() string { return p.xtakefn1(func(b rune, i int) bool { return i < 32 && (b >= 'a' && b <= 'z' || b >= 'A' && b <= 'Z' || b >= '0' && b <= '9' || (i > 0 && b == '_' || b == '-' || b == '.')) }) } // ../rfc/8461:319 func (p *parser) xid() string { return p.xtakefn1(func(b rune, i int) bool { return i < 32 && (b >= 'a' && b <= 'z' || b >= 'A' && b <= 'Z' || b >= '0' && b <= '9') }) } // ../rfc/8461:326 func (p *parser) xvalue() string { return p.xtakefn1(func(b rune, i int) bool { return b > ' ' && b < 0x7f && b != '=' && b != ';' }) } // ../rfc/8461:315 func (p *parser) delim() bool { o := p.o e := len(p.s) for o < e && (p.s[o] == ' ' || p.s[o] == '\t') { o++ } if o >= e || p.s[o] != ';' { return false } o++ for o < e && (p.s[o] == ' ' || p.s[o] == '\t') { o++ } p.o = o return true } func (p *parser) empty() bool { return p.o >= len(p.s) } // ../rfc/8461:485 func (p *parser) eol() bool { return p.take("\n") || p.take("\r\n") } func (p *parser) xtakelist(l ...string) string { for _, s := range l { if p.prefix(s) { return p.xtaken(len(s)) } } p.xerrorf("expected one of %s", strings.Join(l, ", ")) return "" // not reached } // ../rfc/8461:476 func (p *parser) xmaxage() int { digits := p.xtakefn1(func(b rune, i int) bool { return b >= '0' && b <= '9' && i < 10 }) v, err := strconv.ParseInt(digits, 10, 32) if err != nil { p.xerrorf("parsing int: %s", err) } return int(v) } func (p *parser) take(s string) bool { if p.prefix(s) { p.o += len(s) return true } return false } // ../rfc/8461:469 func (p *parser) xmx() (mx STSMX) { if p.prefix("*.") { mx.Wildcard = true p.o += 2 } mx.Domain = p.xdomain() return mx } // ../rfc/5321:2291 func (p *parser) xdomain() dns.Domain { s := p.xsubdomain() for p.take(".") { s += "." + p.xsubdomain() } d, err := dns.ParseDomain(s) if err != nil { p.xerrorf("parsing domain %q: %s", s, err) } return d } // ../rfc/8461:487 func (p *parser) xsubdomain() string { // note: utf-8 is valid, but U-labels are explicitly not allowed. ../rfc/8461:411 ../rfc/5321:2303 unicode := false s := p.xtakefn1(func(c rune, i int) bool { if c > 0x7f { unicode = true } return c >= '0' && c <= '9' || c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || (i > 0 && c == '-') || c > 0x7f }) if unicode { p.xerrorf("domain must be specified in A labels, not U labels (unicode)") } return s } // ../rfc/8461:487 func (p *parser) xpolicyvalue() string { e := len(p.s) for i, c := range p.s[p.o:] { if c > ' ' && c < 0x7f || c >= 0x80 || (c == ' ' && i > 0) { continue } e = p.o + i break } // Walk back on trailing spaces. for e > p.o && p.s[e-1] == ' ' { e-- } n := e - p.o if n <= 0 { p.xerrorf("empty extension value") } return p.xtaken(n) } // "*WSP" func (p *parser) wsp() { n := len(p.s) for p.o < n && (p.s[p.o] == ' ' || p.s[p.o] == '\t') { p.o++ } }