mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-09 12:28:49 +03:00
Merge branch 'master' into getcertificate
This commit is contained in:
commit
d25a3e95e4
16 changed files with 343 additions and 24 deletions
|
@ -69,6 +69,23 @@ var directiveOrder = []directive{
|
||||||
{"browse", setup.Browse},
|
{"browse", setup.Browse},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterDirective adds the given directive to caddy's list of directives.
|
||||||
|
// Pass the name of a directive you want it to be placed after,
|
||||||
|
// otherwise it will be placed at the bottom of the stack.
|
||||||
|
func RegisterDirective(name string, setup SetupFunc, after string) {
|
||||||
|
dir := directive{name: name, setup: setup}
|
||||||
|
idx := len(directiveOrder)
|
||||||
|
for i := range directiveOrder {
|
||||||
|
if directiveOrder[i].name == after {
|
||||||
|
idx = i + 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...)
|
||||||
|
directiveOrder = newDirectives
|
||||||
|
parse.ValidDirectives[name] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// directive ties together a directive name with its setup function.
|
// directive ties together a directive name with its setup function.
|
||||||
type directive struct {
|
type directive struct {
|
||||||
name string
|
name string
|
||||||
|
|
31
caddy/directives_test.go
Normal file
31
caddy/directives_test.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package caddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegister(t *testing.T) {
|
||||||
|
directives := []directive{
|
||||||
|
{"dummy", nil},
|
||||||
|
{"dummy2", nil},
|
||||||
|
}
|
||||||
|
directiveOrder = directives
|
||||||
|
RegisterDirective("foo", nil, "dummy")
|
||||||
|
if len(directiveOrder) != 3 {
|
||||||
|
t.Fatal("Should have 3 directives now")
|
||||||
|
}
|
||||||
|
getNames := func() (s []string) {
|
||||||
|
for _, d := range directiveOrder {
|
||||||
|
s = append(s, d.name)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) {
|
||||||
|
t.Fatalf("directive order doesn't match: %s", getNames())
|
||||||
|
}
|
||||||
|
RegisterDirective("bar", nil, "ASDASD")
|
||||||
|
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) {
|
||||||
|
t.Fatalf("directive order doesn't match: %s", getNames())
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,9 +23,9 @@ func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname, _, err := net.SplitHostPort(r.URL.Host)
|
hostname, _, err := net.SplitHostPort(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostname = r.URL.Host
|
hostname = r.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort)
|
upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort)
|
||||||
|
|
8
dist/README.txt
vendored
8
dist/README.txt
vendored
|
@ -8,9 +8,11 @@ Source Code
|
||||||
https://github.com/mholt/caddy
|
https://github.com/mholt/caddy
|
||||||
|
|
||||||
|
|
||||||
For instructions on using Caddy, please see the user guide on the website. For a list of what's new in this version, see CHANGES.txt.
|
For instructions on using Caddy, please see the user guide on the website.
|
||||||
|
For a list of what's new in this version, see CHANGES.txt.
|
||||||
|
|
||||||
If you have a question, bug report, or would like to contribute, please open an issue or submit a pull request on GitHub. Your contributions do not go unnoticed!
|
If you have a question, bug report, or would like to contribute, please open an
|
||||||
|
issue or submit a pull request on GitHub. Your contributions do not go unnoticed!
|
||||||
|
|
||||||
For a good time, follow @mholt6 on Twitter.
|
For a good time, follow @mholt6 on Twitter.
|
||||||
|
|
||||||
|
@ -18,4 +20,4 @@ And thanks - you're awesome!
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
(c) 2015 Matthew Holt
|
(c) 2015 - 2016 Matthew Holt
|
||||||
|
|
|
@ -139,7 +139,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
|
||||||
if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil {
|
if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil {
|
||||||
t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err)
|
t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err)
|
||||||
}
|
}
|
||||||
t.Logf("%d. username=%q password=%v", i, rule.Username, rule.Password)
|
t.Logf("%d. username=%q", i, rule.Username)
|
||||||
if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") {
|
if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") {
|
||||||
t.Errorf("%d (%s) password does not match.", i, rule.Username)
|
t.Errorf("%d (%s) password does not match.", i, rule.Username)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/russross/blackfriday"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This file contains the context and functions available for
|
// This file contains the context and functions available for
|
||||||
|
@ -190,3 +192,17 @@ func (c Context) StripExt(path string) string {
|
||||||
func (c Context) Replace(input, find, replacement string) string {
|
func (c Context) Replace(input, find, replacement string) string {
|
||||||
return strings.Replace(input, find, replacement, -1)
|
return strings.Replace(input, find, replacement, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Markdown returns the HTML contents of the markdown contained in filename
|
||||||
|
// (relative to the site root).
|
||||||
|
func (c Context) Markdown(filename string) (string, error) {
|
||||||
|
body, err := c.Include(filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
renderer := blackfriday.HtmlRenderer(0, "", "")
|
||||||
|
extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH | blackfriday.EXTENSION_DEFINITION_LISTS
|
||||||
|
markdown := blackfriday.Markdown([]byte(body), renderer, extns)
|
||||||
|
|
||||||
|
return string(markdown), nil
|
||||||
|
}
|
||||||
|
|
|
@ -92,6 +92,45 @@ func TestIncludeNotExisting(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarkdown(t *testing.T) {
|
||||||
|
context := getContextOrFail(t)
|
||||||
|
|
||||||
|
inputFilename := "test_file"
|
||||||
|
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
|
||||||
|
defer func() {
|
||||||
|
err := os.Remove(absInFilePath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("Failed to clean test file!")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
fileContent string
|
||||||
|
expectedContent string
|
||||||
|
}{
|
||||||
|
// Test 0 - test parsing of markdown
|
||||||
|
{
|
||||||
|
fileContent: "* str1\n* str2\n",
|
||||||
|
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
testPrefix := getTestPrefix(i)
|
||||||
|
|
||||||
|
// WriteFile truncates the contentt
|
||||||
|
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, _ := context.Markdown(inputFilename)
|
||||||
|
if content != test.expectedContent {
|
||||||
|
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCookie(t *testing.T) {
|
func TestCookie(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|
|
@ -70,7 +70,8 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to FastCGI gateway
|
// Connect to FastCGI gateway
|
||||||
fcgi, err := getClient(&rule)
|
network, address := rule.parseAddress()
|
||||||
|
fcgi, err := Dial(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusBadGateway, err
|
return http.StatusBadGateway, err
|
||||||
}
|
}
|
||||||
|
@ -128,15 +129,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
||||||
return h.Next.ServeHTTP(w, r)
|
return h.Next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getClient(r *Rule) (*FCGIClient, error) {
|
// parseAddress returns the network and address of r.
|
||||||
// check if unix socket or TCP
|
// The first string is the network, "tcp" or "unix", implied from the scheme and address.
|
||||||
|
// The second string is r.Address, with scheme prefixes removed.
|
||||||
|
// The two returned strings can be used as parameters to the Dial() function.
|
||||||
|
func (r Rule) parseAddress() (string, string) {
|
||||||
|
// check if address has tcp scheme explicitly set
|
||||||
|
if strings.HasPrefix(r.Address, "tcp://") {
|
||||||
|
return "tcp", r.Address[len("tcp://"):]
|
||||||
|
}
|
||||||
|
// check if address has fastcgi scheme explicity set
|
||||||
|
if strings.HasPrefix(r.Address, "fastcgi://") {
|
||||||
|
return "tcp", r.Address[len("fastcgi://"):]
|
||||||
|
}
|
||||||
|
// check if unix socket
|
||||||
if trim := strings.HasPrefix(r.Address, "unix"); strings.HasPrefix(r.Address, "/") || trim {
|
if trim := strings.HasPrefix(r.Address, "unix"); strings.HasPrefix(r.Address, "/") || trim {
|
||||||
if trim {
|
if trim {
|
||||||
r.Address = r.Address[len("unix:"):]
|
return "unix", r.Address[len("unix:"):]
|
||||||
}
|
}
|
||||||
return Dial("unix", r.Address)
|
return "unix", r.Address
|
||||||
}
|
}
|
||||||
return Dial("tcp", r.Address)
|
// default case, a plain tcp address with no scheme
|
||||||
|
return "tcp", r.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeHeader(w http.ResponseWriter, r *http.Response) {
|
func writeHeader(w http.ResponseWriter, r *http.Response) {
|
||||||
|
@ -168,7 +182,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
||||||
|
|
||||||
// Separate remote IP and port; more lenient than net.SplitHostPort
|
// Separate remote IP and port; more lenient than net.SplitHostPort
|
||||||
var ip, port string
|
var ip, port string
|
||||||
if idx := strings.Index(r.RemoteAddr, ":"); idx > -1 {
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
|
||||||
ip = r.RemoteAddr[:idx]
|
ip = r.RemoteAddr[:idx]
|
||||||
port = r.RemoteAddr[idx+1:]
|
port = r.RemoteAddr[idx+1:]
|
||||||
} else {
|
} else {
|
||||||
|
|
95
middleware/fastcgi/fastcgi_test.go
Normal file
95
middleware/fastcgi/fastcgi_test.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package fastcgi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRuleParseAddress(t *testing.T) {
|
||||||
|
|
||||||
|
getClientTestTable := []struct {
|
||||||
|
rule *Rule
|
||||||
|
expectednetwork string
|
||||||
|
expectedaddress string
|
||||||
|
}{
|
||||||
|
{&Rule{Address: "tcp://172.17.0.1:9000"}, "tcp", "172.17.0.1:9000"},
|
||||||
|
{&Rule{Address: "fastcgi://localhost:9000"}, "tcp", "localhost:9000"},
|
||||||
|
{&Rule{Address: "172.17.0.15"}, "tcp", "172.17.0.15"},
|
||||||
|
{&Rule{Address: "/my/unix/socket"}, "unix", "/my/unix/socket"},
|
||||||
|
{&Rule{Address: "unix:/second/unix/socket"}, "unix", "/second/unix/socket"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range getClientTestTable {
|
||||||
|
if actualnetwork, _ := entry.rule.parseAddress(); actualnetwork != entry.expectednetwork {
|
||||||
|
t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address, actualnetwork, entry.expectednetwork)
|
||||||
|
}
|
||||||
|
if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress {
|
||||||
|
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildEnv(t *testing.T) {
|
||||||
|
|
||||||
|
buildEnvSingle := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string, t *testing.T) {
|
||||||
|
|
||||||
|
h := Handler{}
|
||||||
|
|
||||||
|
env, err := h.buildEnv(r, rule, fpath)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range envExpected {
|
||||||
|
if env[k] != v {
|
||||||
|
t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{}
|
||||||
|
url, err := url.Parse("http://localhost:2015/fgci_test.php?test=blabla")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
r := http.Request{
|
||||||
|
Method: "GET",
|
||||||
|
URL: url,
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Host: "localhost:2015",
|
||||||
|
RemoteAddr: "[2b02:1810:4f2d:9400:70ab:f822:be8a:9093]:51688",
|
||||||
|
RequestURI: "/fgci_test.php",
|
||||||
|
}
|
||||||
|
|
||||||
|
fpath := "/fgci_test.php"
|
||||||
|
|
||||||
|
var envExpected = map[string]string{
|
||||||
|
"REMOTE_ADDR": "[2b02:1810:4f2d:9400:70ab:f822:be8a:9093]",
|
||||||
|
"REMOTE_PORT": "51688",
|
||||||
|
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||||
|
"QUERY_STRING": "test=blabla",
|
||||||
|
"REQUEST_METHOD": "GET",
|
||||||
|
"HTTP_HOST": "localhost:2015",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Test for full canonical IPv6 address
|
||||||
|
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||||
|
|
||||||
|
// 2. Test for shorthand notation of IPv6 address
|
||||||
|
r.RemoteAddr = "[::1]:51688"
|
||||||
|
envExpected["REMOTE_ADDR"] = "[::1]"
|
||||||
|
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||||
|
|
||||||
|
// 3. Test for IPv4 address
|
||||||
|
r.RemoteAddr = "192.168.0.10:51688"
|
||||||
|
envExpected["REMOTE_ADDR"] = "192.168.0.10"
|
||||||
|
buildEnvSingle(&r, rule, fpath, envExpected, t)
|
||||||
|
|
||||||
|
}
|
|
@ -169,12 +169,11 @@ type FCGIClient struct {
|
||||||
reqID uint16
|
reqID uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the fcgi responder at the specified network address.
|
// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
|
||||||
// See func net.Dial for a description of the network and address parameters.
|
// See func net.Dial for a description of the network and address parameters.
|
||||||
func Dial(network, address string) (fcgi *FCGIClient, err error) {
|
func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
|
conn, err = dialer.Dial(network, address)
|
||||||
conn, err = net.Dial(network, address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -188,6 +187,12 @@ func Dial(network, address string) (fcgi *FCGIClient, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
|
||||||
|
// See func net.Dial for a description of the network and address parameters.
|
||||||
|
func Dial(network, address string) (fcgi *FCGIClient, err error) {
|
||||||
|
return DialWithDialer(network, address, net.Dialer{})
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes fcgi connnection
|
// Close closes fcgi connnection
|
||||||
func (c *FCGIClient) Close() {
|
func (c *FCGIClient) Close() {
|
||||||
c.rwc.Close()
|
c.rwc.Close()
|
||||||
|
|
|
@ -68,7 +68,7 @@ func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middlewa
|
||||||
}
|
}
|
||||||
|
|
||||||
// process markdown
|
// process markdown
|
||||||
extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH
|
extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH | blackfriday.EXTENSION_DEFINITION_LISTS
|
||||||
markdown = blackfriday.Markdown(markdown, c.Renderer, extns)
|
markdown = blackfriday.Markdown(markdown, c.Renderer, extns)
|
||||||
|
|
||||||
// set it as body for template
|
// set it as body for template
|
||||||
|
|
|
@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second
|
||||||
|
|
||||||
// ServeHTTP satisfies the middleware.Handler interface.
|
// ServeHTTP satisfies the middleware.Handler interface.
|
||||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
|
||||||
for _, upstream := range p.Upstreams {
|
for _, upstream := range p.Upstreams {
|
||||||
if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
|
if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
|
||||||
var replacer middleware.Replacer
|
var replacer middleware.Replacer
|
||||||
|
|
|
@ -3,6 +3,7 @@ package proxy
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
@ -13,7 +14,9 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketProxy(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
trialMsg := "Is it working?"
|
||||||
|
|
||||||
|
var proxySuccess bool
|
||||||
|
|
||||||
|
// This is our fake "application" we want to proxy to
|
||||||
|
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Request was proxied when this is called
|
||||||
|
proxySuccess = true
|
||||||
|
|
||||||
|
fmt.Fprint(w, trialMsg)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Get absolute path for unix: socket
|
||||||
|
socketPath, err := filepath.Abs("./test_socket")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to get absolute path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change httptest.Server listener to listen to unix: socket
|
||||||
|
ln, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to listen: %v", err)
|
||||||
|
}
|
||||||
|
ts.Listener = ln
|
||||||
|
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||||
|
p := newWebSocketTestProxy(url)
|
||||||
|
|
||||||
|
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer echoProxy.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(echoProxy.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to GET: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
greeting, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to GET: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
actualMsg := fmt.Sprintf("%s", greeting)
|
||||||
|
|
||||||
|
if !proxySuccess {
|
||||||
|
t.Errorf("Expected request to be proxied, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualMsg != trialMsg {
|
||||||
|
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||||
uri, _ := url.Parse(name)
|
uri, _ := url.Parse(name)
|
||||||
u := &fakeUpstream{
|
u := &fakeUpstream{
|
||||||
|
|
|
@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string {
|
||||||
return a + b
|
return a + b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Though the relevant directive prefix is just "unix:", url.Parse
|
||||||
|
// will - assuming the regular URL scheme - add additional slashes
|
||||||
|
// as if "unix" was a request protocol.
|
||||||
|
// What we need is just the path, so if "unix:/var/run/www.socket"
|
||||||
|
// was the proxy directive, the parsed hostName would be
|
||||||
|
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
||||||
|
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
||||||
|
return func(network, addr string) (conn net.Conn, err error) {
|
||||||
|
return net.Dial("unix", hostName[len("unix://"):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
||||||
// URLs to the scheme, host, and base path provided in target. If the
|
// URLs to the scheme, host, and base path provided in target. If the
|
||||||
// target's path is "/base" and the incoming request was for "/dir",
|
// target's path is "/base" and the incoming request was for "/dir",
|
||||||
|
@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string {
|
||||||
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
|
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
|
||||||
targetQuery := target.RawQuery
|
targetQuery := target.RawQuery
|
||||||
director := func(req *http.Request) {
|
director := func(req *http.Request) {
|
||||||
|
if target.Scheme == "unix" {
|
||||||
|
// to make Dial work with unix URL,
|
||||||
|
// scheme and host have to be faked
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
req.URL.Host = "socket"
|
||||||
|
} else {
|
||||||
req.URL.Scheme = target.Scheme
|
req.URL.Scheme = target.Scheme
|
||||||
req.URL.Host = target.Host
|
req.URL.Host = target.Host
|
||||||
|
}
|
||||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||||
|
@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &ReverseProxy{Director: director}
|
rp := &ReverseProxy{Director: director}
|
||||||
|
if target.Scheme == "unix" {
|
||||||
|
rp.Transport = &http.Transport{
|
||||||
|
Dial: socketDial(target.String()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rp
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
func copyHeader(dst, src http.Header) {
|
||||||
|
@ -104,6 +129,9 @@ var hopHeaders = []string{
|
||||||
"Upgrade",
|
"Upgrade",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsecureTransport is used to facilitate HTTPS proxying
|
||||||
|
// when it is OK for upstream to be using a bad certificate,
|
||||||
|
// since this transport skips verification.
|
||||||
var InsecureTransport http.RoundTripper = &http.Transport{
|
var InsecureTransport http.RoundTripper = &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: (&net.Dialer{
|
||||||
|
|
|
@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
|
||||||
|
|
||||||
upstream.Hosts = make([]*UpstreamHost, len(to))
|
upstream.Hosts = make([]*UpstreamHost, len(to))
|
||||||
for i, host := range to {
|
for i, host := range to {
|
||||||
if !strings.HasPrefix(host, "http") {
|
if !strings.HasPrefix(host, "http") &&
|
||||||
|
!strings.HasPrefix(host, "unix:") {
|
||||||
host = "http://" + host
|
host = "http://" + host
|
||||||
}
|
}
|
||||||
uh := &UpstreamHost{
|
uh := &UpstreamHost{
|
||||||
|
|
|
@ -329,9 +329,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
DefaultErrorFunc(w, r, status)
|
DefaultErrorFunc(w, r, status)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Get the remote host
|
||||||
|
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
remoteHost = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
fmt.Fprintf(w, "No such host at %s", s.Server.Addr)
|
fmt.Fprintf(w, "No such host at %s", s.Server.Addr)
|
||||||
log.Printf("[INFO] %s - No such host at %s", host, s.Server.Addr)
|
log.Printf("[INFO] %s - No such host at %s (requested by %s)", host, s.Server.Addr, remoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue