Skip to content

Discriminated Union code-generation [even more improvements!]

Compare
Choose a tag to compare
@louthy louthy released this 10 Nov 14:46

Continuing from the two releases [1],[2] this weekend relating to the new discriminated-union feature of language-ext...

There is now support for creating unions from abstract classes. Although this is slightly less terse than using interfaces, there is a major benefit: classes can contain operators and so the equality and ordering operators can be automatically generated.

So, as well as being able to create unions from interfaces like so:

    [Union]
    public interface Shape
    {
        Shape Rectangle(float width, float length);
        Shape Circle(float radius);
        Shape Prism(float width, float height);
    }

You can now additionally create them from an abstract partial class like so:

    [Union]
    public abstract partial class Shape
    {
        public abstract Shape Rectangle(float width, float length);
        public abstract Shape Circle(float radius);
        public abstract Shape Prism(float width, float height);
    }

Which allows for:

    Shape shape1 = ShapeCon.Circle(100);
    Shape shape2 = ShapeCon.Circle(100);
    Shape shape3 = ShapeCon.Circle(50);

    Assert.True(shape1 == shape2);  
    Assert.False(shape2 == shape3);  
    Assert.True(shape2 > shape3);  

Case classes are now sealed rather than partial. partial opens the door to addition of fields and properties which could compromise the case-type. And so extension methods are the best way of adding functionality to the case-types.

