// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build darwin dragonfly freebsd linux netbsd openbsd

package test

import (
	"bytes"
	"io"
	"io/ioutil"
	"math/rand"
	"net"
	"testing"
	"time"
)

func TestPortForward(t *testing.T) {
	server := newServer(t)
	defer server.Shutdown()
	conn := server.Dial(clientConfig())
	defer conn.Close()

	sshListener, err := conn.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatal(err)
	}

	go func() {
		sshConn, err := sshListener.Accept()
		if err != nil {
			t.Fatalf("listen.Accept failed: %v", err)
		}

		_, err = io.Copy(sshConn, sshConn)
		if err != nil && err != io.EOF {
			t.Fatalf("ssh client copy: %v", err)
		}
		sshConn.Close()
	}()

	forwardedAddr := sshListener.Addr().String()
	tcpConn, err := net.Dial("tcp", forwardedAddr)
	if err != nil {
		t.Fatalf("TCP dial failed: %v", err)
	}

	readChan := make(chan []byte)
	go func() {
		data, _ := ioutil.ReadAll(tcpConn)
		readChan <- data
	}()

	// Invent some data.
	data := make([]byte, 100*1000)
	for i := range data {
		data[i] = byte(i % 255)
	}

	var sent []byte
	for len(sent) < 1000*1000 {
		// Send random sized chunks
		m := rand.Intn(len(data))
		n, err := tcpConn.Write(data[:m])
		if err != nil {
			break
		}
		sent = append(sent, data[:n]...)
	}
	if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
		t.Errorf("tcpConn.CloseWrite: %v", err)
	}

	read := <-readChan

	if len(sent) != len(read) {
		t.Fatalf("got %d bytes, want %d", len(read), len(sent))
	}
	if bytes.Compare(sent, read) != 0 {
		t.Fatalf("read back data does not match")
	}

	if err := sshListener.Close(); err != nil {
		t.Fatalf("sshListener.Close: %v", err)
	}

	// Check that the forward disappeared.
	tcpConn, err = net.Dial("tcp", forwardedAddr)
	if err == nil {
		tcpConn.Close()
		t.Errorf("still listening to %s after closing", forwardedAddr)
	}
}

func TestAcceptClose(t *testing.T) {
	server := newServer(t)
	defer server.Shutdown()
	conn := server.Dial(clientConfig())

	sshListener, err := conn.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatal(err)
	}

	quit := make(chan error, 1)
	go func() {
		for {
			c, err := sshListener.Accept()
			if err != nil {
				quit <- err
				break
			}
			c.Close()
		}
	}()
	sshListener.Close()

	select {
	case <-time.After(1 * time.Second):
		t.Errorf("timeout: listener did not close.")
	case err := <-quit:
		t.Logf("quit as expected (error %v)", err)
	}
}

// Check that listeners exit if the underlying client transport dies.
func TestPortForwardConnectionClose(t *testing.T) {
	server := newServer(t)
	defer server.Shutdown()
	conn := server.Dial(clientConfig())

	sshListener, err := conn.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatal(err)
	}

	quit := make(chan error, 1)
	go func() {
		for {
			c, err := sshListener.Accept()
			if err != nil {
				quit <- err
				break
			}
			c.Close()
		}
	}()

	// It would be even nicer if we closed the server side, but it
	// is more involved as the fd for that side is dup()ed.
	server.clientConn.Close()

	select {
	case <-time.After(1 * time.Second):
		t.Errorf("timeout: listener did not close.")
	case err := <-quit:
		t.Logf("quit as expected (error %v)", err)
	}
}