Simplify how dialing and listening works
Also add support for bare TCP
This commit is contained in:
parent
da01a0d119
commit
09b2259a8c
50
dial.go
50
dial.go
@ -5,53 +5,57 @@ import "errors"
|
|||||||
import "context"
|
import "context"
|
||||||
import "crypto/tls"
|
import "crypto/tls"
|
||||||
|
|
||||||
// Dial opens a connection to a server. The network must be one of "quic",
|
|
||||||
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and
|
|
||||||
// "quic6" don't do anything as the quic-go package doesn't seem to support this
|
|
||||||
// behavior.
|
|
||||||
func Dial(ctx context.Context, network, address string) (Conn, error) {
|
|
||||||
return (Dialer { }).Dial(ctx, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dialer allows for further configuration of the dialing process.
|
|
||||||
type Dialer struct {
|
|
||||||
TLSConfig *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial opens a connection to a server. The network must be one of:
|
// Dial opens a connection to a server. The network must be one of:
|
||||||
//
|
//
|
||||||
// - "quic"
|
// - "quic"
|
||||||
// - "quic4" (IPv4-only)
|
// - "quic4" (IPv4-only)
|
||||||
// - "quic6" (IPv6-only)
|
// - "quic6" (IPv6-only)
|
||||||
|
// - "tls"
|
||||||
|
// - "tls4" (IPv4-only)
|
||||||
|
// - "tls6" (IPv6-only)
|
||||||
// - "tcp"
|
// - "tcp"
|
||||||
// - "tcp4" (IPv4-only)
|
// - "tcp4" (IPv4-only)
|
||||||
// - "tcp6" (IPv6-only)
|
// - "tcp6" (IPv6-only)
|
||||||
// - "unix"
|
// - "unix"
|
||||||
//
|
//
|
||||||
// For now, QUIC is unsupported.
|
// For now, QUIC is unsupported.
|
||||||
func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) {
|
func Dial(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||||
switch network {
|
switch network {
|
||||||
case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address)
|
case "quic", "quic4", "quic6": return DialQUIC(ctx, network, address, tlsConf)
|
||||||
case "tcp", "tcp4", "tcp6": return diale.dialTLS(ctx, network, address)
|
case "tls", "tls4", "tls6": return DialTLS(ctx, network, address, tlsConf)
|
||||||
case "unix": return diale.dialUnix(ctx, network, address)
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
addr, err := net.ResolveTCPAddr(network, address)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
return DialTCP(ctx, network, nil, addr)
|
||||||
|
case "unix":
|
||||||
|
addr, err := net.ResolveUnixAddr(network, address)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
return DialUnix(ctx, network, addr)
|
||||||
default: return nil, ErrUnknownNetwork
|
default: return nil, ErrUnknownNetwork
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) {
|
// DialQUIC opens a connection to a server over QUIC.
|
||||||
|
func DialQUIC(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||||
return nil, errors.New("quic is not yet implemented")
|
return nil, errors.New("quic is not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (diale Dialer) dialTLS(ctx context.Context, network, address string) (Conn, error) {
|
// DialTLS opens a connection to a server over TLS.
|
||||||
conn, err := tls.Dial(network, address, diale.TLSConfig)
|
func DialTLS(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||||
|
conn, err := tls.Dial(network, address, tlsConf)
|
||||||
if err != nil { return nil, err }
|
if err != nil { return nil, err }
|
||||||
return AdaptA(conn, ClientSide), nil
|
return AdaptA(conn, ClientSide), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) {
|
// DialTCP opens a connection to a server over TCP.
|
||||||
if network != "unix" { return nil, ErrUnknownNetwork }
|
func DialTCP(ctx context.Context, network string, laddr, raddr *net.TCPAddr) (Conn, error) {
|
||||||
addr, err := net.ResolveUnixAddr(network, address)
|
conn, err := net.DialTCP(network, laddr, raddr)
|
||||||
if err != nil { return nil, err }
|
if err != nil { return nil, err }
|
||||||
|
return AdaptA(conn, ClientSide), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialUnix opens a connection to a server over a Unix domain socket.
|
||||||
|
func DialUnix(ctx context.Context, network string, addr *net.UnixAddr) (Conn, error) {
|
||||||
conn, err := net.DialUnix(network, nil, addr)
|
conn, err := net.DialUnix(network, nil, addr)
|
||||||
if err != nil { return nil, err }
|
if err != nil { return nil, err }
|
||||||
return AdaptA(conn, ClientSide), nil
|
return AdaptA(conn, ClientSide), nil
|
||||||
|
|||||||
33
listen.go
33
listen.go
@ -20,6 +20,9 @@ type Listener interface {
|
|||||||
// - "quic"
|
// - "quic"
|
||||||
// - "quic4" (IPv4-only)
|
// - "quic4" (IPv4-only)
|
||||||
// - "quic6" (IPv6-only)
|
// - "quic6" (IPv6-only)
|
||||||
|
// - "tls"
|
||||||
|
// - "tls4" (IPv4-only)
|
||||||
|
// - "tls6" (IPv6-only)
|
||||||
// - "tcp"
|
// - "tcp"
|
||||||
// - "tcp4" (IPv4-only)
|
// - "tcp4" (IPv4-only)
|
||||||
// - "tcp6" (IPv6-only)
|
// - "tcp6" (IPv6-only)
|
||||||
@ -29,9 +32,17 @@ type Listener interface {
|
|||||||
func Listen(network, address string, tlsConf *tls.Config) (Listener, error) {
|
func Listen(network, address string, tlsConf *tls.Config) (Listener, error) {
|
||||||
switch network {
|
switch network {
|
||||||
case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf)
|
case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf)
|
||||||
case "tcp", "tcp4", "tcp6": return ListenTLS(network, address, tlsConf)
|
case "tls", "tls4", "tls6": return ListenTLS(network, address, tlsConf)
|
||||||
case "unix": return ListenUnix(network, address)
|
case "tcp", "tcp4", "tcp6":
|
||||||
default: return nil, ErrUnknownNetwork
|
addr, err := net.ResolveTCPAddr(network, address)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
return ListenTCP(network, addr)
|
||||||
|
case "unix":
|
||||||
|
addr, err := net.ResolveUnixAddr(network, address)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
return ListenUnix(network, addr)
|
||||||
|
default:
|
||||||
|
return nil, ErrUnknownNetwork
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +57,6 @@ func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error)
|
|||||||
// ListenTLS listens for incoming HOPP connections using a TLS socket as a
|
// ListenTLS listens for incoming HOPP connections using a TLS socket as a
|
||||||
// transport. The network must be "tcp".
|
// transport. The network must be "tcp".
|
||||||
func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
|
func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
|
||||||
if network != "tcp" { return nil, ErrUnknownNetwork }
|
|
||||||
listener, err := tls.Listen(network, address, tlsConf)
|
listener, err := tls.Listen(network, address, tlsConf)
|
||||||
if err != nil { return nil, err }
|
if err != nil { return nil, err }
|
||||||
return &netListenerWrapper {
|
return &netListenerWrapper {
|
||||||
@ -54,12 +64,19 @@ func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListenTCP listens for incoming HOPP connections using a TCP socket as a
|
||||||
|
// transport. The network must be "tcp".
|
||||||
|
func ListenTCP(network string, laddr *net.TCPAddr) (Listener, error) {
|
||||||
|
listener, err := net.ListenTCP(network, laddr)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
return &netListenerWrapper {
|
||||||
|
underlying: listener,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ListenUnix listens for incoming HOPP connections using a Unix domain socket
|
// ListenUnix listens for incoming HOPP connections using a Unix domain socket
|
||||||
// as a transport. The network must be "unix".
|
// as a transport. The network must be "unix".
|
||||||
func ListenUnix(network, address string) (Listener, error) {
|
func ListenUnix(network string, addr *net.UnixAddr) (Listener, error) {
|
||||||
if network != "unix" { return nil, ErrUnknownNetwork }
|
|
||||||
addr, err := net.ResolveUnixAddr(network, address)
|
|
||||||
if err != nil { return nil, err }
|
|
||||||
listener, err := net.ListenUnix(network, addr)
|
listener, err := net.ListenUnix(network, addr)
|
||||||
if err != nil { return nil, err }
|
if err != nil { return nil, err }
|
||||||
return &netListenerWrapper {
|
return &netListenerWrapper {
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
package hopp
|
package hopp
|
||||||
|
|
||||||
import "io"
|
import "io"
|
||||||
import "net"
|
|
||||||
import "bytes"
|
import "bytes"
|
||||||
import "errors"
|
import "errors"
|
||||||
import "slices"
|
import "slices"
|
||||||
import "testing"
|
import "testing"
|
||||||
|
import "context"
|
||||||
|
|
||||||
// some of these tests spawn goroutines that can signal a failure.
|
// some of these tests spawn goroutines that can signal a failure.
|
||||||
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
|
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
|
||||||
@ -230,34 +230,31 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server
|
|||||||
addr := "localhost:7959"
|
addr := "localhost:7959"
|
||||||
|
|
||||||
// server
|
// server
|
||||||
listener, err := net.Listen(network, addr)
|
listener, err := Listen(network, addr, nil)
|
||||||
if err != nil { test.Fatal(err) }
|
|
||||||
test.Cleanup(func() { listener.Close() })
|
test.Cleanup(func() { listener.Close() })
|
||||||
go func() {
|
go func() {
|
||||||
test.Log("SERVER listening")
|
test.Log("SERVER listening")
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil { test.Error("SERVER", err); return }
|
if err != nil { test.Error("SERVER", err); return }
|
||||||
|
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
test.Cleanup(func() { conn.Close() })
|
test.Cleanup(func() { conn.Close() })
|
||||||
a := AdaptA(conn, ServerSide)
|
|
||||||
test.Cleanup(func() { a.Close() })
|
|
||||||
|
|
||||||
serverFunc(a)
|
serverFunc(conn)
|
||||||
test.Log("SERVER closing")
|
test.Log("SERVER closing")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// client
|
// client
|
||||||
test.Log("CLIENT dialing")
|
test.Log("CLIENT dialing")
|
||||||
conn, err := net.Dial(network, addr)
|
conn, err := Dial(context.Background(), network, addr, nil)
|
||||||
if err != nil { test.Fatal("CLIENT", err) }
|
if err != nil { test.Fatal("CLIENT", err) }
|
||||||
|
test.Cleanup(func() { conn.Close() })
|
||||||
test.Log("CLIENT dialed")
|
test.Log("CLIENT dialed")
|
||||||
a := AdaptA(conn, ClientSide)
|
|
||||||
test.Cleanup(func() { a.Close() })
|
|
||||||
|
|
||||||
clientFunc(a)
|
clientFunc(conn)
|
||||||
|
|
||||||
test.Log("CLIENT waiting for connection close...")
|
test.Log("CLIENT waiting for connection close...")
|
||||||
trans, err := a.AcceptTrans()
|
trans, err := conn.AcceptTrans()
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
test.Error("CLIENT wrong error:", err)
|
test.Error("CLIENT wrong error:", err)
|
||||||
test.Fatal("CLIENT trans:", trans)
|
test.Fatal("CLIENT trans:", trans)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user