Simplify how dialing and listening works

Also add support for bare TCP
This commit is contained in:
Sasha Koshka 2025-10-20 16:44:07 -04:00
parent da01a0d119
commit 09b2259a8c
3 changed files with 61 additions and 43 deletions

50
dial.go
View File

@ -5,53 +5,57 @@ import "errors"
import "context"
import "crypto/tls"
// Dial opens a connection to a server. The network must be one of "quic",
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and
// "quic6" don't do anything as the quic-go package doesn't seem to support this
// behavior.
func Dial(ctx context.Context, network, address string) (Conn, error) {
return (Dialer { }).Dial(ctx, network, address)
}
// Dialer allows for further configuration of the dialing process.
type Dialer struct {
TLSConfig *tls.Config
}
// Dial opens a connection to a server. The network must be one of:
//
// - "quic"
// - "quic4" (IPv4-only)
// - "quic6" (IPv6-only)
// - "tls"
// - "tls4" (IPv4-only)
// - "tls6" (IPv6-only)
// - "tcp"
// - "tcp4" (IPv4-only)
// - "tcp6" (IPv6-only)
// - "unix"
//
// For now, QUIC is unsupported.
func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) {
func Dial(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
switch network {
case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address)
case "tcp", "tcp4", "tcp6": return diale.dialTLS(ctx, network, address)
case "unix": return diale.dialUnix(ctx, network, address)
case "quic", "quic4", "quic6": return DialQUIC(ctx, network, address, tlsConf)
case "tls", "tls4", "tls6": return DialTLS(ctx, network, address, tlsConf)
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil { return nil, err }
return DialTCP(ctx, network, nil, addr)
case "unix":
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
return DialUnix(ctx, network, addr)
default: return nil, ErrUnknownNetwork
}
}
func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) {
// DialQUIC opens a connection to a server over QUIC.
func DialQUIC(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
return nil, errors.New("quic is not yet implemented")
}
func (diale Dialer) dialTLS(ctx context.Context, network, address string) (Conn, error) {
conn, err := tls.Dial(network, address, diale.TLSConfig)
// DialTLS opens a connection to a server over TLS.
func DialTLS(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
conn, err := tls.Dial(network, address, tlsConf)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil
}
func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) {
if network != "unix" { return nil, ErrUnknownNetwork }
addr, err := net.ResolveUnixAddr(network, address)
// DialTCP opens a connection to a server over TCP.
func DialTCP(ctx context.Context, network string, laddr, raddr *net.TCPAddr) (Conn, error) {
conn, err := net.DialTCP(network, laddr, raddr)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil
}
// DialUnix opens a connection to a server over a Unix domain socket.
func DialUnix(ctx context.Context, network string, addr *net.UnixAddr) (Conn, error) {
conn, err := net.DialUnix(network, nil, addr)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil

View File

@ -20,6 +20,9 @@ type Listener interface {
// - "quic"
// - "quic4" (IPv4-only)
// - "quic6" (IPv6-only)
// - "tls"
// - "tls4" (IPv4-only)
// - "tls6" (IPv6-only)
// - "tcp"
// - "tcp4" (IPv4-only)
// - "tcp6" (IPv6-only)
@ -29,9 +32,17 @@ type Listener interface {
func Listen(network, address string, tlsConf *tls.Config) (Listener, error) {
switch network {
case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf)
case "tcp", "tcp4", "tcp6": return ListenTLS(network, address, tlsConf)
case "unix": return ListenUnix(network, address)
default: return nil, ErrUnknownNetwork
case "tls", "tls4", "tls6": return ListenTLS(network, address, tlsConf)
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil { return nil, err }
return ListenTCP(network, addr)
case "unix":
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
return ListenUnix(network, addr)
default:
return nil, ErrUnknownNetwork
}
}
@ -46,7 +57,6 @@ func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error)
// ListenTLS listens for incoming HOPP connections using a TLS socket as a
// transport. The network must be "tcp".
func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
if network != "tcp" { return nil, ErrUnknownNetwork }
listener, err := tls.Listen(network, address, tlsConf)
if err != nil { return nil, err }
return &netListenerWrapper {
@ -54,12 +64,19 @@ func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
}, nil
}
// ListenTCP listens for incoming HOPP connections using a TCP socket as a
// transport. The network must be "tcp".
func ListenTCP(network string, laddr *net.TCPAddr) (Listener, error) {
listener, err := net.ListenTCP(network, laddr)
if err != nil { return nil, err }
return &netListenerWrapper {
underlying: listener,
}, nil
}
// ListenUnix listens for incoming HOPP connections using a Unix domain socket
// as a transport. The network must be "unix".
func ListenUnix(network, address string) (Listener, error) {
if network != "unix" { return nil, ErrUnknownNetwork }
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
func ListenUnix(network string, addr *net.UnixAddr) (Listener, error) {
listener, err := net.ListenUnix(network, addr)
if err != nil { return nil, err }
return &netListenerWrapper {

View File

@ -1,11 +1,11 @@
package hopp
import "io"
import "net"
import "bytes"
import "errors"
import "slices"
import "testing"
import "context"
// some of these tests spawn goroutines that can signal a failure.
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
@ -230,34 +230,31 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server
addr := "localhost:7959"
// server
listener, err := net.Listen(network, addr)
if err != nil { test.Fatal(err) }
listener, err := Listen(network, addr, nil)
test.Cleanup(func() { listener.Close() })
go func() {
test.Log("SERVER listening")
conn, err := listener.Accept()
if err != nil { test.Error("SERVER", err); return }
defer conn.Close()
test.Cleanup(func() { conn.Close() })
a := AdaptA(conn, ServerSide)
test.Cleanup(func() { a.Close() })
serverFunc(a)
serverFunc(conn)
test.Log("SERVER closing")
}()
// client
test.Log("CLIENT dialing")
conn, err := net.Dial(network, addr)
conn, err := Dial(context.Background(), network, addr, nil)
if err != nil { test.Fatal("CLIENT", err) }
test.Cleanup(func() { conn.Close() })
test.Log("CLIENT dialed")
a := AdaptA(conn, ClientSide)
test.Cleanup(func() { a.Close() })
clientFunc(a)
clientFunc(conn)
test.Log("CLIENT waiting for connection close...")
trans, err := a.AcceptTrans()
trans, err := conn.AcceptTrans()
if !errors.Is(err, io.EOF) {
test.Error("CLIENT wrong error:", err)
test.Fatal("CLIENT trans:", trans)