Compare commits
9 Commits
14a317c2ab
...
3daa66c4bc
| Author | SHA1 | Date | |
|---|---|---|---|
| 3daa66c4bc | |||
| c5154b3d85 | |||
| c2ce95021c | |||
| d4ccdb282e | |||
| 2e4c693174 | |||
| c9480ba016 | |||
| 09b2259a8c | |||
| da01a0d119 | |||
| c326a2b6b9 |
62
dial.go
62
dial.go
@ -5,53 +5,59 @@ import "errors"
|
||||
import "context"
|
||||
import "crypto/tls"
|
||||
|
||||
// Dial opens a connection to a server. The network must be one of "quic",
|
||||
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and
|
||||
// "quic6" don't do anything as the quic-go package doesn't seem to support this
|
||||
// behavior.
|
||||
func Dial(ctx context.Context, network, address string) (Conn, error) {
|
||||
return (Dialer { }).Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// Dialer allows for further configuration of the dialing process.
|
||||
type Dialer struct {
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// Dial opens a connection to a server. The network must be one of:
|
||||
//
|
||||
// - "quic"
|
||||
// - "quic4" (IPv4-only)
|
||||
// - "quic6" (IPv6-only)
|
||||
// - "tls"
|
||||
// - "tls4" (IPv4-only)
|
||||
// - "tls6" (IPv6-only)
|
||||
// - "tcp"
|
||||
// - "tcp4" (IPv4-only)
|
||||
// - "tcp6" (IPv6-only)
|
||||
// - "unix"
|
||||
//
|
||||
// For now, QUIC is unsupported.
|
||||
func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) {
|
||||
func Dial(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||
switch network {
|
||||
case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address)
|
||||
case "tcp", "tcp4", "tcp6": return diale.dialTLS(ctx, network, address)
|
||||
case "unix": return diale.dialUnix(ctx, network, address)
|
||||
case "quic", "quic4", "quic6": return DialQUIC(ctx, network, address, tlsConf)
|
||||
case "tls", "tls4", "tls6": return DialTLS(ctx, network, address, tlsConf)
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil { return nil, err }
|
||||
return DialTCP(ctx, network, nil, addr)
|
||||
case "unix":
|
||||
addr, err := net.ResolveUnixAddr(network, address)
|
||||
if err != nil { return nil, err }
|
||||
return DialUnix(ctx, network, addr)
|
||||
default: return nil, ErrUnknownNetwork
|
||||
}
|
||||
}
|
||||
|
||||
func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) {
|
||||
// DialQUIC opens a connection to a server over QUIC.
|
||||
func DialQUIC(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||
return nil, errors.New("quic is not yet implemented")
|
||||
}
|
||||
|
||||
func (diale Dialer) dialTLS(ctx context.Context, network, address string) (Conn, error) {
|
||||
conn, err := tls.Dial(network, address, diale.TLSConfig)
|
||||
// DialTLS opens a connection to a server over TLS.
|
||||
func DialTLS(ctx context.Context, network, address string, tlsConf *tls.Config) (Conn, error) {
|
||||
network, err := tlsNetworkToTCPNetwork(network)
|
||||
if err != nil { return nil, err }
|
||||
conn, err := tls.Dial(network, address, tlsConf)
|
||||
if err != nil { return nil, err }
|
||||
return AdaptA(conn, ClientSide), nil
|
||||
}
|
||||
|
||||
func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) {
|
||||
if network != "unix" { return nil, ErrUnknownNetwork }
|
||||
addr, err := net.ResolveUnixAddr(network, address)
|
||||
// DialTCP opens a connection to a server over TCP.
|
||||
func DialTCP(ctx context.Context, network string, laddr, raddr *net.TCPAddr) (Conn, error) {
|
||||
conn, err := net.DialTCP(network, laddr, raddr)
|
||||
if err != nil { return nil, err }
|
||||
return AdaptA(conn, ClientSide), nil
|
||||
}
|
||||
|
||||
// DialUnix opens a connection to a server over a Unix domain socket.
|
||||
func DialUnix(ctx context.Context, network string, addr *net.UnixAddr) (Conn, error) {
|
||||
conn, err := net.DialUnix(network, nil, addr)
|
||||
if err != nil { return nil, err }
|
||||
return AdaptA(conn, ClientSide), nil
|
||||
@ -69,7 +75,6 @@ func tlsConfig(conf *tls.Config) *tls.Config {
|
||||
return conf
|
||||
}
|
||||
|
||||
|
||||
func quicNetworkToUDPNetwork(network string) (string, error) {
|
||||
switch network {
|
||||
case "quic4": return "udp4", nil
|
||||
@ -78,3 +83,12 @@ func quicNetworkToUDPNetwork(network string) (string, error) {
|
||||
default: return "", ErrUnknownNetwork
|
||||
}
|
||||
}
|
||||
|
||||
func tlsNetworkToTCPNetwork(network string) (string, error) {
|
||||
switch network {
|
||||
case "tls4": return "tcp4", nil
|
||||
case "tls6": return "tcp6", nil
|
||||
case "tls": return "tcp", nil
|
||||
default: return "", ErrUnknownNetwork
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,13 +47,10 @@ func main() {
|
||||
func join(address string, room string, nickname string) (hopp.Trans, error) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 16 * time.Second)
|
||||
defer done()
|
||||
dialer := hopp.Dialer {
|
||||
TLSConfig: &tls.Config {
|
||||
conn, err := hopp.Dial(ctx, "tls", address, &tls.Config {
|
||||
// don't actually do this in real life
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
conn, err := dialer.Dial(ctx, "tcp", address)
|
||||
})
|
||||
if err != nil { return nil, err }
|
||||
|
||||
transRoom, err := conn.OpenTrans()
|
||||
|
||||
@ -5,6 +5,7 @@ package chat
|
||||
// Please edit that file instead, and re-compile it to this location.
|
||||
// HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
|
||||
|
||||
import "fmt"
|
||||
import "git.tebibyte.media/sashakoshka/hopp"
|
||||
import "git.tebibyte.media/sashakoshka/hopp/tape"
|
||||
|
||||
@ -26,7 +27,9 @@ func Send(trans hopp.Trans, message Message) (n int, err error) {
|
||||
if err != nil { return n, err }
|
||||
defer writer.Close()
|
||||
encoder := tape.NewEncoder(writer)
|
||||
return message.Encode(encoder)
|
||||
n, err = message.Encode(encoder)
|
||||
if err != nil { return n, err }
|
||||
return n, encoder.Flush()
|
||||
}
|
||||
|
||||
// canAssign determines if data from the given source tag can be assigned to
|
||||
@ -81,29 +84,29 @@ func(this *MessageError) Encode(encoder *tape.Encoder) (n int, err error) {
|
||||
nn, err = encoder.WriteUintN(2, tag_1.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
{
|
||||
nn, err = encoder.WriteUint16(0x0000)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_2 := tape.LI.WithCN(1)
|
||||
nn, err = encoder.WriteUint8(uint8(tag_2))
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = encoder.WriteUint16(uint16((*this).Code))
|
||||
n += nn; if err != nil { return n, err }
|
||||
if value, ok := (*this).Description.Value(); ok {
|
||||
nn, err = encoder.WriteUint16(0x0001)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_3 := tape.StringTag(string(value))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_3))
|
||||
tag_2 := tape.StringTag(string(value))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_2))
|
||||
n += nn; if err != nil { return n, err }
|
||||
if len(value) > tape.MaxStructureLength {
|
||||
return n, tape.ErrTooLong
|
||||
}
|
||||
if tag_3.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len(value)), tag_3.CN())
|
||||
if tag_2.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len(value)), tag_2.CN())
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
nn, err = encoder.Write([]byte(value))
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
nn, err = encoder.WriteUint16(0x0000)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_3 := tape.LI.WithCN(1)
|
||||
nn, err = encoder.WriteUint8(uint8(tag_3))
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = encoder.WriteUint16(uint16((*this).Code))
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@ -251,34 +254,34 @@ func(this *MessageChat) Encode(encoder *tape.Encoder) (n int, err error) {
|
||||
nn, err = encoder.WriteUintN(2, tag_8.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
{
|
||||
nn, err = encoder.WriteUint16(0x0001)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_9 := tape.StringTag(string((*this).Nickname))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_9))
|
||||
n += nn; if err != nil { return n, err }
|
||||
if len((*this).Nickname) > tape.MaxStructureLength {
|
||||
return n, tape.ErrTooLong
|
||||
}
|
||||
if tag_9.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len((*this).Nickname)), tag_9.CN())
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
nn, err = encoder.Write([]byte((*this).Nickname))
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = encoder.WriteUint16(0x0000)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_10 := tape.StringTag(string((*this).Content))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_10))
|
||||
tag_9 := tape.StringTag(string((*this).Content))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_9))
|
||||
n += nn; if err != nil { return n, err }
|
||||
if len((*this).Content) > tape.MaxStructureLength {
|
||||
return n, tape.ErrTooLong
|
||||
}
|
||||
if tag_10.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len((*this).Content)), tag_10.CN())
|
||||
if tag_9.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len((*this).Content)), tag_9.CN())
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
nn, err = encoder.Write([]byte((*this).Content))
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = encoder.WriteUint16(0x0001)
|
||||
n += nn; if err != nil { return n, err }
|
||||
tag_10 := tape.StringTag(string((*this).Nickname))
|
||||
nn, err = encoder.WriteUint8(uint8(tag_10))
|
||||
n += nn; if err != nil { return n, err }
|
||||
if len((*this).Nickname) > tape.MaxStructureLength {
|
||||
return n, tape.ErrTooLong
|
||||
}
|
||||
if tag_10.Is(tape.LBA) {
|
||||
nn, err = encoder.WriteUintN(uint64(len((*this).Nickname)), tag_10.CN())
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
nn, err = encoder.Write([]byte((*this).Nickname))
|
||||
n += nn; if err != nil { return n, err }
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@ -695,36 +698,36 @@ func Receive(trans hopp.Trans) (message any, n int, err error) {
|
||||
decoder := tape.NewDecoder(reader)
|
||||
if err != nil { return nil, n, err }
|
||||
switch method {
|
||||
case 0001:
|
||||
var message MessageSuccess
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0200:
|
||||
var message MessageJoin
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0300:
|
||||
var message MessageChat
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0400:
|
||||
var message MessageJoinNotify
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0401:
|
||||
case 0x0401:
|
||||
var message MessageLeaveNotify
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0000:
|
||||
case 0x0000:
|
||||
var message MessageError
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0x0001:
|
||||
var message MessageSuccess
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0x0200:
|
||||
var message MessageJoin
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0x0300:
|
||||
var message MessageChat
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
case 0x0400:
|
||||
var message MessageJoinNotify
|
||||
nn, err := message.Decode(decoder)
|
||||
n += nn; if err != nil { return nil, n, err }
|
||||
return message, n, nil
|
||||
}
|
||||
return nil, n, hopp.ErrUnknownMethod
|
||||
return nil, n, fmt.Errorf("%w: M%04X", hopp.ErrUnknownMethod, method)
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ func main() {
|
||||
func host(address string, certPath, keyPath string) error {
|
||||
keyPair, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil { return err }
|
||||
listener, err := hopp.ListenTLS("tcp", address, &tls.Config {
|
||||
listener, err := hopp.Listen("tls", address, &tls.Config {
|
||||
InsecureSkipVerify: true,
|
||||
Certificates: []tls.Certificate { keyPair },
|
||||
})
|
||||
@ -60,11 +60,11 @@ func (this *client) run() {
|
||||
for {
|
||||
log.Println("accepting transaction")
|
||||
trans, err := this.conn.AcceptTrans()
|
||||
log.Println("accepted transaction")
|
||||
if err != nil {
|
||||
log.Printf("XXX %v failed: %v", this.conn.RemoteAddr(), err)
|
||||
continue
|
||||
return
|
||||
}
|
||||
log.Println("accepted transaction")
|
||||
go this.runTrans(trans)
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import "git.tebibyte.media/sashakoshka/hopp/tape"
|
||||
|
||||
const imports =
|
||||
`
|
||||
import "fmt"
|
||||
import "git.tebibyte.media/sashakoshka/hopp"
|
||||
import "git.tebibyte.media/sashakoshka/hopp/tape"
|
||||
`
|
||||
@ -41,7 +42,9 @@ func Send(trans hopp.Trans, message Message) (n int, err error) {
|
||||
if err != nil { return n, err }
|
||||
defer writer.Close()
|
||||
encoder := tape.NewEncoder(writer)
|
||||
return message.Encode(encoder)
|
||||
n, err = message.Encode(encoder)
|
||||
if err != nil { return n, err }
|
||||
return n, encoder.Flush()
|
||||
}
|
||||
|
||||
// canAssign determines if data from the given source tag can be assigned to
|
||||
@ -1216,7 +1219,7 @@ func (this *Generator) generateReceive() (n int, err error) {
|
||||
nn, err = this.iprintf("switch method {\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
for method, message := range this.protocol.Messages {
|
||||
nn, err = this.iprintf("case %04X:\n", method)
|
||||
nn, err = this.iprintf("case 0x%04X:\n", method)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf(
|
||||
@ -1234,7 +1237,7 @@ func (this *Generator) generateReceive() (n int, err error) {
|
||||
}
|
||||
nn, err = this.iprint("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("return nil, n, hopp.ErrUnknownMethod\n")
|
||||
nn, err = this.iprint("return nil, n, fmt.Errorf(\"%w: M%04X\", hopp.ErrUnknownMethod, method)\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.pop()
|
||||
nn, err = this.iprint("}\n")
|
||||
|
||||
56
internal/connshark/connshark.go
Normal file
56
internal/connshark/connshark.go
Normal file
@ -0,0 +1,56 @@
|
||||
package connshark
|
||||
|
||||
import "os"
|
||||
import "io"
|
||||
import "fmt"
|
||||
import "net"
|
||||
import "log"
|
||||
import "sync"
|
||||
import "math/rand"
|
||||
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
|
||||
|
||||
type insert struct {
|
||||
net.Conn
|
||||
output io.WriteCloser
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func LogDebugFile(underlying net.Conn) net.Conn {
|
||||
file, err := os.Create(fmt.Sprintf("connshark-%08X.log", rand.Uint32()))
|
||||
if err != nil {
|
||||
log.Println("XXX COULD NOT OPEN DEBUG FILE! reason: ", err)
|
||||
return underlying
|
||||
}
|
||||
return Log(underlying, file)
|
||||
}
|
||||
|
||||
func Log(underlying net.Conn, output io.WriteCloser) net.Conn {
|
||||
return &insert {
|
||||
Conn: underlying,
|
||||
output: output,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *insert) Read(buffer []byte) (n int, err error) {
|
||||
if n > 0 {
|
||||
this.lock.Lock()
|
||||
defer this.lock.Unlock()
|
||||
fmt.Fprintf(this.output, "TX: %s\n", tu.HexBytes(buffer[:n]))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (this *insert) Write(buffer []byte) (n int, err error) {
|
||||
n, err = this.Conn.Write(buffer)
|
||||
if n > 0 {
|
||||
this.lock.Lock()
|
||||
defer this.lock.Unlock()
|
||||
fmt.Fprintf(this.output, "RX: %s\n", tu.HexBytes(buffer[:n]))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (this *insert) Close() error {
|
||||
this.output.Close()
|
||||
return this.Conn.Close()
|
||||
}
|
||||
35
listen.go
35
listen.go
@ -20,6 +20,9 @@ type Listener interface {
|
||||
// - "quic"
|
||||
// - "quic4" (IPv4-only)
|
||||
// - "quic6" (IPv6-only)
|
||||
// - "tls"
|
||||
// - "tls4" (IPv4-only)
|
||||
// - "tls6" (IPv6-only)
|
||||
// - "tcp"
|
||||
// - "tcp4" (IPv4-only)
|
||||
// - "tcp6" (IPv6-only)
|
||||
@ -29,9 +32,17 @@ type Listener interface {
|
||||
func Listen(network, address string, tlsConf *tls.Config) (Listener, error) {
|
||||
switch network {
|
||||
case "quic", "quic4", "quic6": return ListenQUIC(network, address, tlsConf)
|
||||
case "tcp", "tcp4", "tcp6": return ListenTLS(network, address, tlsConf)
|
||||
case "unix": return ListenUnix(network, address)
|
||||
default: return nil, ErrUnknownNetwork
|
||||
case "tls", "tls4", "tls6": return ListenTLS(network, address, tlsConf)
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil { return nil, err }
|
||||
return ListenTCP(network, addr)
|
||||
case "unix":
|
||||
addr, err := net.ResolveUnixAddr(network, address)
|
||||
if err != nil { return nil, err }
|
||||
return ListenUnix(network, addr)
|
||||
default:
|
||||
return nil, ErrUnknownNetwork
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,7 +57,8 @@ func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error)
|
||||
// ListenTLS listens for incoming HOPP connections using a TLS socket as a
|
||||
// transport. The network must be "tcp".
|
||||
func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
|
||||
if network != "tcp" { return nil, ErrUnknownNetwork }
|
||||
network, err := tlsNetworkToTCPNetwork(network)
|
||||
if err != nil { return nil, err }
|
||||
listener, err := tls.Listen(network, address, tlsConf)
|
||||
if err != nil { return nil, err }
|
||||
return &netListenerWrapper {
|
||||
@ -54,12 +66,19 @@ func ListenTLS(network, address string, tlsConf *tls.Config) (Listener, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListenTCP listens for incoming HOPP connections using a TCP socket as a
|
||||
// transport. The network must be "tcp".
|
||||
func ListenTCP(network string, laddr *net.TCPAddr) (Listener, error) {
|
||||
listener, err := net.ListenTCP(network, laddr)
|
||||
if err != nil { return nil, err }
|
||||
return &netListenerWrapper {
|
||||
underlying: listener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListenUnix listens for incoming HOPP connections using a Unix domain socket
|
||||
// as a transport. The network must be "unix".
|
||||
func ListenUnix(network, address string) (Listener, error) {
|
||||
if network != "unix" { return nil, ErrUnknownNetwork }
|
||||
addr, err := net.ResolveUnixAddr(network, address)
|
||||
if err != nil { return nil, err }
|
||||
func ListenUnix(network string, addr *net.UnixAddr) (Listener, error) {
|
||||
listener, err := net.ListenUnix(network, addr)
|
||||
if err != nil { return nil, err }
|
||||
return &netListenerWrapper {
|
||||
|
||||
21
metadapta.go
21
metadapta.go
@ -124,10 +124,10 @@ func (this *a) unlistTransactionSafe(id int64) {
|
||||
delete(this.transMap, id)
|
||||
}
|
||||
|
||||
func (this *a) sendMessageSafe(trans int64, method uint16, data []byte) error {
|
||||
func (this *a) sendMessageSafe(trans int64, method uint16, ccb uint64, data []byte) error {
|
||||
this.sendLock.Lock()
|
||||
defer this.sendLock.Unlock()
|
||||
return encodeMessageA(this.underlying, this.sizeLimit, trans, method, data)
|
||||
return encodeMessageA(this.underlying, this.sizeLimit, trans, method, 0, data)
|
||||
}
|
||||
|
||||
func (this *a) receive() {
|
||||
@ -251,7 +251,7 @@ func (this *transA) ID() int64 {
|
||||
}
|
||||
|
||||
func (this *transA) Send(method uint16, data []byte) error {
|
||||
return this.parent.sendMessageSafe(this.id, method, data)
|
||||
return this.parent.sendMessageSafe(this.id, method, 0, data)
|
||||
}
|
||||
|
||||
func (this *transA) SendWriter(method uint16) (io.WriteCloser, error) {
|
||||
@ -417,6 +417,7 @@ func (this *writerA) Write(data []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (this *writerA) Close() error {
|
||||
this.flush(0)
|
||||
this.open = false
|
||||
return nil
|
||||
}
|
||||
@ -430,21 +431,21 @@ func (this *writerA) writeOne(data []byte) (n int, err error) {
|
||||
n = len(data)
|
||||
// if have a full chunk, flush
|
||||
if int64(len(this.buffer)) == this.chunkSize {
|
||||
err = this.flush()
|
||||
err = this.flush(1)
|
||||
if err != nil { return n, err }
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// if not, flush and store as much as we can in the buffer
|
||||
err = this.flush()
|
||||
err = this.flush(1)
|
||||
if err != nil { return n, err }
|
||||
this.buffer = append(this.buffer, data...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (this *writerA) flush() error {
|
||||
return this.parent.parent.sendMessageSafe(this.parent.id, this.method, this.buffer)
|
||||
func (this *writerA) flush(ccb uint64) error {
|
||||
return this.parent.parent.sendMessageSafe(this.parent.id, this.method, ccb, this.buffer)
|
||||
}
|
||||
|
||||
type incomingMessage struct {
|
||||
@ -458,15 +459,19 @@ func encodeMessageA(
|
||||
sizeLimit int64,
|
||||
trans int64,
|
||||
method uint16,
|
||||
ccb uint64,
|
||||
data []byte,
|
||||
) error {
|
||||
if int64(len(data)) > sizeLimit {
|
||||
return ErrPayloadTooLarge
|
||||
}
|
||||
buffer := make([]byte, 18 + len(data))
|
||||
// transaction ID field
|
||||
encodeI64(buffer[:8], trans)
|
||||
// method field
|
||||
encodeI16(buffer[8:10], method)
|
||||
encodeI64(buffer[10:18], uint64(len(data)))
|
||||
// payload size field
|
||||
encodeI64(buffer[10:18], uint64(len(data)) & 0x7FFFFFFFFFFFFFFF | ccb << 63)
|
||||
copy(buffer[18:], data)
|
||||
_, err := writer.Write(buffer)
|
||||
return err
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
package hopp
|
||||
|
||||
import "io"
|
||||
import "net"
|
||||
import "bytes"
|
||||
import "errors"
|
||||
import "slices"
|
||||
import "testing"
|
||||
import "context"
|
||||
|
||||
// some of these tests spawn goroutines that can signal a failure.
|
||||
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
|
||||
@ -145,7 +145,7 @@ func TestTransOpenCloseA(test *testing.T) {
|
||||
func TestEncodeMessageA(test *testing.T) {
|
||||
buffer := new(bytes.Buffer)
|
||||
payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
|
||||
err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, payload)
|
||||
err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, 0, payload)
|
||||
correct := []byte {
|
||||
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
|
||||
0x6B, 0x12,
|
||||
@ -163,7 +163,7 @@ func TestEncodeMessageA(test *testing.T) {
|
||||
func TestEncodeMessageAErr(test *testing.T) {
|
||||
buffer := new(bytes.Buffer)
|
||||
payload := make([]byte, 0x10000)
|
||||
err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, payload)
|
||||
err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, 0, payload)
|
||||
if !errors.Is(err, ErrPayloadTooLarge) {
|
||||
test.Fatalf("wrong error: %v", err)
|
||||
}
|
||||
@ -208,7 +208,7 @@ func TestEncodeDecodeMessageA(test *testing.T) {
|
||||
correctMethod := uint16(30)
|
||||
correctPayload := []byte("good")
|
||||
buffer := bytes.Buffer { }
|
||||
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, correctPayload)
|
||||
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, 0, correctPayload)
|
||||
if err != nil { test.Fatal(err) }
|
||||
transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit)
|
||||
if got, correct := transID, int64(2); got != correct {
|
||||
@ -230,34 +230,31 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server
|
||||
addr := "localhost:7959"
|
||||
|
||||
// server
|
||||
listener, err := net.Listen(network, addr)
|
||||
if err != nil { test.Fatal(err) }
|
||||
listener, err := Listen(network, addr, nil)
|
||||
test.Cleanup(func() { listener.Close() })
|
||||
go func() {
|
||||
test.Log("SERVER listening")
|
||||
conn, err := listener.Accept()
|
||||
if err != nil { test.Error("SERVER", err); return }
|
||||
|
||||
defer conn.Close()
|
||||
test.Cleanup(func() { conn.Close() })
|
||||
a := AdaptA(conn, ServerSide)
|
||||
test.Cleanup(func() { a.Close() })
|
||||
|
||||
serverFunc(a)
|
||||
serverFunc(conn)
|
||||
test.Log("SERVER closing")
|
||||
}()
|
||||
|
||||
// client
|
||||
test.Log("CLIENT dialing")
|
||||
conn, err := net.Dial(network, addr)
|
||||
conn, err := Dial(context.Background(), network, addr, nil)
|
||||
if err != nil { test.Fatal("CLIENT", err) }
|
||||
test.Cleanup(func() { conn.Close() })
|
||||
test.Log("CLIENT dialed")
|
||||
a := AdaptA(conn, ClientSide)
|
||||
test.Cleanup(func() { a.Close() })
|
||||
|
||||
clientFunc(a)
|
||||
clientFunc(conn)
|
||||
|
||||
test.Log("CLIENT waiting for connection close...")
|
||||
trans, err := a.AcceptTrans()
|
||||
trans, err := conn.AcceptTrans()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
test.Error("CLIENT wrong error:", err)
|
||||
test.Fatal("CLIENT trans:", trans)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user