hopp/generate/generate.go

617 lines
20 KiB
Go

package generate
import "io"
import "fmt"
import "maps"
import "math"
import "slices"
import "strings"
import "git.tebibyte.media/sashakoshka/hopp/tape"
const imports =
`
import "git.teibibyte.media/sashakoshka/hopp/tape"
`
const preamble = `
/* # Do not edit this package by hand!
*
* This file was automatically generated by the Holanet PDL compiler. The
* source file is located at <path>
* Please edit that file instead, and re-compile it to this location.
*
* HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
*/
`
const static = `
// Table is a KTV table with an undefined schema.
type Table map[uint16] any
// Message is any message that can be sent along this protocol.
type Message interface {
tape.Encodable
tape.Decodable
// Method returns the method code of the message.
Method() uint16
}
`
// Generator converts protocols into Go code.
type Generator struct {
// Output is where the generated code will be sent.
Output io.Writer
// PackageName is the package name that will be used in the file. If
// left empty, the default is "protocol".
PackageName string
nestingLevel int
protocol *Protocol
}
func (this *Generator) Generate(protocol *Protocol) (n int, err error) {
this.nestingLevel = 0
this.protocol = protocol
defer func() { this.protocol = nil }()
// preamble and static section
packageName := "protocol"
if this.PackageName != "" {
packageName = this.PackageName
}
nn, err := this.iprintf("package %s\n", packageName)
n += nn; if err != nil { return n, err }
nn, err = this.print(preamble)
n += nn; if err != nil { return n, err }
nn, err = this.print(imports)
n += nn; if err != nil { return n, err }
nn, err = this.print(static)
n += nn; if err != nil { return n, err }
// type definitions
for _, name := range slices.Sorted(maps.Keys(protocol.Types)) {
nn, err := this.generateTypedef(name, protocol.Types[name])
n += nn; if err != nil { return n, err }
}
// messages
for _, method := range slices.Sorted(maps.Keys(protocol.Messages)) {
nn, err := this.generateMessage(method, protocol.Messages[method])
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) generateTypedef(name string, typ Type) (n int, err error) {
// type definition
nn, err := this.iprintf(
"\n// %s represents the protocol data type %s.\n",
name, name)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("type %s ", name)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
// Tag method
// to be honest we probably don't need this method at all
// nn, err = this.iprintf("\n// Tag returns the preferred TAPE tag.\n")
// n += nn; if err != nil { return n, err }
// nn, err = this.iprintf("func (this *%s) Tag() tape.Tag {\n", name)
// n += nn; if err != nil { return n, err }
// this.push()
// nn, err = this.iprintf("return ")
// n += nn; if err != nil { return n, err }
// nn, err = this.generateTag(typ, "(*this)")
// n += nn; if err != nil { return n, err }
// nn, err = this.println()
// n += nn; if err != nil { return n, err }
// this.pop()
// nn, err = this.iprintf("}\n")
// n += nn; if err != nil { return n, err }
// EncodeValue method
nn, err = this.iprintf(
"\n// EncodeValue encodes the value of this type without the " +
"tag. The value is\n// encoded according to the parameters " +
"specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) EncodeValue(encoder *tape.Encoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var nn int\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(typ, "(*this)", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// DecodeValue method
nn, err = this.iprintf(
"\n // DecodeValue decodes the value of this type without " +
"the tag. The value is\n// decoded according to the " +
"parameters specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) DecodeValue(decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
// TODO
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
return n, nil
}
func (this *Generator) generateMessage(method uint16, message Message) (n int, err error) {
nn, err := this.iprintf(
"\n// %s represents the protocol message M%04X %s.\n",
message.Name, method, message.Name)
nn, err = this.iprintf("type %s ", this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
nn, err = this.generateType(message.Type)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
// Encode method
nn, err = this.iprintf("\n// Encode encodes this message's tag and value.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func(this %s) Encode(encoder *tape.Encoder) (n int, err error) {\n",
this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("tag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTag(message.Type, "(*this)")
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err := encoder.WriteUint8()\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(message.Type, "(*this)", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// TODO decode method
return n, nil
}
func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource string) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
// SI: (none)
// LI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag, so we write nothing here
break
}
nn, err := this.iprintf("nn, err = encoder.WriteInt%d(%s)\n", bitsToBytes(typ.Bits), valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeFloat:
// FP: <value: FloatN>
nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(%s)\n", bitsToBytes(typ.Bits), valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeString:
// see TypeBuffer
nn, err := this.generateEncodeValue(TypeBuffer { }, valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), uint64(len(%s)))\n",
tagSource, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n", tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), uint64(len(%s)))\n",
tagSource, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("itemTag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTN(typ.Element)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("tag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTag(typ.Element, "item")
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag.Is(tape.SBA) { continue }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag.CN() > itemTag.CN() { largest = tag }\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if itemTag.Is(tape.SBA) { itemTag += 1 << 5 }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.generateEncodeValue(typ.Element, "item", "itemTag")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), uint64(len(%s)))\n",
tagSource, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for key, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("nn, err = encoder.WriteUint16(key)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag := tape.TagAny(tag)\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = tape.EncodeAny(tag)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), %d)\n",
tagSource, len(typ.Fields))
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var tag tape.Tag\n")
n += nn; if err != nil { return n, err }
for key, field := range typ.Fields {
nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag = ")
n += nn; if err != nil { return n, err }
fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name)
nn, err = this.generateTag(field.Type, fieldSource)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(field.Type, fieldSource, "tag")
n += nn; if err != nil { return n, err }
}
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeNamed:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = %s.EncodeValue(encoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) generateErrorCheck() (n int, err error) {
return this.iprintf("n += nn; if err != nil { return n, err }\n")
}
// generateTag generates the preferred TN and CN for the given type and value.
// The generated code is INLINE.
func (this *Generator) generateTag(typ Type, source string) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
if typ.Bits <= 5 {
nn, err := this.printf("tape.TagSI")
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("tape.TagLI.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err }
}
case TypeFloat:
nn, err := this.printf("tape.TagFP.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.generateTag(TypeBuffer { }, source)
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("bufferTag(%s)", source)
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("arrayTag(tape.TagOTA.WithCN(tape.IntBytes(uint64(len(%s))))", source)
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.printf("tape.TagKTV.WithCN(tape.IntBytes(uint64(len(%s))))", source)
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.printf("tape.TagKTV.WithCN(%d)", tape.IntBytes(uint64(len(typ.Fields))))
n += nn; if err != nil { return n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateTag(resolved, source)
n += nn; if err != nil { return n, err }
}
return n, nil
}
// generateTN generates the appropriate TN for the given type. The generated
// code is INLINE. The generated tag will have a CN as zero. For types that
// change TN based on their length, the TN capable of supporting more
// information is chosen.
func (this *Generator) generateTN(typ Type) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
if typ.Bits <= 5 {
nn, err := this.printf("tape.TagSI")
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("tape.TagLI")
n += nn; if err != nil { return n, err }
}
case TypeFloat:
nn, err := this.printf("tape.TagFP",)
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.generateTN(TypeBuffer { })
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("tape.TagLBA")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("tape.TagOTA")
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.printf("tape.TagKTV")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.printf("tape.TagKTV")
n += nn; if err != nil { return n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateTN(resolved)
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) generateType(typ Type) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
if err := this.validateIntBitSize(typ.Bits); err != nil {
return n, err
}
if typ.Signed {
nn, err := this.printf("int%d", typ.Bits)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("uint%d", typ.Bits)
n += nn; if err != nil { return n, err }
}
case TypeFloat:
switch typ.Bits {
case 16:
nn, err := this.print("float32")
n += nn; if err != nil { return n, err }
case 32, 64:
nn, err := this.printf("float%d", typ.Bits)
n += nn; if err != nil { return n, err }
default:
return n, fmt.Errorf("floats of size %d are unsupported on this platform", typ.Bits)
}
case TypeString:
nn, err := this.print("string")
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.print("[]byte")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.print("[]")
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ.Element)
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.print("Table")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.generateTypeTableDefined(typ)
n += nn; if err != nil { return n, err }
case TypeNamed:
actual, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateType(actual)
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, err error) {
nn, err := this.print("struct {\n")
n += nn; if err != nil { return n, err }
this.push()
for _, key := range slices.Sorted(maps.Keys(typ.Fields)) {
field := typ.Fields[key]
nn, err := this.iprintf("%s ", field.Name)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(field.Type)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
}
this.pop()
nn, err = this.iprint("}")
n += nn; if err != nil { return n, err }
return n, nil
}
func (this *Generator) validateIntBitSize(size int) error {
switch size {
case 8, 16, 32, 64: return nil
default: return fmt.Errorf("integers of size %d are unsupported on this platform", size)
}
}
func (this *Generator) validateFloatBitSize(size int) error {
switch size {
case 16, 32, 64: return nil
default: return fmt.Errorf("floats of size %d are unsupported on this platform", size)
}
}
func (this *Generator) push() {
this.nestingLevel ++
}
func (this *Generator) pop() {
if this.nestingLevel < 1 {
panic("cannot pop when nesting level is less than 1")
}
this.nestingLevel --
}
func (this *Generator) indent() string {
return strings.Repeat("\t", this.nestingLevel)
}
func (this *Generator) print(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, args...)
}
func (this *Generator) println(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, args...)
}
func (this *Generator) printf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, format, args...)
}
func (this *Generator) iprint(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintln(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, this.indent() + format, args...)
}
func (this *Generator) resolveMessageName(message string) string {
return "Message" + message
}
func (this *Generator) resolveTypeName(name string) (Type, error) {
switch name {
case "U8": return TypeInt { Bits: 8 }, nil
case "U16": return TypeInt { Bits: 16 }, nil
case "U32": return TypeInt { Bits: 32 }, nil
case "U64": return TypeInt { Bits: 64 }, nil
case "U128": return TypeInt { Bits: 128 }, nil
case "U256": return TypeInt { Bits: 256 }, nil
case "I8": return TypeInt { Bits: 8, Signed: true }, nil
case "I16": return TypeInt { Bits: 16, Signed: true }, nil
case "I32": return TypeInt { Bits: 32, Signed: true }, nil
case "I64": return TypeInt { Bits: 64, Signed: true }, nil
case "I128": return TypeInt { Bits: 128, Signed: true }, nil
case "I256": return TypeInt { Bits: 256, Signed: true }, nil
case "F16": return TypeFloat { Bits: 16 }, nil
case "F32": return TypeFloat { Bits: 32 }, nil
case "F64": return TypeFloat { Bits: 64 }, nil
case "F128": return TypeFloat { Bits: 128 }, nil
case "F256": return TypeFloat { Bits: 256 }, nil
case "String": return TypeString { }, nil
case "Buffer": return TypeBuffer { }, nil
case "Table": return TypeTable { }, nil
}
if typ, ok := this.protocol.Types[name]; ok {
if typ, ok := typ.(TypeNamed); ok {
return this.resolveTypeName(typ.Name)
}
return typ, nil
}
return nil, fmt.Errorf("no type exists called %s", name)
}
func bitsToBytes(bits int) int {
return int(math.Ceil(float64(bits) / 8.0))
}
func bitsToCN(bits int) int {
return bitsToBytes(bits) - 1
}