package caddyhttp

import (
	"bytes"
	"context"
	"io"
	"net/http"
	"net/http/httptest"
	"net/netip"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
)

type writeFunc func(p []byte) (int, error)

type nopSyncer writeFunc

func (n nopSyncer) Write(p []byte) (int, error) {
	return n(p)
}

func (n nopSyncer) Sync() error {
	return nil
}

// testLogger returns a logger and a buffer to which the logger writes. The
// buffer can be read for asserting log output.
func testLogger(wf writeFunc) *zap.Logger {
	ws := nopSyncer(wf)
	encoderCfg := zapcore.EncoderConfig{
		MessageKey:     "msg",
		LevelKey:       "level",
		NameKey:        "logger",
		EncodeLevel:    zapcore.LowercaseLevelEncoder,
		EncodeTime:     zapcore.ISO8601TimeEncoder,
		EncodeDuration: zapcore.StringDurationEncoder,
	}
	core := zapcore.NewCore(zapcore.NewJSONEncoder(encoderCfg), ws, zap.DebugLevel)

	return zap.New(core)
}

func TestServer_LogRequest(t *testing.T) {
	s := &Server{}

	ctx := context.Background()
	ctx = context.WithValue(ctx, ExtraLogFieldsCtxKey, new(ExtraLogFields))
	req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
	rec := httptest.NewRecorder()
	wrec := NewResponseRecorder(rec, nil, nil)

	duration := 50 * time.Millisecond
	repl := NewTestReplacer(req)
	bodyReader := &lengthReader{Source: req.Body}
	shouldLogCredentials := false

	buf := bytes.Buffer{}
	accLog := testLogger(buf.Write)
	s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, shouldLogCredentials)

	assert.JSONEq(t, `{
		"msg":"handled request", "level":"info", "bytes_read":0,
		"duration":"50ms", "resp_headers": {}, "size":0,
		"status":0, "user_id":""
	}`, buf.String())
}

func TestServer_LogRequest_WithTraceID(t *testing.T) {
	s := &Server{}

	extra := new(ExtraLogFields)
	ctx := context.WithValue(context.Background(), ExtraLogFieldsCtxKey, extra)
	extra.Add(zap.String("traceID", "1234567890abcdef"))

	req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
	rec := httptest.NewRecorder()
	wrec := NewResponseRecorder(rec, nil, nil)

	duration := 50 * time.Millisecond
	repl := NewTestReplacer(req)
	bodyReader := &lengthReader{Source: req.Body}
	shouldLogCredentials := false

	buf := bytes.Buffer{}
	accLog := testLogger(buf.Write)
	s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, shouldLogCredentials)

	assert.JSONEq(t, `{
		"msg":"handled request", "level":"info", "bytes_read":0,
		"duration":"50ms", "resp_headers": {}, "size":0,
		"status":0, "user_id":"",
		"traceID":"1234567890abcdef"
	}`, buf.String())
}

func BenchmarkServer_LogRequest(b *testing.B) {
	s := &Server{}

	extra := new(ExtraLogFields)
	ctx := context.WithValue(context.Background(), ExtraLogFieldsCtxKey, extra)

	req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
	rec := httptest.NewRecorder()
	wrec := NewResponseRecorder(rec, nil, nil)

	duration := 50 * time.Millisecond
	repl := NewTestReplacer(req)
	bodyReader := &lengthReader{Source: req.Body}

	buf := io.Discard
	accLog := testLogger(buf.Write)

	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, false)
	}
}

