Added a way to test IR type equality

This commit is contained in:
Sasha Koshka 2024-01-27 06:49:30 +00:00
parent f12405b3f4
commit 73e2cffda2
1 changed files with 106 additions and 0 deletions

View File

@ -8,6 +8,7 @@ type Type interface {
LLString () string
Name () string
SetName (name string)
Equals (Type) bool
}
type AbstractType struct {
@ -40,6 +41,10 @@ func (this *TypeDefined) String () string {
return this.LLString()
}
func (this *TypeDefined) Equals (ty Type) bool {
return this.Source.Equals(ty)
}
type TypeArray struct {
AbstractType
Element Type
@ -55,6 +60,14 @@ func (this *TypeArray) String () string {
return this.LLString()
}
func (this *TypeArray) Equals (ty Type) bool {
if ty, ok := ty.(*TypeArray); ok {
return this.Length == ty.Length &&
TypesEqual(this.Element, ty.Element)
}
return false
}
type FloatKind uint8; const (
// 16-bit floating-point type (IEEE 754 half precision).
FloatKindHalf FloatKind = iota // half
@ -96,6 +109,13 @@ func (this *TypeFloat) String () string {
return this.LLString()
}
func (this *TypeFloat) Equals (ty Type) bool {
if ty, ok := ty.(*TypeFloat); ok {
return this.Kind == ty.Kind
}
return false
}
type TypeFunction struct {
AbstractType
Return Type
@ -123,6 +143,19 @@ func (this *TypeFunction) String () string {
return this.LLString()
}
func (this *TypeFunction) Equals (ty Type) bool {
if ty, ok := ty.(*TypeFunction); ok {
if len(this.Parameters) != len(ty.Parameters) { return false }
for index, parameter := range this.Parameters {
if !TypesEqual(parameter, ty.Parameters[index]) { return false }
}
return TypesEqual(this.Return, ty.Return) &&
this.Variadic == ty.Variadic
}
return false
}
type TypeInt struct {
AbstractType
BitSize uint64
@ -136,6 +169,13 @@ func (this *TypeInt) String () string {
return this.LLString()
}
func (this *TypeInt) Equals (ty Type) bool {
if ty, ok := ty.(*TypeInt); ok {
return this.BitSize == ty.BitSize
}
return false
}
type TypeLabel struct {
AbstractType
}
@ -149,6 +189,11 @@ func (this *TypeLabel) String () string {
return this.LLString()
}
func (this *TypeLabel) Equals (ty Type) bool {
_, ok := ty.(*TypeLabel)
return ok
}
type TypeMMX struct {
AbstractType
}
@ -162,6 +207,11 @@ func (this *TypeMMX) String () string {
return this.LLString()
}
func (this *TypeMMX) Equals (ty Type) bool {
_, ok := ty.(*TypeMMX)
return ok
}
type TypeMetadata struct {
AbstractType
}
@ -169,11 +219,17 @@ type TypeMetadata struct {
func (this *TypeMetadata) LLString () string {
return "metadata"
}
func (this *TypeMetadata) String () string {
if this.Named() { return EncodeTypeName(this.Name()) }
return this.LLString()
}
func (this *TypeMetadata) Equals (ty Type) bool {
_, ok := ty.(*TypeMetadata)
return ok
}
type TypePointer struct {
AbstractType
AddressSpace AddressSpace
@ -192,6 +248,11 @@ func (this *TypePointer) String () string {
return this.LLString()
}
func (this *TypePointer) Equals (ty Type) bool {
_, ok := ty.(*TypePointer)
return ok
}
type TypeStruct struct {
AbstractType
Fields []Type
@ -221,6 +282,17 @@ func (this *TypeStruct) String () string {
return this.LLString()
}
func (this *TypeStruct) Equals (ty Type) bool {
if ty, ok := ty.(*TypeStruct); ok {
if len(this.Fields) != len(ty.Fields) { return false }
for index, field := range this.Fields {
if !TypesEqual(field, ty.Fields[index]) { return false }
}
return this.Packed == ty.Packed && this.Opaque == ty.Opaque
}
return false
}
type TypeToken struct {
AbstractType
}
@ -234,6 +306,11 @@ func (this *TypeToken) String () string {
return this.LLString()
}
func (this *TypeToken) Equals (ty Type) bool {
_, ok := ty.(*TypeToken)
return ok
}
type TypeVector struct {
AbstractType
Element Type
@ -254,6 +331,15 @@ func (this *TypeVector) String () string {
return this.LLString()
}
func (this *TypeVector) Equals (ty Type) bool {
if ty, ok := ty.(*TypeVector); ok {
return this.Length == ty.Length &&
this.Scalable == ty.Scalable &&
TypesEqual(this.Element, ty.Element)
}
return false
}
type TypeVoid struct {
AbstractType
}
@ -267,6 +353,11 @@ func (this *TypeVoid) String () string {
return this.LLString()
}
func (this *TypeVoid) Equals (ty Type) bool {
_, ok := ty.(*TypeVoid)
return ok
}
func aggregateElemType (t Type, indices []uint64) Type {
// Base case.
if len(indices) == 0 {
@ -281,3 +372,18 @@ func aggregateElemType (t Type, indices []uint64) Type {
panic(fmt.Errorf("support for aggregate type %T not yet implemented", t))
}
}
// TypesEqual checks if two types are equal to eachother, even if one or both
// are nil.
func TypesEqual (left, right Type) bool {
left = ReduceToBase(left)
right = ReduceToBase(right)
if (left == nil) != (right == nil) { return false }
if left == nil { return true }
return left.Equals(right)
}
func ReduceToBase (ty Type) Type {
if ty, ok := ty.(*TypeDefined); ok { return ty.Source }
return ty
}