diff --git a/llvm/type.go b/llvm/type.go index 6bba3f3..db8bb22 100644 --- a/llvm/type.go +++ b/llvm/type.go @@ -8,6 +8,7 @@ type Type interface { LLString () string Name () string SetName (name string) + Equals (Type) bool } type AbstractType struct { @@ -40,6 +41,10 @@ func (this *TypeDefined) String () string { return this.LLString() } +func (this *TypeDefined) Equals (ty Type) bool { + return this.Source.Equals(ty) +} + type TypeArray struct { AbstractType Element Type @@ -55,6 +60,14 @@ func (this *TypeArray) String () string { return this.LLString() } +func (this *TypeArray) Equals (ty Type) bool { + if ty, ok := ty.(*TypeArray); ok { + return this.Length == ty.Length && + TypesEqual(this.Element, ty.Element) + } + return false +} + type FloatKind uint8; const ( // 16-bit floating-point type (IEEE 754 half precision). FloatKindHalf FloatKind = iota // half @@ -96,6 +109,13 @@ func (this *TypeFloat) String () string { return this.LLString() } +func (this *TypeFloat) Equals (ty Type) bool { + if ty, ok := ty.(*TypeFloat); ok { + return this.Kind == ty.Kind + } + return false +} + type TypeFunction struct { AbstractType Return Type @@ -123,6 +143,19 @@ func (this *TypeFunction) String () string { return this.LLString() } +func (this *TypeFunction) Equals (ty Type) bool { + if ty, ok := ty.(*TypeFunction); ok { + if len(this.Parameters) != len(ty.Parameters) { return false } + for index, parameter := range this.Parameters { + if !TypesEqual(parameter, ty.Parameters[index]) { return false } + } + return TypesEqual(this.Return, ty.Return) && + this.Variadic == ty.Variadic + } + return false + +} + type TypeInt struct { AbstractType BitSize uint64 @@ -136,6 +169,13 @@ func (this *TypeInt) String () string { return this.LLString() } +func (this *TypeInt) Equals (ty Type) bool { + if ty, ok := ty.(*TypeInt); ok { + return this.BitSize == ty.BitSize + } + return false +} + type TypeLabel struct { AbstractType } @@ -149,6 +189,11 @@ func (this *TypeLabel) String () string { return this.LLString() } +func (this *TypeLabel) Equals (ty Type) bool { + _, ok := ty.(*TypeLabel) + return ok +} + type TypeMMX struct { AbstractType } @@ -162,6 +207,11 @@ func (this *TypeMMX) String () string { return this.LLString() } +func (this *TypeMMX) Equals (ty Type) bool { + _, ok := ty.(*TypeMMX) + return ok +} + type TypeMetadata struct { AbstractType } @@ -169,11 +219,17 @@ type TypeMetadata struct { func (this *TypeMetadata) LLString () string { return "metadata" } + func (this *TypeMetadata) String () string { if this.Named() { return EncodeTypeName(this.Name()) } return this.LLString() } +func (this *TypeMetadata) Equals (ty Type) bool { + _, ok := ty.(*TypeMetadata) + return ok +} + type TypePointer struct { AbstractType AddressSpace AddressSpace @@ -192,6 +248,11 @@ func (this *TypePointer) String () string { return this.LLString() } +func (this *TypePointer) Equals (ty Type) bool { + _, ok := ty.(*TypePointer) + return ok +} + type TypeStruct struct { AbstractType Fields []Type @@ -221,6 +282,17 @@ func (this *TypeStruct) String () string { return this.LLString() } +func (this *TypeStruct) Equals (ty Type) bool { + if ty, ok := ty.(*TypeStruct); ok { + if len(this.Fields) != len(ty.Fields) { return false } + for index, field := range this.Fields { + if !TypesEqual(field, ty.Fields[index]) { return false } + } + return this.Packed == ty.Packed && this.Opaque == ty.Opaque + } + return false +} + type TypeToken struct { AbstractType } @@ -234,6 +306,11 @@ func (this *TypeToken) String () string { return this.LLString() } +func (this *TypeToken) Equals (ty Type) bool { + _, ok := ty.(*TypeToken) + return ok +} + type TypeVector struct { AbstractType Element Type @@ -254,6 +331,15 @@ func (this *TypeVector) String () string { return this.LLString() } +func (this *TypeVector) Equals (ty Type) bool { + if ty, ok := ty.(*TypeVector); ok { + return this.Length == ty.Length && + this.Scalable == ty.Scalable && + TypesEqual(this.Element, ty.Element) + } + return false +} + type TypeVoid struct { AbstractType } @@ -267,6 +353,11 @@ func (this *TypeVoid) String () string { return this.LLString() } +func (this *TypeVoid) Equals (ty Type) bool { + _, ok := ty.(*TypeVoid) + return ok +} + func aggregateElemType (t Type, indices []uint64) Type { // Base case. if len(indices) == 0 { @@ -281,3 +372,18 @@ func aggregateElemType (t Type, indices []uint64) Type { panic(fmt.Errorf("support for aggregate type %T not yet implemented", t)) } } + +// TypesEqual checks if two types are equal to eachother, even if one or both +// are nil. +func TypesEqual (left, right Type) bool { + left = ReduceToBase(left) + right = ReduceToBase(right) + if (left == nil) != (right == nil) { return false } + if left == nil { return true } + return left.Equals(right) +} + +func ReduceToBase (ty Type) Type { + if ty, ok := ty.(*TypeDefined); ok { return ty.Source } + return ty +}