func BenchmarkServer_LogRequest_WithTraceID(b *testing.B) {
	s := &Server{}

	extra := new(ExtraLogFields)
	ctx := context.WithValue(context.Background(), ExtraLogFieldsCtxKey, extra)
	extra.Add(zap.String("traceID", "1234567890abcdef"))

	req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
	rec := httptest.NewRecorder()
	wrec := NewResponseRecorder(rec, nil, nil)

	duration := 50 * time.Millisecond
	repl := NewTestReplacer(req)
	bodyReader := &lengthReader{Source: req.Body}

	buf := io.Discard
	accLog := testLogger(buf.Write)

	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		s.logRequest(accLog, req, wrec, &duration, repl, bodyReader, false)
	}
}
func TestServer_TrustedRealClientIP_NoTrustedHeaders(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	ip := trustedRealClientIP(req, []string{}, "192.0.2.1")

	assert.Equal(t, ip, "192.0.2.1")
}

func TestServer_TrustedRealClientIP_OneTrustedHeaderEmpty(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "192.0.2.1")
}

func TestServer_TrustedRealClientIP_OneTrustedHeaderInvalid(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("X-Forwarded-For", "not, an, ip")
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "192.0.2.1")
}

func TestServer_TrustedRealClientIP_OneTrustedHeaderValid(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("X-Forwarded-For", "10.0.0.1")
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "10.0.0.1")
}

func TestServer_TrustedRealClientIP_OneTrustedHeaderValidArray(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3")
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "1.1.1.1")
}

func TestServer_TrustedRealClientIP_IncludesPort(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("X-Forwarded-For", "1.1.1.1:1234")
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "1.1.1.1")
}

func TestServer_TrustedRealClientIP_SkipsInvalidIps(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("X-Forwarded-For", "not an ip, bad bad, 10.0.0.1")
	ip := trustedRealClientIP(req, []string{"X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip, "10.0.0.1")
}

func TestServer_TrustedRealClientIP_MultipleTrustedHeaderValidArray(t *testing.T) {
	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"
	req.Header.Set("Real-Client-IP", "1.1.1.1, 2.2.2.2, 3.3.3.3")
	req.Header.Set("X-Forwarded-For", "3.3.3.3, 4.4.4.4")
	ip1 := trustedRealClientIP(req, []string{"X-Forwarded-For", "Real-Client-IP"}, "192.0.2.1")
	ip2 := trustedRealClientIP(req, []string{"Real-Client-IP", "X-Forwarded-For"}, "192.0.2.1")
	ip3 := trustedRealClientIP(req, []string{"Missing-Header-IP", "Real-Client-IP", "X-Forwarded-For"}, "192.0.2.1")

	assert.Equal(t, ip1, "3.3.3.3")
	assert.Equal(t, ip2, "1.1.1.1")
	assert.Equal(t, ip3, "1.1.1.1")
}

func TestServer_DetermineTrustedProxy_NoConfig(t *testing.T) {
	server := &Server{}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "192.0.2.1:12345"

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.False(t, trusted)
	assert.Equal(t, clientIP, "192.0.2.1")
}

func TestServer_DetermineTrustedProxy_NoConfigIpv6(t *testing.T) {
	server := &Server{}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "[::1]:12345"

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.False(t, trusted)
	assert.Equal(t, clientIP, "::1")
}

func TestServer_DetermineTrustedProxy_NoConfigIpv6Zones(t *testing.T) {
	server := &Server{}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "[::1%eth2]:12345"

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.False(t, trusted)
	assert.Equal(t, clientIP, "::1")
}

func TestServer_DetermineTrustedProxy_TrustedLoopback(t *testing.T) {
	loopbackPrefix, _ := netip.ParsePrefix("127.0.0.1/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{loopbackPrefix},
		},
		ClientIPHeaders: []string{"X-Forwarded-For"},
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "127.0.0.1:12345"
	req.Header.Set("X-Forwarded-For", "31.40.0.10")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "31.40.0.10")
}

func TestServer_DetermineTrustedProxy_UntrustedPrefix(t *testing.T) {
	loopbackPrefix, _ := netip.ParsePrefix("127.0.0.1/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{loopbackPrefix},
		},
		ClientIPHeaders: []string{"X-Forwarded-For"},
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("X-Forwarded-For", "31.40.0.10")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.False(t, trusted)
	assert.Equal(t, clientIP, "10.0.0.1")
}

func TestServer_DetermineTrustedProxy_MultipleTrustedPrefixes(t *testing.T) {
	loopbackPrefix, _ := netip.ParsePrefix("127.0.0.1/8")
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{loopbackPrefix, localPrivatePrefix},
		},
		ClientIPHeaders: []string{"X-Forwarded-For"},
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("X-Forwarded-For", "31.40.0.10")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "31.40.0.10")
}