To make all of this work with abstract classes I needed to remove the inheritance of Record<CASE_TYPE> from each union case, and so now the generated code does the work of the Record type at compile-time rather than at run time. It's lead to a slight explosion in the amount of generated code, but I guess it shows how hard it is to do this manually!

    [System.Serializable]
    public sealed class Rectangle : _ShapeBase, System.IEquatable<Rectangle>, System.IComparable<Rectangle>, System.IComparable
    {
        public readonly float Width;
        public readonly float Length;
        public override int _Tag => 1;
        public Rectangle(float width, float length)
        {
            this.Width = width;
            this.Length = length;
        }

        public void Deconstruct(out float Width, out float Length)
        {
            Width = this.Width;
            Length = this.Length;
        }

        public Rectangle(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            Width = (float)info.GetValue("Width", typeof(float));
            Length = (float)info.GetValue("Length", typeof(float));
        }

        public void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            info.AddValue("Width", Width);
            info.AddValue("Length", Length);
        }

        public static bool operator ==(Rectangle x, Rectangle y) => ReferenceEquals(x, y) || (x?.Equals(y) ?? false);
        public static bool operator !=(Rectangle x, Rectangle y) => !(x == y);
        public static bool operator>(Rectangle x, Rectangle y) => !ReferenceEquals(x, y) && !ReferenceEquals(x, null) && x.CompareTo(y) > 0;
        public static bool operator <(Rectangle x, Rectangle y) => !ReferenceEquals(x, y) && (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) < 0);
        public static bool operator >=(Rectangle x, Rectangle y) => ReferenceEquals(x, y) || (!ReferenceEquals(x, null) && x.CompareTo(y) >= 0);
        public static bool operator <=(Rectangle x, Rectangle y) => ReferenceEquals(x, y) || (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) <= 0);
        public bool Equals(Rectangle other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return false;
            if (!default(EqDefault<float>).Equals(Width, other.Width))
                return false;
            if (!default(EqDefault<float>).Equals(Length, other.Length))
                return false;
            return true;
        }

        public override bool Equals(object obj) => obj is Rectangle tobj && Equals(tobj);
        public override bool Equals(Shape obj) => obj is Rectangle tobj && Equals(tobj);
        public override int CompareTo(object obj) => obj is Shape p ? CompareTo(p) : 1;
        public override int CompareTo(Shape obj) => obj is Rectangle tobj ? CompareTo(tobj) : obj is null ? 1 : _Tag.CompareTo(obj._Tag);
        public int CompareTo(Rectangle other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return 1;
            int cmp = 0;
            cmp = default(OrdDefault<float>).Compare(Width, other.Width);
            if (cmp != 0)
                return cmp;
            cmp = default(OrdDefault<float>).Compare(Length, other.Length);
            if (cmp != 0)
                return cmp;
            return 0;
        }

        public override int GetHashCode()
        {
            const int fnvOffsetBasis = -2128831035;
            const int fnvPrime = 16777619;
            int state = fnvOffsetBasis;
            unchecked
            {
                state = (default(EqDefault<float>).GetHashCode(Width) ^ state) * fnvPrime;
                state = (default(EqDefault<float>).GetHashCode(Length) ^ state) * fnvPrime;
            }

            return state;
        }

        public override string ToString()
        {
            var sb = new StringBuilder();
            sb.Append("Rectangle(");
            sb.Append(LanguageExt.Prelude.isnull(Width) ? $"Width: [null]" : $"Width: {Width}");
            sb.Append($", ");
            sb.Append(LanguageExt.Prelude.isnull(Length) ? $"Length: [null]" : $"Length: {Length}");
            sb.Append(")");
            return sb.ToString();
        }

        public Rectangle With(float? Width = null, float? Length = null) => new Rectangle(Width ?? this.Width, Length ?? this.Length);
        public static readonly Lens<Rectangle, float> width = Lens<Rectangle, float>.New(_x => _x.Width, _x => _y => _y.With(Width: _x));
        public static readonly Lens<Rectangle, float> length = Lens<Rectangle, float>.New(_x => _x.Length, _x => _y => _y.With(Length: _x));
    }

    [System.Serializable]
    public sealed class Circle : _ShapeBase, System.IEquatable<Circle>, System.IComparable<Circle>, System.IComparable
    {
        public readonly float Radius;
        public override int _Tag => 2;
        public Circle(float radius)
        {
            this.Radius = radius;
        }

        public void Deconstruct(out float Radius)
        {
            Radius = this.Radius;
        }

        public Circle(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            Radius = (float)info.GetValue("Radius", typeof(float));
        }

        public void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            info.AddValue("Radius", Radius);
        }

        public static bool operator ==(Circle x, Circle y) => ReferenceEquals(x, y) || (x?.Equals(y) ?? false);
        public static bool operator !=(Circle x, Circle y) => !(x == y);
        public static bool operator>(Circle x, Circle y) => !ReferenceEquals(x, y) && !ReferenceEquals(x, null) && x.CompareTo(y) > 0;
        public static bool operator <(Circle x, Circle y) => !ReferenceEquals(x, y) && (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) < 0);
        public static bool operator >=(Circle x, Circle y) => ReferenceEquals(x, y) || (!ReferenceEquals(x, null) && x.CompareTo(y) >= 0);
        public static bool operator <=(Circle x, Circle y) => ReferenceEquals(x, y) || (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) <= 0);
        public bool Equals(Circle other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return false;
            if (!default(EqDefault<float>).Equals(Radius, other.Radius))
                return false;
            return true;
        }

        public override bool Equals(object obj) => obj is Circle tobj && Equals(tobj);
        public override bool Equals(Shape obj) => obj is Circle tobj && Equals(tobj);
        public override int CompareTo(object obj) => obj is Shape p ? CompareTo(p) : 1;
        public override int CompareTo(Shape obj) => obj is Circle tobj ? CompareTo(tobj) : obj is null ? 1 : _Tag.CompareTo(obj._Tag);
        public int CompareTo(Circle other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return 1;
            int cmp = 0;
            cmp = default(OrdDefault<float>).Compare(Radius, other.Radius);
            if (cmp != 0)
                return cmp;
            return 0;
        }

        public override int GetHashCode()
        {
            const int fnvOffsetBasis = -2128831035;
            const int fnvPrime = 16777619;
            int state = fnvOffsetBasis;
            unchecked
            {
                state = (default(EqDefault<float>).GetHashCode(Radius) ^ state) * fnvPrime;
            }

            return state;
        }

        public override string ToString()
        {
            var sb = new StringBuilder();
            sb.Append("Circle(");
            sb.Append(LanguageExt.Prelude.isnull(Radius) ? $"Radius: [null]" : $"Radius: {Radius}");
            sb.Append(")");
            return sb.ToString();
        }

        public Circle With(float? Radius = null) => new Circle(Radius ?? this.Radius);
        public static readonly Lens<Circle, float> radius = Lens<Circle, float>.New(_x => _x.Radius, _x => _y => _y.With(Radius: _x));
    }

    [System.Serializable]
    public sealed class Prism : _ShapeBase, System.IEquatable<Prism>, System.IComparable<Prism>, System.IComparable
    {
        public readonly float Width;
        public readonly float Height;
        public override int _Tag => 3;
        public Prism(float width, float height)
        {
            this.Width = width;
            this.Height = height;
        }

        public void Deconstruct(out float Width, out float Height)
        {
            Width = this.Width;
            Height = this.Height;
        }

        public Prism(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            Width = (float)info.GetValue("Width", typeof(float));
            Height = (float)info.GetValue("Height", typeof(float));
        }

        public void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context)
        {
            info.AddValue("Width", Width);
            info.AddValue("Height", Height);
        }

        public static bool operator ==(Prism x, Prism y) => ReferenceEquals(x, y) || (x?.Equals(y) ?? false);
        public static bool operator !=(Prism x, Prism y) => !(x == y);
        public static bool operator>(Prism x, Prism y) => !ReferenceEquals(x, y) && !ReferenceEquals(x, null) && x.CompareTo(y) > 0;
        public static bool operator <(Prism x, Prism y) => !ReferenceEquals(x, y) && (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) < 0);
        public static bool operator >=(Prism x, Prism y) => ReferenceEquals(x, y) || (!ReferenceEquals(x, null) && x.CompareTo(y) >= 0);
        public static bool operator <=(Prism x, Prism y) => ReferenceEquals(x, y) || (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) <= 0);
        public bool Equals(Prism other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return false;
            if (!default(EqDefault<float>).Equals(Width, other.Width))
                return false;
            if (!default(EqDefault<float>).Equals(Height, other.Height))
                return false;
            return true;
        }

        public override bool Equals(object obj) => obj is Prism tobj && Equals(tobj);
        public override bool Equals(Shape obj) => obj is Prism tobj && Equals(tobj);
        public override int CompareTo(object obj) => obj is Shape p ? CompareTo(p) : 1;
        public override int CompareTo(Shape obj) => obj is Prism tobj ? CompareTo(tobj) : obj is null ? 1 : _Tag.CompareTo(obj._Tag);
        public int CompareTo(Prism other)
        {
            if (LanguageExt.Prelude.isnull(other))
                return 1;
            int cmp = 0;
            cmp = default(OrdDefault<float>).Compare(Width, other.Width);
            if (cmp != 0)
                return cmp;
            cmp = default(OrdDefault<float>).Compare(Height, other.Height);
            if (cmp != 0)
                return cmp;
            return 0;
        }

        public override int GetHashCode()
        {
            const int fnvOffsetBasis = -2128831035;
            const int fnvPrime = 16777619;
            int state = fnvOffsetBasis;
            unchecked
            {
                state = (default(EqDefault<float>).GetHashCode(Width) ^ state) * fnvPrime;
                state = (default(EqDefault<float>).GetHashCode(Height) ^ state) * fnvPrime;
            }

            return state;
        }

        public override string ToString()
        {
            var sb = new StringBuilder();
            sb.Append("Prism(");
            sb.Append(LanguageExt.Prelude.isnull(Width) ? $"Width: [null]" : $"Width: {Width}");
            sb.Append($", ");
            sb.Append(LanguageExt.Prelude.isnull(Height) ? $"Height: [null]" : $"Height: {Height}");
            sb.Append(")");
            return sb.ToString();
        }

        public Prism With(float? Width = null, float? Height = null) => new Prism(Width ?? this.Width, Height ?? this.Height);
        public static readonly Lens<Prism, float> width = Lens<Prism, float>.New(_x => _x.Width, _x => _y => _y.With(Width: _x));
        public static readonly Lens<Prism, float> height = Lens<Prism, float>.New(_x => _x.Height, _x => _y => _y.With(Height: _x));
    }

    public static partial class ShapeCon
    {
        public static Shape Rectangle(float width, float length) => new Rectangle(width, length);
        public static Shape Circle(float radius) => new Circle(radius);
        public static Shape Prism(float width, float height) => new Prism(width, height);
    }

    [System.Serializable]
    public abstract partial class Shape : IEquatable<Shape>, IComparable<Shape>, IComparable
    {
        public abstract int _Tag
        {
            get;
        }

        public abstract int CompareTo(object obj);
        public abstract int CompareTo(Shape other);
        public abstract bool Equals(Shape other);
        public override bool Equals(object obj) => obj is Shape tobj && Equals(tobj);
        public override int GetHashCode() => throw new System.NotSupportedException();
        public static bool operator ==(Shape x, Shape y) => ReferenceEquals(x, y) || (x?.Equals(y) ?? false);
        public static bool operator !=(Shape x, Shape y) => !(x == y);
        public static bool operator>(Shape x, Shape y) => !ReferenceEquals(x, y) && !ReferenceEquals(x, null) && x.CompareTo(y) > 0;
        public static bool operator <(Shape x, Shape y) => !ReferenceEquals(x, y) && (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) < 0);
        public static bool operator >=(Shape x, Shape y) => ReferenceEquals(x, y) || (!ReferenceEquals(x, null) && x.CompareTo(y) >= 0);
        public static bool operator <=(Shape x, Shape y) => ReferenceEquals(x, y) || (ReferenceEquals(x, null) && !ReferenceEquals(y, null) || x.CompareTo(y) <= 0);
    }

    public abstract partial class _ShapeBase : Shape
    {
        public override Shape Rectangle(float width, float length) => throw new NotSupportedException();
        public override Shape Circle(float radius) => throw new NotSupportedException();
        public override Shape Prism(float width, float height) => throw new NotSupportedException();
    }

This will soon be adapted to support a [Record] attribute for generating records at compile-time and remove the need to derive from Record<TYPE>. That will mean struct records will be easy to create.