From 819aed35bff81574920c2a87eddbebed5ddfda1f Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Sun, 9 Jul 2023 20:25:53 +0800
Subject: [PATCH] Make route middleware/handler mockable (#25766)

To mock a handler:

```go
web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) {
	// ...
})
defer web.RouteMockReset()
```


It helps:

* Test the middleware's behavior (assert the ctx.Data, etc)
* Mock the middleware's behavior (prepare some context data for handler)
* Mock the handler's response for some test cases, especially for some
integration tests and e2e tests.
---
 modules/web/route.go          | 18 +++++++--
 modules/web/routemock.go      | 61 ++++++++++++++++++++++++++++++
 modules/web/routemock_test.go | 70 +++++++++++++++++++++++++++++++++++
 3 files changed, 145 insertions(+), 4 deletions(-)
 create mode 100644 modules/web/routemock.go
 create mode 100644 modules/web/routemock_test.go

diff --git a/modules/web/route.go b/modules/web/route.go
index 8685062a8e..dc87e112ec 100644
--- a/modules/web/route.go
+++ b/modules/web/route.go
@@ -50,7 +50,9 @@ func NewRoute() *Route {
 // Use supports two middlewares
 func (r *Route) Use(middlewares ...any) {
 	for _, m := range middlewares {
-		r.R.Use(toHandlerProvider(m))
+		if m != nil {
+			r.R.Use(toHandlerProvider(m))
+		}
 	}
 }
 
@@ -79,15 +81,23 @@ func (r *Route) getPattern(pattern string) string {
 }
 
 func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
-	handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h))
+	handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)+1)
 	for _, m := range r.curMiddlewares {
-		handlerProviders = append(handlerProviders, toHandlerProvider(m))
+		if m != nil {
+			handlerProviders = append(handlerProviders, toHandlerProvider(m))
+		}
 	}
 	for _, m := range h {
-		handlerProviders = append(handlerProviders, toHandlerProvider(m))
+		if h != nil {
+			handlerProviders = append(handlerProviders, toHandlerProvider(m))
+		}
 	}
 	middlewares := handlerProviders[:len(handlerProviders)-1]
 	handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
+	mockPoint := RouteMockPoint(MockAfterMiddlewares)
+	if mockPoint != nil {
+		middlewares = append(middlewares, mockPoint)
+	}
 	return middlewares, handlerFunc
 }
 
diff --git a/modules/web/routemock.go b/modules/web/routemock.go
new file mode 100644
index 0000000000..cb41f63b91
--- /dev/null
+++ b/modules/web/routemock.go
@@ -0,0 +1,61 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package web
+
+import (
+	"net/http"
+
+	"code.gitea.io/gitea/modules/setting"
+)
+
+// MockAfterMiddlewares is a general mock point, it's between middlewares and the handler
+const MockAfterMiddlewares = "MockAfterMiddlewares"
+
+var routeMockPoints = map[string]func(next http.Handler) http.Handler{}
+
+// RouteMockPoint registers a mock point as a middleware for testing, example:
+//
+//	r.Use(web.RouteMockPoint("my-mock-point-1"))
+//	r.Get("/foo", middleware2, web.RouteMockPoint("my-mock-point-2"), middleware2, handler)
+//
+// Then use web.RouteMock to mock the route execution.
+// It only takes effect in testing mode (setting.IsInTesting == true).
+func RouteMockPoint(pointName string) func(next http.Handler) http.Handler {
+	if !setting.IsInTesting {
+		return nil
+	}
+	routeMockPoints[pointName] = nil
+	return func(next http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			if h := routeMockPoints[pointName]; h != nil {
+				h(next).ServeHTTP(w, r)
+			} else {
+				next.ServeHTTP(w, r)
+			}
+		})
+	}
+}
+
+// RouteMock uses the registered mock point to mock the route execution, example:
+//
+//	defer web.RouteMockReset()
+//	web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) {
+//		ctx.WriteResponse(...)
+//	}
+//
+// Then the mock function will be executed as a middleware at the mock point.
+// It only takes effect in testing mode (setting.IsInTesting == true).
+func RouteMock(pointName string, h any) {
+	if _, ok := routeMockPoints[pointName]; !ok {
+		panic("route mock point not found: " + pointName)
+	}
+	routeMockPoints[pointName] = toHandlerProvider(h)
+}
+
+// RouteMockReset resets all mock points (no mock anymore)
+func RouteMockReset() {
+	for k := range routeMockPoints {
+		routeMockPoints[k] = nil // keep the keys because RouteMock will check the keys to make sure no misspelling
+	}
+}
diff --git a/modules/web/routemock_test.go b/modules/web/routemock_test.go
new file mode 100644
index 0000000000..04c6d1d82e
--- /dev/null
+++ b/modules/web/routemock_test.go
@@ -0,0 +1,70 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package web
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"code.gitea.io/gitea/modules/setting"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRouteMock(t *testing.T) {
+	setting.IsInTesting = true
+
+	r := NewRoute()
+	middleware1 := func(resp http.ResponseWriter, req *http.Request) {
+		resp.Header().Set("X-Test-Middleware1", "m1")
+	}
+	middleware2 := func(resp http.ResponseWriter, req *http.Request) {
+		resp.Header().Set("X-Test-Middleware2", "m2")
+	}
+	handler := func(resp http.ResponseWriter, req *http.Request) {
+		resp.Header().Set("X-Test-Handler", "h")
+	}
+	r.Get("/foo", middleware1, RouteMockPoint("mock-point"), middleware2, handler)
+
+	// normal request
+	recorder := httptest.NewRecorder()
+	req, err := http.NewRequest("GET", "http://localhost:8000/foo", nil)
+	assert.NoError(t, err)
+	r.ServeHTTP(recorder, req)
+	assert.Len(t, recorder.Header(), 3)
+	assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
+	assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2"))
+	assert.EqualValues(t, "h", recorder.Header().Get("X-Test-Handler"))
+	RouteMockReset()
+
+	// mock at "mock-point"
+	RouteMock("mock-point", func(resp http.ResponseWriter, req *http.Request) {
+		resp.Header().Set("X-Test-MockPoint", "a")
+		resp.WriteHeader(http.StatusOK)
+	})
+	recorder = httptest.NewRecorder()
+	req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil)
+	assert.NoError(t, err)
+	r.ServeHTTP(recorder, req)
+	assert.Len(t, recorder.Header(), 2)
+	assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
+	assert.EqualValues(t, "a", recorder.Header().Get("X-Test-MockPoint"))
+	RouteMockReset()
+
+	// mock at MockAfterMiddlewares
+	RouteMock(MockAfterMiddlewares, func(resp http.ResponseWriter, req *http.Request) {
+		resp.Header().Set("X-Test-MockPoint", "b")
+		resp.WriteHeader(http.StatusOK)
+	})
+	recorder = httptest.NewRecorder()
+	req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil)
+	assert.NoError(t, err)
+	r.ServeHTTP(recorder, req)
+	assert.Len(t, recorder.Header(), 3)
+	assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
+	assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2"))
+	assert.EqualValues(t, "b", recorder.Header().Get("X-Test-MockPoint"))
+	RouteMockReset()
+}