package smtp

import (
	"bufio"
	"errors"
	"io"
	"strings"
	"testing"
)

func TestDataWrite(t *testing.T) {
	if err := DataWrite(io.Discard, strings.NewReader("bad")); err == nil || !errors.Is(err, errMissingCRLF) {
		t.Fatalf("got err %v, expected errMissingCRLF", err)
	}
	if err := DataWrite(io.Discard, strings.NewReader(".")); err == nil || !errors.Is(err, errMissingCRLF) {
		t.Fatalf("got err %v, expected errMissingCRLF", err)
	}

	check := func(msg, want string) {
		t.Helper()
		w := &strings.Builder{}
		if err := DataWrite(w, strings.NewReader(msg)); err != nil {
			t.Fatalf("writing smtp data: %s", err)
		}
		got := w.String()
		if got != want {
			t.Fatalf("got %q, expected %q, for msg %q", got, want, msg)
		}
	}

	check("", ".\r\n")
	check(".\r\n", "..\r\n.\r\n")
	check("header: abc\r\n\r\nmessage\r\n", "header: abc\r\n\r\nmessage\r\n.\r\n")
}

func TestDataReader(t *testing.T) {
	// Copy with a 1 byte buffer for reading.
	smallCopy := func(d io.Writer, r io.Reader) (int, error) {
		var wrote int
		buf := make([]byte, 1)
		for {
			n, err := r.Read(buf)
			if n > 0 {
				nn, err := d.Write(buf)
				if nn > 0 {
					wrote += nn
				}
				if err != nil {
					return wrote, err
				}
			}
			if err == io.EOF {
				break
			} else if err != nil {
				return wrote, err
			}
		}
		return wrote, nil
	}

	check := func(data, want string) {
		t.Helper()

		s := &strings.Builder{}
		dr := NewDataReader(bufio.NewReader(strings.NewReader(data)))
		if _, err := io.Copy(s, dr); err != nil {
			t.Fatalf("got err %v", err)
		} else if got := s.String(); got != want {
			t.Fatalf("got %q, expected %q, for %q", got, want, data)
		}

		s = &strings.Builder{}
		dr = NewDataReader(bufio.NewReader(strings.NewReader(data)))
		if _, err := smallCopy(s, dr); err != nil {
			t.Fatalf("got err %v", err)
		} else if got := s.String(); got != want {
			t.Fatalf("got %q, expected %q, for %q", got, want, data)
		}
	}

	check("test\r\n.\r\n", "test\r\n")
	check(".\r\n", "")
	check(".test\r\n.\r\n", "test\r\n") // Unnecessary dot, but valid in SMTP.
	check("..test\r\n.\r\n", ".test\r\n")

	s := &strings.Builder{}
	dr := NewDataReader(bufio.NewReader(strings.NewReader("no end")))
	if _, err := io.Copy(s, dr); err != io.ErrUnexpectedEOF {
		t.Fatalf("got err %v, expected io.ErrUnexpectedEOF", err)
	}
}