package scram

import (
	"encoding/base64"
	"errors"
	"fmt"
	"strconv"
	"strings"
)

type parser struct {
	s     string // Original casing.
	lower string // Lower casing, for case-insensitive token consumption.
	o     int    // Offset in s/lower.
}

type parseError struct{ err error }

func (e parseError) Error() string {
	return e.err.Error()
}

func (e parseError) Unwrap() error {
	return e.err
}

// toLower lower cases bytes that are A-Z. strings.ToLower does too much. and
// would replace invalid bytes with unicode replacement characters, which would
// break our requirement that offsets into the original and upper case strings
// point to the same character.
func toLower(s string) string {
	r := []byte(s)
	for i, c := range r {
		if c >= 'A' && c <= 'Z' {
			r[i] = c + 0x20
		}
	}
	return string(r)
}

func newParser(buf []byte) *parser {
	s := string(buf)
	return &parser{s, toLower(s), 0}
}

// Turn panics of parseError into a descriptive ErrInvalidEncoding. Called with
// defer by functions that parse.
func (p *parser) recover(rerr *error) {
	x := recover()
	if x == nil {
		return
	}
	err, ok := x.(error)
	if !ok {
		panic(x)
	}
	var xerr Error
	if errors.As(err, &xerr) {
		*rerr = err
		return
	}
	*rerr = fmt.Errorf("%w: %s", ErrInvalidEncoding, err)
}

func (p *parser) xerrorf(format string, args ...any) {
	panic(parseError{fmt.Errorf(format, args...)})
}

func (p *parser) xcheckf(err error, format string, args ...any) {
	if err != nil {
		panic(parseError{fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)})
	}
}

func (p *parser) xempty() {
	if p.o != len(p.s) {
		p.xerrorf("leftover data")
	}
}

func (p *parser) xnonempty() {
	if p.o >= len(p.s) {
		p.xerrorf("unexpected end")
	}
}

func (p *parser) xbyte() byte {
	p.xnonempty()
	c := p.lower[p.o]
	p.o++
	return c
}

func (p *parser) peek(s string) bool {
	return strings.HasPrefix(p.lower[p.o:], s)
}

func (p *parser) take(s string) bool {
	if p.peek(s) {
		p.o += len(s)
		return true
	}
	return false
}

func (p *parser) xtake(s string) {
	if !p.take(s) {
		p.xerrorf("expected %q", s)
	}
}

func (p *parser) xauthzid() string {
	p.xtake("a=")
	return p.xsaslname()
}

func (p *parser) xusername() string {
	p.xtake("n=")
	return p.xsaslname()
}

func (p *parser) xnonce() string {
	p.xtake("r=")
	o := p.o
	for ; o < len(p.s); o++ {
		c := p.s[o]
		if c <= ' ' || c >= 0x7f || c == ',' {
			break
		}
	}
	if o == p.o {
		p.xerrorf("empty nonce")
	}
	r := p.s[p.o:o]
	p.o = o
	return r
}

func (p *parser) xattrval() {
	c := p.xbyte()
	if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z') {
		p.xerrorf("expected alpha for attr-val")
	}
	p.xtake("=")
	p.xvalue()
}

func (p *parser) xvalue() string {
	for o, c := range p.s[p.o:] {
		if c == 0 || c == ',' {
			if o == 0 {
				p.xerrorf("invalid empty value")
			}
			r := p.s[p.o : p.o+o]
			p.o = o
			return r
		}
	}
	p.xnonempty()
	r := p.s[p.o:]
	p.o = len(p.s)
	return r
}

func (p *parser) xbase64() []byte {
	o := p.o
	for ; o < len(p.s); o++ {
		c := p.s[o]
		if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '/' || c == '+' || c == '=') {
			break
		}
	}
	buf, err := base64.StdEncoding.DecodeString(p.s[p.o:o])
	p.xcheckf(err, "decoding base64")
	p.o = o
	return buf
}

func (p *parser) xsaslname() string {
	var esc string
	var is bool
	var r string
	for o, c := range p.s[p.o:] {
		if c == 0 || c == ',' {
			if is {
				p.xerrorf("saslname unexpected end")
			}
			if o == 0 {
				p.xerrorf("saslname cannot be empty")
			}
			p.o += o
			return r
		}
		if is {
			esc += string(c)
			if len(esc) < 2 {
				continue
			}
			switch esc {
			case "2c", "2C":
				r += ","
			case "3d", "3D":
				r += "="
			default:
				p.xerrorf("bad escape %q in saslanem", esc)
			}
			is = false
			esc = ""
			continue
		} else if c == '=' {
			is = true
			continue
		}
		r += string(c)
	}
	if is {
		p.xerrorf("saslname unexpected end")
	}
	if r == "" {
		p.xerrorf("saslname cannot be empty")
	}
	p.o = len(p.s)
	return r
}

func (p *parser) xchannelBinding() string {
	p.xtake("c=")
	return string(p.xbase64())
}

func (p *parser) xproof() []byte {
	p.xtake("p=")
	return p.xbase64()
}

func (p *parser) xsalt() []byte {
	p.xtake("s=")
	return p.xbase64()
}

func (p *parser) xtakefn1(fn func(rune, int) bool) string {
	for o, c := range p.s[p.o:] {
		if !fn(c, o) {
			if o == 0 {
				p.xerrorf("non-empty match required")
			}
			r := p.s[p.o : p.o+o]
			p.o += o
			return r
		}
	}
	p.xnonempty()
	r := p.s[p.o:]
	p.o = len(p.s)
	return r
}

func (p *parser) xiterations() int {
	p.xtake("i=")
	digits := p.xtakefn1(func(c rune, i int) bool {
		return c >= '1' && c <= '9' || i > 0 && c == '0'
	})
	v, err := strconv.ParseInt(digits, 10, 32)
	p.xcheckf(err, "parsing int")
	return int(v)
}