Compare commits

...

9 Commits

9 changed files with 215 additions and 121 deletions

62
dial.go
View File

@ -5,53 +5,59 @@ import "errors"
import "context"
import "crypto/tls"
// Dial opens a connection to a server. The network must be one of "quic",
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and
// "quic6" don't do anything as the quic-go package doesn't seem to support this
// behavior.
func Dial(ctx context.Context, network, address string) (Conn, error) {
return (Dialer { }).Dial(ctx, network, address)
}
// Dialer allows for further configuration of the dialing process.
type Dialer struct {
TLSConfig *tls.Config
}
// Dial opens a connection to a server. The network must be one of:
//
// - "quic"
// - "quic4" (IPv4-only)
// - "quic6" (IPv6-only)
// - "tls"
// - "tls4" (IPv4-only)
// - "tls6" (IPv6-only)
// - "tcp"
// - "tcp4" (IPv4-only)
// - "tcp6" (IPv6-only)
// - "unix"
//
// For now, QUIC is unsupported.
func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) {
func Dial(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
switch network {
case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address)
case "tcp", "tcp4", "tcp6": return diale.dialTLS(ctx, network, address)
case "unix": return diale.dialUnix(ctx, network, address)
case "quic", "quic4", "quic6": return DialQUIC(ctx, network, address, tlsConf)
case "tls", "tls4", "tls6": return DialTLS(ctx, network, address, tlsConf)
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil { return nil, err }
return DialTCP(ctx, network, nil, addr)
case "unix":
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
return DialUnix(ctx, network, addr)
default: return nil, ErrUnknownNetwork
}
}
func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) {
// DialQUIC opens a connection to a server over QUIC.
func DialQUIC(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
return nil, errors.New("quic is not yet implemented")
}
func (diale Dialer) dialTLS(ctx context.Context, network, address string) (Conn, error) {
conn, err := tls.Dial(network, address, diale.TLSConfig)
// DialTLS opens a connection to a server over TLS.
func DialTLS(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
network, err := tlsNetworkToTCPNetwork(network)
if err != nil { return nil, err }
conn, err := tls.Dial(network, address, tlsConf)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil
}
func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) {
if network != "unix" { return nil, ErrUnknownNetwork }
addr, err := net.ResolveUnixAddr(network, address)
// DialTCP opens a connection to a server over TCP.
func DialTCP(ctx context.Context, network string, laddr, raddr *net.TCPAddr) (Conn, error) {
conn, err := net.DialTCP(network, laddr, raddr)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil
}
// DialUnix opens a connection to a server over a Unix domain socket.
func DialUnix(ctx context.Context, network string, addr *net.UnixAddr) (Conn, error) {
conn, err := net.DialUnix(network, nil, addr)
if err != nil { return nil, err }
return AdaptA(conn, ClientSide), nil
@ -69,7 +75,6 @@ func tlsConfig(conf *tls.Config) *tls.Config {
return conf
}
func quicNetworkToUDPNetwork(network string) (string, error) {
switch network {
case "quic4": return "udp4", nil
@ -78,3 +83,12 @@ func quicNetworkToUDPNetwork(network string) (string, error) {
default: return "", ErrUnknownNetwork
}
}
func tlsNetworkToTCPNetwork(network string) (string, error) {
switch network {
case "tls4": return "tcp4", nil
case "tls6": return "tcp6", nil
case "tls": return "tcp", nil
default: return "", ErrUnknownNetwork
}
}

View File

@ -47,13 +47,10 @@ func main() {
func join(address string, room string, nickname string) (hopp.Trans, error) {
ctx, done := context.WithTimeout(context.Background(), 16 * time.Second)
defer done()
dialer := hopp.Dialer {
TLSConfig: &tls.Config {
conn, err := hopp.Dial(ctx, "tls", address, &tls.Config {
// don't actually do this in real life
InsecureSkipVerify: true,
},
}
conn, err := dialer.Dial(ctx, "tcp", address)
})
if err != nil { return nil, err }
transRoom, err := conn.OpenTrans()

View File

@ -5,6 +5,7 @@ package chat
// Please edit that file instead, and re-compile it to this location.
// HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
import "fmt"
import "git.tebibyte.media/sashakoshka/hopp"
import "git.tebibyte.media/sashakoshka/hopp/tape"
@ -26,7 +27,9 @@ func Send(trans hopp.Trans, message Message) (n int, err error) {
if err != nil { return n, err }
defer writer.Close()
encoder := tape.NewEncoder(writer)
return message.Encode(encoder)
n, err = message.Encode(encoder)
if err != nil { return n, err }
return n, encoder.Flush()
}
// canAssign determines if data from the given source tag can be assigned to
@ -81,29 +84,29 @@ func(this *MessageError) Encode(encoder *tape.Encoder) (n int, err error) {
nn, err = encoder.WriteUintN(2, tag_1.CN() + 1)
n += nn; if err != nil { return n, err }
{
nn, err = encoder.WriteUint16(0x0000)
n += nn; if err != nil { return n, err }
tag_2 := tape.LI.WithCN(1)
nn, err = encoder.WriteUint8(uint8(tag_2))
n += nn; if err != nil { return n, err }
nn, err = encoder.WriteUint16(uint16((*this).Code))
n += nn; if err != nil { return n, err }
if value, ok := (*this).Description.Value(); ok {
nn, err = encoder.WriteUint16(0x0001)
n += nn; if err != nil { return n, err }
tag_3 := tape.StringTag(string(value))
nn, err = encoder.WriteUint8(uint8(tag_3))
tag_2 := tape.StringTag(string(value))
nn, err = encoder.WriteUint8(uint8(tag_2))
n += nn; if err != nil { return n, err }
if len(value) > tape.MaxStructureLength {
return n, tape.ErrTooLong
}
if tag_3.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len(value)), tag_3.CN())
if tag_2.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len(value)), tag_2.CN())
n += nn; if err != nil { return n, err }
}
nn, err = encoder.Write([]byte(value))
n += nn; if err != nil { return n, err }
}
nn, err = encoder.WriteUint16(0x0000)
n += nn; if err != nil { return n, err }
tag_3 := tape.LI.WithCN(1)
nn, err = encoder.WriteUint8(uint8(tag_3))
n += nn; if err != nil { return n, err }
nn, err = encoder.WriteUint16(uint16((*this).Code))
n += nn; if err != nil { return n, err }
}
return n, nil
}
@ -251,34 +254,34 @@ func(this *MessageChat) Encode(encoder *tape.Encoder) (n int, err error) {
nn, err = encoder.WriteUintN(2, tag_8.CN() + 1)
n += nn; if err != nil { return n, err }
{
nn, err = encoder.WriteUint16(0x0001)
n += nn; if err != nil { return n, err }
tag_9 := tape.StringTag(string((*this).Nickname))
nn, err = encoder.WriteUint8(uint8(tag_9))
n += nn; if err != nil { return n, err }
if len((*this).Nickname) > tape.MaxStructureLength {
return n, tape.ErrTooLong
}
if tag_9.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len((*this).Nickname)), tag_9.CN())
n += nn; if err != nil { return n, err }
}
nn, err = encoder.Write([]byte((*this).Nickname))
n += nn; if err != nil { return n, err }
nn, err = encoder.WriteUint16(0x0000)
n += nn; if err != nil { return n, err }
tag_10 := tape.StringTag(string((*this).Content))
nn, err = encoder.WriteUint8(uint8(tag_10))
tag_9 := tape.StringTag(string((*this).Content))
nn, err = encoder.WriteUint8(uint8(tag_9))
n += nn; if err != nil { return n, err }
if len((*this).Content) > tape.MaxStructureLength {
return n, tape.ErrTooLong
}
if tag_10.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len((*this).Content)), tag_10.CN())
if tag_9.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len((*this).Content)), tag_9.CN())
n += nn; if err != nil { return n, err }
}
nn, err = encoder.Write([]byte((*this).Content))
n += nn; if err != nil { return n, err }
nn, err = encoder.WriteUint16(0x0001)
n += nn; if err != nil { return n, err }
tag_10 := tape.StringTag(string((*this).Nickname))
nn, err = encoder.WriteUint8(uint8(tag_10))
n += nn; if err != nil { return n, err }
if len((*this).Nickname) > tape.MaxStructureLength {
return n, tape.ErrTooLong
}
if tag_10.Is(tape.LBA) {
nn, err = encoder.WriteUintN(uint64(len((*this).Nickname)), tag_10.CN())
n += nn; if err != nil { return n, err }
}
nn, err = encoder.Write([]byte((*this).Nickname))
n += nn; if err != nil { return n, err }
}
return n, nil
}
@ -695,36 +698,36 @@ func Receive(trans hopp.Trans) (message any, n int, err error) {
decoder := tape.NewDecoder(reader)
if err != nil { return nil, n, err }
switch method {
case 0001:
var message MessageSuccess
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0200:
var message MessageJoin
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0300:
var message MessageChat
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0400:
var message MessageJoinNotify
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0401:
case 0x0401:
var message MessageLeaveNotify
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0000:
case 0x0000:
var message MessageError
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0x0001:
var message MessageSuccess
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0x0200:
var message MessageJoin
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0x0300:
var message MessageChat
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
case 0x0400:
var message MessageJoinNotify
nn, err := message.Decode(decoder)
n += nn; if err != nil { return nil, n, err }
return message, n, nil
}
return nil, n, hopp.ErrUnknownMethod
return nil, n, fmt.Errorf("%w: M%04X", hopp.ErrUnknownMethod, method)
}

