Saturday, May 16, 2009

Implementing Equals in C#

In this post I present a way that, in my opinion, correctly implements equality between objects. Or at least, correctly implements equality with respect to some specific objectives. Indeed, there are many ways to write Equals(). But, each way targets some specific goals and I do not believe that there is only one way to do it.

To start with, let's state what the objectives are in this case:

  • Equality for immutable value objects.
  • Two objects are equals iff they are of exactly the same class, independently of the type of the variable that holds them.
  • Support inheritance and reuse comparison logic from base classes.
  • Support the == operator.
  • Reduce the programmer's workload as much as possible.
The code is below:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Equal
{
    public class Util
    {
        // Returns true if a and b are both non null and are of exactly the same class
        public static bool AreSameClass(Object a, Object b)
        {
            return a != null && b != null && a.GetType().Equals(b.GetType());
        }
    }

    public class A
    {
        private int a;

        public A(int a)
        {
            this.a = a;
        }

        public override int GetHashCode()
        {
            return a;
        }

        public override bool Equals(object o)
        {
            return Util.AreSameClass(this, o) && this.EqualMembers((A)o);
        }

        protected bool EqualMembers(A o)
        {
            return a == o.a;
        }

        public static bool operator ==(A a, A b)
        {
            return object.Equals(a, b); // handle cases where a or b is null.
        }

        public static bool operator !=(A a, A b)
        {
            return !(a == b);
        }
    }

    public class B : A
    {
        public int b;

        public B(int a, int b)
            : base(a)
        {
            this.b = b;
        }

        public override int GetHashCode()
        {
            return base.GetHashCode() + b;
        }

        public override bool Equals(object o)
        {
            return Util.AreSameClass(this, o) && this.EqualMembers((B)o);

        }

        protected bool EqualMembers(B o)
        {
            return base.Equals(o) && b == o.b;
        }
    }
}
The Util class is a helper class that helps find whether two objects are of the same class. If at least one is null then the objects are considered not being of the same class. As we will see later, the handling of null objects is done elsewhere. Then comes the base class A. This class follows the good practices in term of implementation of GetHashCode().

In addition:

  • The class A overrides A.Equals(object o). It first tests whether this and the object to be compared to are of the same class, using for that the above mentionned Util class. If both objects are of the same class, it then calls the protected, non virtual A.EqualMembers(A o) method to compare two by two the member variables of the class A.
  • As said before, the class A implements A.EqualMembers(A o) . This function compares two objects of type A with respect to their member variables. For this, it does it in the most appropriate way. That is, by invoking Equals(), the == operator or object.ReferenceEquals() if necessary. A.EqualMembers(A o) is not meant to be redefined by subclasses. If that should happen, the subclass should simply calls base.EqualMembers(A o).
  • Class A also implements the dual operators == and != operators. The == operator calls objects.Equals(), making sure that nulls are handled properly.
The implementation of equals for subclass B is not more complex:

  • Class B implements GetHashCode() which calls base.GetHashCode().
  • Class B implements B.equals(object o) the same way as class A.
  • Class B implements EqualMembers(B o). This method first calls base.EqualMembers(o) in order to compare the member variables defined by class A and then compares the member variables defined by class B.
  • Class B does not need to implements the == operator anymore.
One can ask why we have defined the non virtual methods EqualMember(). Usually, the role of comparing objects in a type safe way is given to IEquatable.Equals(T o). The answer is simple, in that case, two objects of the same type (B) which are different but whose common parts (A) are the same would be equal if the equality is tested from within a method of class B. In other words, the equality would return different results depending on the caller's context.

This is shown in the code below:


using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Equal
{
public class Util
{
  // Returns true if a and b are both non null and are of exactly the same class
  public static bool AreSameClass(Object a, Object b)
  {
      return a != null && b != null && a.GetType().Equals(b.GetType());
  }
}

public class A
{
  private int a;

  public A(int a)
  {
      this.a = a;
  }

  public override int GetHashCode()
  {
      return a;
  }

  public override bool Equals(object o)
  {
      return Util.AreSameClass(this, o) && this.Equals((A)o);
  }

  protected bool Equals(A o)
  {
      return a == o.a;
  }

  public static bool operator ==(A a, A b)
  {
      return object.Equals(a, b); // handle cases where a or b is null.
  }

  public static bool operator !=(A a, A b)
  {
      return !(a == b);
  }
}

public class B : A
{
  public int b;

  public B(int a, int b)
      : base(a)
  {
      this.b = b;
  }

  public override int GetHashCode()
  {
      return base.GetHashCode() + b;
  }

  public override bool Equals(object o)
  {
      return Util.AreSameClass(this, o) && this.Equals((B)o);

  }

  protected bool Equals(B o)
  {
      return base.Equals(o) && b == o.b;
  }

  static public void test()
  {
      B b1 = new B(1, 2);
      B b2 = new B (1,3);

      // casts
      A a2 = b2;

      // returns true althought b1 and b2 are different
      bool r = b1.Equals(a2);
 }
}
}

To finish, here is a set of test cases:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Equal
{
class Program
{
 static void Main(string[] args)
 {
     B b1 = new B(1, 2);
     B b2 = new B(1, 3);
     B b3 = new B(1, 2);

     // casts
     A a1 = b1;
     A a2 = b2;
     A a3 = b3;

     object o1 = b1;
     object o2 = b2;
     object o3 = b3;

     bool r = true;

     r &= b1.Equals(b1);
     r &= b1.Equals(a1);
     r &= b1.Equals(o1);

     r &= a1.Equals(b1);
     r &= a1.Equals(a1);
     r &= a1.Equals(o1);

     r &= o1.Equals(b1);
     r &= o1.Equals(a1);
     r &= o1.Equals(o1);

     r &= b1 == b1;
     r &= b1 == a1;
     r &= (b1 == o1);  // reference equality

     r &= a1 == b1;
     r &= a1 == a1;
     r &= (a1 == o1);  // reference equality

     r &= (o1 == b1);  // reference equality
     r &= (o1 == a1);  // reference equality
     r &= (o1 == o1);  // reference equality


     r &= !(b1.Equals(b2));
     r &= !(b1.Equals(a2));
     r &= !(b1.Equals(o2));

     r &= !(a1.Equals(b2));
     r &= !(a1.Equals(a2));
     r &= !(a1.Equals(o2));

     r &= !(o1.Equals(b2));
     r &= !(o1.Equals(a2));
     r &= !(o1.Equals(o2));

     r &= !(b1 == b2);
     r &= !(b1 == a2);
     r &= !(b1 == o2);  // reference equality

     r &= !(a1 == b2);
     r &= !(a1 == a2);
     r &= !(a1 == o2);  // reference equality

     r &= !(o1 == b2);  // reference equality
     r &= !(o1 == a2);  // reference equality
     r &= !(o1 == o2);  // reference equality


     r &= b1.Equals(b3);
     r &= b1.Equals(a3);
     r &= b1.Equals(o3);

     r &= a1.Equals(b3);
     r &= a1.Equals(a3);
     r &= a1.Equals(o3);

     r &= o1.Equals(b3);
     r &= o1.Equals(a3);
     r &= o1.Equals(o3);

     r &= b1 == b3;
     r &= b1 == a3;
     r &= !(b1 == o3);  // reference equality

     r &= a1 == b3;
     r &= a1 == a3;
     r &= !(a1 == o3);  // reference equality

     r &= !(o1 == b3);  // reference equality
     r &= !(o1 == a3);  // reference equality
     r &= !(o1 == o3);  // reference equality 
 }
}
}