diff --git a/dial.go b/dial.go index ce7b312..bdb93f1 100644 --- a/dial.go +++ b/dial.go @@ -5,53 +5,57 @@ import "errors" import "context" import "crypto/tls" -// Dial opens a connection to a server. The network must be one of "quic", -// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and -// "quic6" don't do anything as the quic-go package doesn't seem to support this -// behavior. -func Dial(ctx context.Context, network, address string) (Conn, error) { - return (Dialer { }).Dial(ctx, network, address) -} - -// Dialer allows for further configuration of the dialing process. -type Dialer struct { - TLSConfig *tls.Config -} - // Dial opens a connection to a server. The network must be one of: // // - "quic" // - "quic4" (IPv4-only) // - "quic6" (IPv6-only) +// - "tls" +// - "tls4" (IPv4-only) +// - "tls6" (IPv6-only) // - "tcp" // - "tcp4" (IPv4-only) // - "tcp6" (IPv6-only) // - "unix" // // For now, QUIC is unsupported. -func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) { +func Dial(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) { switch network { - case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address) - case "tcp", "tcp4", "tcp6": return diale.dialTLS(ctx, network, address) - case "unix": return diale.dialUnix(ctx, network, address) + case "quic", "quic4", "quic6": return DialQUIC(ctx, network, address, tlsConf) + case "tls", "tls4", "tls6": return DialTLS(ctx, network, address, tlsConf) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { return nil, err } + return DialTCP(ctx, network, nil, addr) + case "unix": + addr, err := net.ResolveUnixAddr(network, address) + if err != nil { return nil, err } + return DialUnix(ctx, network, addr) default: return nil, ErrUnknownNetwork } } -func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) { +// DialQUIC opens a connection to a server over QUIC. +func DialQUIC(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) { return nil, errors.New("quic is not yet implemented") } -func (diale Dialer) dialTLS(ctx context.Context, network, address string) (Conn, error) { - conn, err := tls.Dial(network, address, diale.TLSConfig) +// DialTLS opens a connection to a server over TLS. +func DialTLS(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) { + conn, err := tls.Dial(network, address, tlsConf) if err != nil { return nil, err } return AdaptA(conn, ClientSide), nil } -func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) { - if network != "unix" { return nil, ErrUnknownNetwork } - addr, err := net.ResolveUnixAddr(network, address) +// DialTCP opens a connection to a server over TCP. +func DialTCP(ctx context.Context, network string, laddr, raddr *net.TCPAddr) (Conn, error) { + conn, err := net.DialTCP(network, laddr, raddr) if err != nil { return nil, err } + return AdaptA(conn, ClientSide), nil +} + +// DialUnix opens a connection to a server over a Unix domain socket. +func DialUnix(ctx context.Context, network string, addr *net.UnixAddr) (Conn, error) { conn, err := net.DialUnix(network, nil, addr) if err != nil { return nil, err } return AdaptA(conn, ClientSide), nil diff --git a/listen.go b/listen.go index 51d355f..9c33d07 100644 --- a/listen.go +++ b/listen.go @@ -20,6 +20,9 @@ type Listener interface { // - "quic" // - "quic4" (IPv4-only) // - "quic6" (IPv6-only) +// - "tls" +// - "tls4" (IPv4-only) +// - "tls6" (IPv6-only) // - "tcp" // - "tcp4" (IPv4-only) // - "tcp6" (IPv6-only) @@ -29,9 +32,17 @@ type Listener interface { func Listen(network, address string, tlsConf *tls.Config) (Listener, error) { switch network { case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf) - case "tcp", "tcp4", "tcp6": return ListenTLS(network, address, tlsConf) - case "unix": return ListenUnix(network, address) - default: return nil, ErrUnknownNetwork + case "tls", "tls4", "tls6": return ListenTLS(network, address, tlsConf) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { return nil, err } + return ListenTCP(network, addr) + case "unix": + addr, err := net.ResolveUnixAddr(network, address) + if err != nil { return nil, err } + return ListenUnix(network, addr) + default: + return nil, ErrUnknownNetwork } } @@ -46,7 +57,6 @@ func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error) // ListenTLS listens for incoming HOPP connections using a TLS socket as a // transport. The network must be "tcp". func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) { - if network != "tcp" { return nil, ErrUnknownNetwork } listener, err := tls.Listen(network, address, tlsConf) if err != nil { return nil, err } return &netListenerWrapper { @@ -54,12 +64,19 @@ func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) { }, nil } +// ListenTCP listens for incoming HOPP connections using a TCP socket as a +// transport. The network must be "tcp". +func ListenTCP(network string, laddr *net.TCPAddr) (Listener, error) { + listener, err := net.ListenTCP(network, laddr) + if err != nil { return nil, err } + return &netListenerWrapper { + underlying: listener, + }, nil +} + // ListenUnix listens for incoming HOPP connections using a Unix domain socket // as a transport. The network must be "unix". -func ListenUnix(network, address string) (Listener, error) { - if network != "unix" { return nil, ErrUnknownNetwork } - addr, err := net.ResolveUnixAddr(network, address) - if err != nil { return nil, err } +func ListenUnix(network string, addr *net.UnixAddr) (Listener, error) { listener, err := net.ListenUnix(network, addr) if err != nil { return nil, err } return &netListenerWrapper { diff --git a/metadapta_test.go b/metadapta_test.go index 62dfdd9..be27b89 100644 --- a/metadapta_test.go +++ b/metadapta_test.go @@ -1,11 +1,11 @@ package hopp import "io" -import "net" import "bytes" import "errors" import "slices" import "testing" +import "context" // some of these tests spawn goroutines that can signal a failure. // abide by the documentation for testing.T (https://pkg.go.dev/testing#T): @@ -230,34 +230,31 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server addr := "localhost:7959" // server - listener, err := net.Listen(network, addr) - if err != nil { test.Fatal(err) } + listener, err := Listen(network, addr, nil) test.Cleanup(func() { listener.Close() }) go func() { test.Log("SERVER listening") conn, err := listener.Accept() if err != nil { test.Error("SERVER", err); return } + defer conn.Close() test.Cleanup(func() { conn.Close() }) - a := AdaptA(conn, ServerSide) - test.Cleanup(func() { a.Close() }) - - serverFunc(a) + + serverFunc(conn) test.Log("SERVER closing") }() // client test.Log("CLIENT dialing") - conn, err := net.Dial(network, addr) + conn, err := Dial(context.Background(), network, addr, nil) if err != nil { test.Fatal("CLIENT", err) } + test.Cleanup(func() { conn.Close() }) test.Log("CLIENT dialed") - a := AdaptA(conn, ClientSide) - test.Cleanup(func() { a.Close() }) - clientFunc(a) + clientFunc(conn) test.Log("CLIENT waiting for connection close...") - trans, err := a.AcceptTrans() + trans, err := conn.AcceptTrans() if !errors.Is(err, io.EOF) { test.Error("CLIENT wrong error:", err) test.Fatal("CLIENT trans:", trans)