From 99fa4581aa0e8c309f0745687d8aa7e832b31ae4 Mon Sep 17 00:00:00 2001 From: jordi collell Date: Sun, 10 May 2015 08:20:58 +0200 Subject: [PATCH] basicauth: patch for overlapping rules --- middleware/basicauth/basicauth.go | 20 +++++++++++++++++--- middleware/basicauth/basicauth_test.go | 6 +++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index de29b8d9..25d1d504 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -19,6 +19,10 @@ type BasicAuth struct { // ServeHTTP implements the middleware.Handler interface. func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + + var hasAuth bool + var isAuthenticated bool + for _, rule := range a.Rules { for _, res := range rule.Resources { if !middleware.Path(r.URL.Path).Matches(res) { @@ -27,16 +31,26 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error // Path matches; parse auth header username, password, ok := r.BasicAuth() + hasAuth = true // Check credentials if !ok || username != rule.Username || password != rule.Password { - w.Header().Set("WWW-Authenticate", "Basic") - return http.StatusUnauthorized, nil + continue } - + // flag set only on success authentication + isAuthenticated = true + } + } + + if hasAuth { + if !isAuthenticated { + w.Header().Set("WWW-Authenticate", "Basic") + return http.StatusUnauthorized, nil + } else { // "It's an older code, sir, but it checks out. I was about to clear them." return a.Next.ServeHTTP(w, r) } + } // Pass-thru when no paths match diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go index 04d9fc83..b590bc35 100644 --- a/middleware/basicauth/basicauth_test.go +++ b/middleware/basicauth/basicauth_test.go @@ -84,12 +84,13 @@ func TestMultipleOverlappingRules(t *testing.T) { {"/t", http.StatusOK, "t:p1"}, {"/t/t", http.StatusOK, "t:p1"}, {"/t/t", http.StatusOK, "t1:p2"}, - + {"/a", http.StatusOK, "t1:p2"}, + {"/t/t", http.StatusUnauthorized, "t1:p3"}, + {"/t", http.StatusUnauthorized, "t1:p2"}, } for i, test := range tests { - req, err := http.NewRequest("GET", test.from, nil) if err != nil { @@ -108,7 +109,6 @@ func TestMultipleOverlappingRules(t *testing.T) { i, test.result, result) } - } }