func TestServer_DetermineTrustedProxy_MultipleTrustedClientHeaders(t *testing.T) {
	loopbackPrefix, _ := netip.ParsePrefix("127.0.0.1/8")
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{loopbackPrefix, localPrivatePrefix},
		},
		ClientIPHeaders: []string{"CF-Connecting-IP", "X-Forwarded-For"},
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("CF-Connecting-IP", "1.1.1.1, 2.2.2.2")
	req.Header.Set("X-Forwarded-For", "3.3.3.3, 4.4.4.4")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "1.1.1.1")
}

func TestServer_DetermineTrustedProxy_MatchLeftMostValidIp(t *testing.T) {
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{localPrivatePrefix},
		},
		ClientIPHeaders: []string{"X-Forwarded-For"},
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("X-Forwarded-For", "30.30.30.30, 45.54.45.54, 10.0.0.1")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "30.30.30.30")
}

func TestServer_DetermineTrustedProxy_MatchRightMostUntrusted(t *testing.T) {
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{localPrivatePrefix},
		},
		ClientIPHeaders:      []string{"X-Forwarded-For"},
		TrustedProxiesStrict: 1,
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("X-Forwarded-For", "30.30.30.30, 45.54.45.54, 10.0.0.1")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "45.54.45.54")
}

func TestServer_DetermineTrustedProxy_MatchRightMostUntrustedSkippingEmpty(t *testing.T) {
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{localPrivatePrefix},
		},
		ClientIPHeaders:      []string{"Missing-Header", "CF-Connecting-IP", "X-Forwarded-For"},
		TrustedProxiesStrict: 1,
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("CF-Connecting-IP", "not a real IP")
	req.Header.Set("X-Forwarded-For", "30.30.30.30, bad, 45.54.45.54, not real")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "45.54.45.54")
}

func TestServer_DetermineTrustedProxy_MatchRightMostUntrustedSkippingTrusted(t *testing.T) {
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{localPrivatePrefix},
		},
		ClientIPHeaders:      []string{"CF-Connecting-IP", "X-Forwarded-For"},
		TrustedProxiesStrict: 1,
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("CF-Connecting-IP", "10.0.0.1, 10.0.0.2, 10.0.0.3")
	req.Header.Set("X-Forwarded-For", "30.30.30.30, 45.54.45.54, 10.0.0.4")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "45.54.45.54")
}

func TestServer_DetermineTrustedProxy_MatchRightMostUntrustedFirst(t *testing.T) {
	localPrivatePrefix, _ := netip.ParsePrefix("10.0.0.0/8")

	server := &Server{
		trustedProxies: &StaticIPRange{
			ranges: []netip.Prefix{localPrivatePrefix},
		},
		ClientIPHeaders:      []string{"CF-Connecting-IP", "X-Forwarded-For"},
		TrustedProxiesStrict: 1,
	}

	req := httptest.NewRequest("GET", "/", nil)
	req.RemoteAddr = "10.0.0.1:12345"
	req.Header.Set("CF-Connecting-IP", "10.0.0.1, 90.100.110.120, 10.0.0.2, 10.0.0.3")
	req.Header.Set("X-Forwarded-For", "30.30.30.30, 45.54.45.54, 10.0.0.4")

	trusted, clientIP := determineTrustedProxy(req, server)

	assert.True(t, trusted)
	assert.Equal(t, clientIP, "90.100.110.120")
}