View File

@ -28,7 +28,7 @@ func main() {
func host(address string, certPath, keyPath string) error {
keyPair, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil { return err }
listener, err := hopp.ListenTLS("tcp", address, &tls.Config {
listener, err := hopp.Listen("tls", address, &tls.Config {
InsecureSkipVerify: true,
Certificates: []tls.Certificate { keyPair },
})
@ -60,11 +60,11 @@ func (this *client) run() {
for {
log.Println("accepting transaction")
trans, err := this.conn.AcceptTrans()
log.Println("accepted transaction")
if err != nil {
log.Printf("XXX %v failed: %v", this.conn.RemoteAddr(), err)
continue
return
}
log.Println("accepted transaction")
go this.runTrans(trans)
}
}

View File

@ -11,6 +11,7 @@ import "git.tebibyte.media/sashakoshka/hopp/tape"
const imports =
`
import "fmt"
import "git.tebibyte.media/sashakoshka/hopp"
import "git.tebibyte.media/sashakoshka/hopp/tape"
`
@ -41,7 +42,9 @@ func Send(trans hopp.Trans, message Message) (n int, err error) {
if err != nil { return n, err }
defer writer.Close()
encoder := tape.NewEncoder(writer)
return message.Encode(encoder)
n, err = message.Encode(encoder)
if err != nil { return n, err }
return n, encoder.Flush()
}
// canAssign determines if data from the given source tag can be assigned to
@ -1216,7 +1219,7 @@ func (this *Generator) generateReceive() (n int, err error) {
nn, err = this.iprintf("switch method {\n")
n += nn; if err != nil { return n, err }
for method, message := range this.protocol.Messages {
nn, err = this.iprintf("case %04X:\n", method)
nn, err = this.iprintf("case 0x%04X:\n", method)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
@ -1234,7 +1237,7 @@ func (this *Generator) generateReceive() (n int, err error) {
}
nn, err = this.iprint("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return nil, n, hopp.ErrUnknownMethod\n")
nn, err = this.iprint("return nil, n, fmt.Errorf(\"%w: M%04X\", hopp.ErrUnknownMethod, method)\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprint("}\n")

View File

@ -0,0 +1,56 @@
package connshark
import "os"
import "io"
import "fmt"
import "net"
import "log"
import "sync"
import "math/rand"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
type insert struct {
net.Conn
output io.WriteCloser
lock sync.Mutex
}
func LogDebugFile(underlying net.Conn) net.Conn {
file, err := os.Create(fmt.Sprintf("connshark-%08X.log", rand.Uint32()))
if err != nil {
log.Println("XXX COULD NOT OPEN DEBUG FILE! reason: ", err)
return underlying
}
return Log(underlying, file)
}
func Log(underlying net.Conn, output io.WriteCloser) net.Conn {
return &insert {
Conn: underlying,
output: output,
}
}
func (this *insert) Read(buffer []byte) (n int, err error) {
if n > 0 {
this.lock.Lock()
defer this.lock.Unlock()
fmt.Fprintf(this.output, "TX: %s\n", tu.HexBytes(buffer[:n]))
}
return n, err
}
func (this *insert) Write(buffer []byte) (n int, err error) {
n, err = this.Conn.Write(buffer)
if n > 0 {
this.lock.Lock()
defer this.lock.Unlock()
fmt.Fprintf(this.output, "RX: %s\n", tu.HexBytes(buffer[:n]))
}
return n, err
}
func (this *insert) Close() error {
this.output.Close()
return this.Conn.Close()
}

View File

@ -20,6 +20,9 @@ type Listener interface {
// - "quic"
// - "quic4" (IPv4-only)
// - "quic6" (IPv6-only)
// - "tls"
// - "tls4" (IPv4-only)
// - "tls6" (IPv6-only)
// - "tcp"
// - "tcp4" (IPv4-only)
// - "tcp6" (IPv6-only)
@ -29,9 +32,17 @@ type Listener interface {
func Listen(network, address string, tlsConf *tls.Config) (Listener, error) {
switch network {
case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf)
case "tcp", "tcp4", "tcp6": return ListenTLS(network, address, tlsConf)
case "unix": return ListenUnix(network, address)
default: return nil, ErrUnknownNetwork
case "tls", "tls4", "tls6": return ListenTLS(network, address, tlsConf)
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil { return nil, err }
return ListenTCP(network, addr)
case "unix":
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
return ListenUnix(network, addr)
default:
return nil, ErrUnknownNetwork
}
}
@ -46,7 +57,8 @@ func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error)
// ListenTLS listens for incoming HOPP connections using a TLS socket as a
// transport. The network must be "tcp".
func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
if network != "tcp" { return nil, ErrUnknownNetwork }
network, err := tlsNetworkToTCPNetwork(network)
if err != nil { return nil, err }
listener, err := tls.Listen(network, address, tlsConf)
if err != nil { return nil, err }
return &netListenerWrapper {
@ -54,12 +66,19 @@ func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
}, nil
}
// ListenTCP listens for incoming HOPP connections using a TCP socket as a
// transport. The network must be "tcp".
func ListenTCP(network string, laddr *net.TCPAddr) (Listener, error) {
listener, err := net.ListenTCP(network, laddr)
if err != nil { return nil, err }
return &netListenerWrapper {
underlying: listener,
}, nil
}
// ListenUnix listens for incoming HOPP connections using a Unix domain socket
// as a transport. The network must be "unix".
func ListenUnix(network, address string) (Listener, error) {
if network != "unix" { return nil, ErrUnknownNetwork }
addr, err := net.ResolveUnixAddr(network, address)
if err != nil { return nil, err }
func ListenUnix(network string, addr *net.UnixAddr) (Listener, error) {
listener, err := net.ListenUnix(network, addr)
if err != nil { return nil, err }
return &netListenerWrapper {

View File

@ -124,10 +124,10 @@ func (this *a) unlistTransactionSafe(id int64) {
delete(this.transMap, id)
}
func (this *a) sendMessageSafe(trans int64, method uint16, data []byte) error {
func (this *a) sendMessageSafe(trans int64, method uint16, ccb uint64, data []byte) error {
this.sendLock.Lock()
defer this.sendLock.Unlock()
return encodeMessageA(this.underlying, this.sizeLimit, trans, method, data)
return encodeMessageA(this.underlying, this.sizeLimit, trans, method, 0, data)
}
func (this *a) receive() {
@ -251,7 +251,7 @@ func (this *transA) ID() int64 {
}
func (this *transA) Send(method uint16, data []byte) error {
return this.parent.sendMessageSafe(this.id, method, data)
return this.parent.sendMessageSafe(this.id, method, 0, data)
}
func (this *transA) SendWriter(method uint16) (io.WriteCloser, error) {
@ -417,6 +417,7 @@ func (this *writerA) Write(data []byte) (n int, err error) {
}
func (this *writerA) Close() error {
this.flush(0)
this.open = false
return nil
}
@ -430,21 +431,21 @@ func (this *writerA) writeOne(data []byte) (n int, err error) {
n = len(data)
// if have a full chunk, flush
if int64(len(this.buffer)) == this.chunkSize {
err = this.flush()
err = this.flush(1)
if err != nil { return n, err }
}
return n, nil
}
// if not, flush and store as much as we can in the buffer
err = this.flush()
err = this.flush(1)
if err != nil { return n, err }
this.buffer = append(this.buffer, data...)
return n, nil
}
func (this *writerA) flush() error {
return this.parent.parent.sendMessageSafe(this.parent.id, this.method, this.buffer)
func (this *writerA) flush(ccb uint64) error {
return this.parent.parent.sendMessageSafe(this.parent.id, this.method, ccb, this.buffer)
}
type incomingMessage struct {
@ -458,15 +459,19 @@ func encodeMessageA(
sizeLimit int64,
trans int64,
method uint16,
ccb uint64,
data []byte,
) error {
if int64(len(data)) > sizeLimit {
return ErrPayloadTooLarge
}
buffer := make([]byte, 18 + len(data))
// transaction ID field
encodeI64(buffer[:8], trans)
// method field
encodeI16(buffer[8:10], method)
encodeI64(buffer[10:18], uint64(len(data)))
// payload size field
encodeI64(buffer[10:18], uint64(len(data)) & 0x7FFFFFFFFFFFFFFF | ccb << 63)
copy(buffer[18:], data)
_, err := writer.Write(buffer)
return err

View File

@ -1,11 +1,11 @@
package hopp
import "io"
import "net"
import "bytes"
import "errors"
import "slices"
import "testing"
import "context"
// some of these tests spawn goroutines that can signal a failure.
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
@ -145,7 +145,7 @@ func TestTransOpenCloseA(test *testing.T) {
func TestEncodeMessageA(test *testing.T) {
buffer := new(bytes.Buffer)
payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, payload)
err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, 0, payload)
correct := []byte {
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
0x6B, 0x12,
@ -163,7 +163,7 @@ func TestEncodeMessageA(test *testing.T) {
func TestEncodeMessageAErr(test *testing.T) {
buffer := new(bytes.Buffer)
payload := make([]byte, 0x10000)
err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, payload)
err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, 0, payload)
if !errors.Is(err, ErrPayloadTooLarge) {
test.Fatalf("wrong error: %v", err)
}
@ -208,7 +208,7 @@ func TestEncodeDecodeMessageA(test *testing.T) {
correctMethod := uint16(30)
correctPayload := []byte("good")
buffer := bytes.Buffer { }
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, correctPayload)
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, 0, correctPayload)
if err != nil { test.Fatal(err) }
transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit)
if got, correct := transID, int64(2); got != correct {
@ -230,34 +230,31 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server
addr := "localhost:7959"
// server
listener, err := net.Listen(network, addr)
if err != nil { test.Fatal(err) }
listener, err := Listen(network, addr, nil)
test.Cleanup(func() { listener.Close() })
go func() {
test.Log("SERVER listening")
conn, err := listener.Accept()
if err != nil { test.Error("SERVER", err); return }
defer conn.Close()
test.Cleanup(func() { conn.Close() })
a := AdaptA(conn, ServerSide)
test.Cleanup(func() { a.Close() })
serverFunc(a)
serverFunc(conn)
test.Log("SERVER closing")
}()
// client
test.Log("CLIENT dialing")
conn, err := net.Dial(network, addr)
conn, err := Dial(context.Background(), network, addr, nil)
if err != nil { test.Fatal("CLIENT", err) }
test.Cleanup(func() { conn.Close() })
test.Log("CLIENT dialed")
a := AdaptA(conn, ClientSide)
test.Cleanup(func() { a.Close() })
clientFunc(a)
clientFunc(conn)
test.Log("CLIENT waiting for connection close...")
trans, err := a.AcceptTrans()
trans, err := conn.AcceptTrans()
if !errors.Is(err, io.EOF) {
test.Error("CLIENT wrong error:", err)
test.Fatal("CLIENT trans:", trans)