From 779aaf3cea852d19598bd4c8b7e08be23006c72a Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Tue, 5 Mar 2024 03:15:44 -0500 Subject: [PATCH] Generate union types --- generator/type-multiplex.go | 2 + generator/type.go | 74 +++++++++++++++++++++++++++++++++++++ generator/type_test.go | 34 +++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/generator/type-multiplex.go b/generator/type-multiplex.go index ab470c0..26319a9 100644 --- a/generator/type-multiplex.go +++ b/generator/type-multiplex.go @@ -37,6 +37,8 @@ func (this *generator) generateType (ty entity.Type) (llvm.Type, error) { return this.generateTypeStruct(ty.(*entity.TypeStruct)) case *entity.TypeInterface: return this.generateTypeInterface(ty.(*entity.TypeInterface)) + case *entity.TypeUnion: + return this.generateTypeUnion(ty.(*entity.TypeUnion)) case *entity.TypeInt: return this.generateTypeInt(ty.(*entity.TypeInt)) case *entity.TypeFloat: diff --git a/generator/type.go b/generator/type.go index 4643da5..337dc8c 100644 --- a/generator/type.go +++ b/generator/type.go @@ -72,6 +72,26 @@ func (this *generator) generateTypeInterface (ty *entity.TypeInterface) (llvm.Ty return irStruct, nil } +func (this *generator) generateTypeUnion (ty *entity.TypeUnion) (llvm.Type, error) { + size := uint64(0) + for _, allowed := range ty.Allowed { + irAllowed, err := this.generateType(allowed) + if err != nil { return nil, err } + allowedSize := this.sizeOfIrType(irAllowed) + if allowedSize > size { size = allowedSize } + } + + irStruct := &llvm.TypeStruct { + Fields: []llvm.Type { + // TODO: could this field be smaller? for example what + // does rust do? + &llvm.TypeInt { BitSize: 64 }, + &llvm.TypeInt { BitSize: size }, + }, + } + return irStruct, nil +} + func (this *generator) generateTypeInt (ty *entity.TypeInt) (llvm.Type, error) { return &llvm.TypeInt { BitSize: uint64(ty.Width) }, nil } @@ -125,6 +145,60 @@ func (this *generator) generateTypeFunction ( return irFunc, nil } +func (this *generator) sizeOfIrType (ty llvm.Type) uint64 { + switch ty := ty.(type) { + case *llvm.TypeArray: + return this.alignmentScale(this.sizeOfIrType(ty.Element)) * ty.Length + case *llvm.TypeVector: + return this.alignmentScale(this.sizeOfIrType(ty.Element)) * ty.Length + case *llvm.TypeDefined: + return this.sizeOfIrType(ty.Source) + case *llvm.TypeFloat: + switch ty.Kind { + case llvm.FloatKindHalf: return 16 + case llvm.FloatKindFloat: return 32 + case llvm.FloatKindDouble: return 64 + case llvm.FloatKindFP128: return 128 + case llvm.FloatKindX86_FP80: return 80 + case llvm.FloatKindPPC_FP128: return 128 + } + case *llvm.TypeFunction: return 0 + case *llvm.TypeInt: return ty.BitSize + case *llvm.TypeLabel: return 0 + case *llvm.TypeMMX: return 0 // is this correct? + case *llvm.TypeMetadata: return 0 + case *llvm.TypePointer: return this.target.WordSize + case *llvm.TypeStruct: + // TODO ensure this is correct because it might not be + total := uint64(0) + for _, field := range ty.Fields { + fieldSize := this.sizeOfIrType(field) + if !ty.Packed { + // if not packed, align members + empty := total == 0 + fieldSize = this.alignmentScale(fieldSize) + + total /= fieldSize + if !empty && total == 0 { total ++ } + total *= fieldSize + } + total += fieldSize + } + return total + case *llvm.TypeToken: return 0 + case *llvm.TypeVoid: return 0 + } + panic(fmt.Sprintln("generator doesn't know about LLVM type", ty)) +} + +// alignmentScale returns the smallest power of two that is greater than or +// equal to size. Note that it starts at 8. +func (this *generator) alignmentScale (size uint64) uint64 { + scale := uint64(8) + for size > scale { scale *= 2 } + return scale +} + func getInterface (ty entity.Type) *entity.TypeInterface { switch ty.(type) { case *entity.TypeNamed: diff --git a/generator/type_test.go b/generator/type_test.go index dea47cd..f62c585 100644 --- a/generator/type_test.go +++ b/generator/type_test.go @@ -129,3 +129,37 @@ B: A } `) } + +func TestTypeUnion (test *testing.T) { +testString (test, +`%"0zNZN147MN2wzMAQ6NS2dQ==::U" = type { i64, i64 } +%"0zNZN147MN2wzMAQ6NS2dQ==::SmallU" = type { i64, i16 } +%"0zNZN147MN2wzMAQ6NS2dQ==::Padded" = type { i8, i16 } +%"0zNZN147MN2wzMAQ6NS2dQ==::PaddedU" = type { i64, i32 } +%"0zNZN147MN2wzMAQ6NS2dQ==::Point" = type { i64, i64 } +%"0zNZN147MN2wzMAQ6NS2dQ==::Error" = type { ptr, ptr } +%"0zNZN147MN2wzMAQ6NS2dQ==::PointOrError" = type { i64, i128 } +define void @"0zNZN147MN2wzMAQ6NS2dQ==::main"() { +0: + %1 = alloca %"0zNZN147MN2wzMAQ6NS2dQ==::U" + %2 = alloca %"0zNZN147MN2wzMAQ6NS2dQ==::SmallU" + %3 = alloca %"0zNZN147MN2wzMAQ6NS2dQ==::PaddedU" + %4 = alloca %"0zNZN147MN2wzMAQ6NS2dQ==::PointOrError" + ret void +} +`, +` +U: (| I8 I16 I32 I64 Int) +SmallU: (| I8 U8 I16 U16) +Point: (. x:Int y:Int) +Padded: (. a:I8 b:I16) +PaddedU: (| Padded) +Error: (~ [error]:String) +PointOrError: (| Point Error) +[main] = { + u:U + su:SmallU + pu:PaddedU + pe:PointOrError +} +`)}