package middleware import ( "bytes" "fmt" "io/ioutil" "net/http" "net/url" "os" "path/filepath" "strings" "testing" "time" "text/template" ) func TestInclude(t *testing.T) { context := getContextOrFail(t) inputFilename := "test_file" absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) defer func() { err := os.Remove(absInFilePath) if err != nil && !os.IsNotExist(err) { t.Fatalf("Failed to clean test file!") } }() tests := []struct { fileContent string expectedContent string shouldErr bool expectedErrorContent string }{ // Test 0 - all good { fileContent: `str1 {{ .Root }} str2`, expectedContent: fmt.Sprintf("str1 %s str2", context.Root), shouldErr: false, expectedErrorContent: "", }, // Test 1 - failure on template.Parse { fileContent: `str1 {{ .Root } str2`, expectedContent: "", shouldErr: true, expectedErrorContent: `unexpected "}" in operand`, }, // Test 3 - failure on template.Execute { fileContent: `str1 {{ .InvalidField }} str2`, expectedContent: "", shouldErr: true, expectedErrorContent: `InvalidField`, }, { fileContent: `str1 {{ .InvalidField }} str2`, expectedContent: "", shouldErr: true, expectedErrorContent: `type middleware.Context`, }, } for i, test := range tests { testPrefix := getTestPrefix(i) // WriteFile truncates the contentt err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) if err != nil { t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) } content, err := context.Include(inputFilename) if err != nil { if !test.shouldErr { t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) } if !strings.Contains(err.Error(), test.expectedErrorContent) { t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) } } if err == nil && test.shouldErr { t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) } if content != test.expectedContent { t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) } } } func TestIncludeNotExisting(t *testing.T) { context := getContextOrFail(t) _, err := context.Include("not_existing") if err == nil { t.Errorf("Expected error but found nil!") } } func TestMarkdown(t *testing.T) { context := getContextOrFail(t) inputFilename := "test_file" absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) defer func() { err := os.Remove(absInFilePath) if err != nil && !os.IsNotExist(err) { t.Fatalf("Failed to clean test file!") } }() tests := []struct { fileContent string expectedContent string }{ // Test 0 - test parsing of markdown { fileContent: "* str1\n* str2\n", expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n", }, } for i, test := range tests { testPrefix := getTestPrefix(i) // WriteFile truncates the contentt err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) if err != nil { t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) } content, _ := context.Markdown(inputFilename) if content != test.expectedContent { t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) } } } func TestCookie(t *testing.T) { tests := []struct { cookie *http.Cookie cookieName string expectedValue string }{ // Test 0 - happy path { cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, cookieName: "cookieName", expectedValue: "cookieValue", }, // Test 1 - try to get a non-existing cookie { cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, cookieName: "notExisting", expectedValue: "", }, // Test 2 - partial name match { cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"}, cookieName: "cook", expectedValue: "", }, // Test 3 - cookie with optional fields { cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, cookieName: "cookie", expectedValue: "cookieValue", }, } for i, test := range tests { testPrefix := getTestPrefix(i) // reinitialize the context for each test context := getContextOrFail(t) context.Req.AddCookie(test.cookie) actualCookieVal := context.Cookie(test.cookieName) if actualCookieVal != test.expectedValue { t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) } } } func TestCookieMultipleCookies(t *testing.T) { context := getContextOrFail(t) cookieNameBase, cookieValueBase := "cookieName", "cookieValue" // make sure that there's no state and multiple requests for different cookies return the correct result for i := 0; i < 10; i++ { context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) } for i := 0; i < 10; i++ { expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) if actualCookieVal != expectedCookieVal { t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) } } } func TestHeader(t *testing.T) { context := getContextOrFail(t) headerKey, headerVal := "Header1", "HeaderVal1" context.Req.Header.Add(headerKey, headerVal) actualHeaderVal := context.Header(headerKey) if actualHeaderVal != headerVal { t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) } missingHeaderVal := context.Header("not-existing") if missingHeaderVal != "" { t.Errorf("Expected empty header value, found %s", missingHeaderVal) } } func TestIP(t *testing.T) { context := getContextOrFail(t) tests := []struct { inputRemoteAddr string expectedIP string }{ // Test 0 - ipv4 with port {"1.1.1.1:1111", "1.1.1.1"}, // Test 1 - ipv4 without port {"1.1.1.1", "1.1.1.1"}, // Test 2 - ipv6 with port {"[::1]:11", "::1"}, // Test 3 - ipv6 without port and brackets {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, // Test 4 - ipv6 with zone and port {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, } for i, test := range tests { testPrefix := getTestPrefix(i) context.Req.RemoteAddr = test.inputRemoteAddr actualIP := context.IP() if actualIP != test.expectedIP { t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) } } } func TestURL(t *testing.T) { context := getContextOrFail(t) inputURL := "http://localhost" context.Req.RequestURI = inputURL if inputURL != context.URI() { t.Errorf("Expected url %s, found %s", inputURL, context.URI()) } } func TestHost(t *testing.T) { tests := []struct { input string expectedHost string shouldErr bool }{ { input: "localhost:123", expectedHost: "localhost", shouldErr: false, }, { input: "localhost", expectedHost: "localhost", shouldErr: false, }, { input: "[::]", expectedHost: "", shouldErr: true, }, } for _, test := range tests { testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) } } func TestPort(t *testing.T) { tests := []struct { input string expectedPort string shouldErr bool }{ { input: "localhost:123", expectedPort: "123", shouldErr: false, }, { input: "localhost", expectedPort: "80", // assuming 80 is the default port shouldErr: false, }, { input: ":8080", expectedPort: "8080", shouldErr: false, }, { input: "[::]", expectedPort: "", shouldErr: true, }, } for _, test := range tests { testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) } } func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { context := getContextOrFail(t) context.Req.Host = input var actualResult, testedObject string var err error if isTestingHost { actualResult, err = context.Host() testedObject = "host" } else { actualResult, err = context.Port() testedObject = "port" } if shouldErr && err == nil { t.Errorf("Expected error, found nil!") return } if !shouldErr && err != nil { t.Errorf("Expected no error, found %s", err) return } if actualResult != expectedResult { t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) } } func TestMethod(t *testing.T) { context := getContextOrFail(t) method := "POST" context.Req.Method = method if method != context.Method() { t.Errorf("Expected method %s, found %s", method, context.Method()) } } func TestPathMatches(t *testing.T) { context := getContextOrFail(t) tests := []struct { urlStr string pattern string shouldMatch bool }{ // Test 0 { urlStr: "http://localhost/", pattern: "", shouldMatch: true, }, // Test 1 { urlStr: "http://localhost", pattern: "", shouldMatch: true, }, // Test 1 { urlStr: "http://localhost/", pattern: "/", shouldMatch: true, }, // Test 3 { urlStr: "http://localhost/?param=val", pattern: "/", shouldMatch: true, }, // Test 4 { urlStr: "http://localhost/dir1/dir2", pattern: "/dir2", shouldMatch: false, }, // Test 5 { urlStr: "http://localhost/dir1/dir2", pattern: "/dir1", shouldMatch: true, }, // Test 6 { urlStr: "http://localhost:444/dir1/dir2", pattern: "/dir1", shouldMatch: true, }, // Test 7 { urlStr: "http://localhost/dir1/dir2", pattern: "*/dir2", shouldMatch: false, }, } for i, test := range tests { testPrefix := getTestPrefix(i) var err error context.Req.URL, err = url.Parse(test.urlStr) if err != nil { t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) } matches := context.PathMatches(test.pattern) if matches != test.shouldMatch { t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) } } } func TestTruncate(t *testing.T) { context := getContextOrFail(t) tests := []struct { inputString string inputLength int expected string }{ // Test 0 - small length { inputString: "string", inputLength: 1, expected: "s", }, // Test 1 - exact length { inputString: "string", inputLength: 6, expected: "string", }, // Test 2 - bigger length { inputString: "string", inputLength: 10, expected: "string", }, // Test 3 - zero length { inputString: "string", inputLength: 0, expected: "", }, // Test 4 - negative, smaller length { inputString: "string", inputLength: -5, expected: "tring", }, // Test 5 - negative, exact length { inputString: "string", inputLength: -6, expected: "string", }, // Test 6 - negative, bigger length { inputString: "string", inputLength: -7, expected: "string", }, } for i, test := range tests { actual := context.Truncate(test.inputString, test.inputLength) if actual != test.expected { t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) } } } func TestStripHTML(t *testing.T) { context := getContextOrFail(t) tests := []struct { input string expected string }{ // Test 0 - no tags { input: `h1`, expected: `h1`, }, // Test 1 - happy path { input: `<h1>h1</h1>`, expected: `h1`, }, // Test 2 - tag in quotes { input: `<h1">">h1</h1>`, expected: `h1`, }, // Test 3 - multiple tags { input: `<h1><b>h1</b></h1>`, expected: `h1`, }, // Test 4 - tags not closed { input: `<h1`, expected: `<h1`, }, // Test 5 - false start { input: `<h1<b>hi`, expected: `<h1hi`, }, } for i, test := range tests { actual := context.StripHTML(test.input) if actual != test.expected { t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input) } } } func TestStripExt(t *testing.T) { context := getContextOrFail(t) tests := []struct { input string expected string }{ // Test 0 - empty input { input: "", expected: "", }, // Test 1 - relative file with ext { input: "file.ext", expected: "file", }, // Test 2 - relative file without ext { input: "file", expected: "file", }, // Test 3 - absolute file without ext { input: "/file", expected: "/file", }, // Test 4 - absolute file with ext { input: "/file.ext", expected: "/file", }, // Test 5 - with ext but ends with / { input: "/dir.ext/", expected: "/dir.ext/", }, // Test 6 - file with ext under dir with ext { input: "/dir.ext/file.ext", expected: "/dir.ext/file", }, } for i, test := range tests { actual := context.StripExt(test.input) if actual != test.expected { t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input) } } } func initTestContext() (Context, error) { body := bytes.NewBufferString("request body") request, err := http.NewRequest("GET", "https://localhost", body) if err != nil { return Context{}, err } return Context{Root: http.Dir(os.TempDir()), Req: request}, nil } func getContextOrFail(t *testing.T) Context { context, err := initTestContext() if err != nil { t.Fatalf("Failed to prepare test context") } return context } func getTestPrefix(testN int) string { return fmt.Sprintf("Test [%d]: ", testN) } func TestTemplates(t *testing.T) { tests := []struct{ tmpl, expected string }{ {`{{.ToUpper "aAA"}}`, "AAA"}, {`{{"bbb" | .ToUpper}}`, "BBB"}, {`{{.ToLower "CCc"}}`, "ccc"}, {`{{range (.Split "a,b,c" ",")}}{{.}}{{end}}`, "abc"}, {`{{range .Split "a,b,c" ","}}{{.}}{{end}}`, "abc"}, {`{{range .Slice "a" "b" "c"}}{{.}}{{end}}`, "abc"}, {`{{with .Map "A" "a" "B" "b" "c" "d"}}{{.A}}{{.B}}{{.c}}{{end}}`, "abd"}, } for i, test := range tests { ctx := getContextOrFail(t) tmpl, err := template.New("").Parse(test.tmpl) if err != nil { t.Errorf("Test %d: %s", i, err) continue } buf := &bytes.Buffer{} err = tmpl.Execute(buf, ctx) if err != nil { t.Errorf("Test %d: %s", i, err) continue } if buf.String() != test.expected { t.Errorf("Test %d: Results do not match. '%s' != '%s'", i, buf.String(), test.expected) } } }