← back

The Function Object pattern

2 March 2016

While playing with Homework 5 of Sedgewick's Algorithms course, I got fed up trying to make iterative tree walks work and instead did all of them using pairs of mutually recursive functions.

I then abstracted these signatures into an interface, and discovered/re-invented the Visitor pattern. (Note [Feb 2017]: I later learned of the research on visitors and catamorphisms, see 1 2 3.)

Here I try to show how the pattern simply arises if you write Java in a principled manner (i.e. encoding constraints and structure into your data types).

Problem

A two-dimensional k-d Tree is a search tree that stores points in \(\mathbb{R}^2\), organising the points by partitioning the plane at each node. There are two kinds of partitions (and correspondingly nodes) --- vertical partitions and horizontal partitions --- which alternate by level.

In Haskell, this constraint can be expressed as

type Point  = (Double, Double)
data VNode  = VNode Point HNode HNode
data HNode  = HNode Point VNode VNode
type KdTree = VNode

But in Java you would simply write

class Node {
  private Node lt, rt;
  private boolean partition;    // T is Vert, F is Horiz
}

This approach strikes me as unsafe and somewhat gross.

I want to to encode the constraints into the types directly1 (a la the Haskell type above), so the type system tells me when I invariably mess it up.

So here's the type I ended up with:

abstract class Node {
  protected Point2D point;
}

final class VertNode extends Node {
  private HorizNode lt, rt;
}

final class HorizNode extends Node {
  private VertNode lt, rt;
}

Iterative traversal

The natural thing is to then write the traversal functions in an iterative style. This is Java, after all.

public void insert(Point2D p) {
  if (root == null) {
    root = new VertNode(p);
    return;
  }
  else {
    Node cur = root;
    while (true) {
      if (cur instanceof VertNode) {
        if (p.equals(cur.point)) return;

        else if (p.x() < cur.point.x()) {
          if (cur.lt == null) {
            cur.lt = new HorizNode(p);
            return;
          } else cur = cur.lt;
        }

        else {
          if (cur.rt == null) {
            cur.rt = new HorizNode(p);
            return;
          } else cur = cur.rt;
        }

      } else if (cur instanceof HorizNode) {
            /* Case is symmetric */
      }
    }
  }
}

There are a few problems with this approach ---

Firstly, you have to repeat this loop for every tree traversal. It would be nice to have the basic structure captured somewhere so every traversal function has a starting point.

But more concerning is that this doesn't actually compile, because lt and rt aren't defined in the Node abstract class. I ended up having to define abstract getters and setters there to get around this problem when accessing the children of Node cur.

But, the setters don't take VertNode or HorizNode's but rather generic Node's, so I have to do checks like this

class HorizNode {
...
  @Override
  protected void setLt(Node lt) {
    if (lt instanceof VertNode)
        this.lt = (VertNode) lt;
    else
        throw new IllegalArgumentException();
  }
}

This is bad because (a) it breaks static type safety (there's an implicit upcast), so (b) it can fail at runtime.

Ideally, the setter ought to be precise enough that the compiler catches when someone tries to set a HorizNode to a child of a HorizNode so I don't have to throw exceptions.

Sadly, there's no way to do this in an abstract class. You can overload the setter, but then you'd have to implement both and one will throw an exception.

So setters are out.

Mutual recursion

The problem stems from the fact that the cur pointer is upcasted to Node in order to accomodate both node types, which forces us to do instanceof checks, getters and setters, and all that business.

If we wrote the function recursively, we could have avoided saving that pointer and all the headaches that come with doing so.

private void insert(VertNode node, Point2D p) {
  if (node == null) return;
  if (p.equals(node.point)) return;

  else if (p.x() < node.point.x()) {
    if (node.lt == null) {
      node.lt = new HorizNode(p);
      return;
    } else insert(node.lt, p);
  }
  else {
    if (node.rt == null) {
      node.rt = new HorizNode(p);
      return;
    } else insert(node.rt, p);
  }
}

private void insert(HorizNode node, Point2D p) {
  /* Case is symmetric */
}

No casting here!

Extracting an interface

Now we can tackle the other problem, and give a skeleton for tree traversal functions. If you squint at the pair of insert helper functions, you will spot that it's doing a pattern match on the type of node.

So, we can extract an interface

interface TreeFunction {
  void match (HorizNode node);
  void match (VertNode node);
}

Now, every tree traversal function pair becomes a class. I call this a Function Object, because it's a function written as an object.

Now we have to be able to apply the function, so we have each node type implement another interface

interface TreeApply {
  void apply (TreeFunction function);
}

The implementation2 for TreeApply is simple, it's always

@Override
void apply (TreeFunction function) {
  function.match(this);
}

to traverse the tree starting at a given node.

So now we have something that looks like this:

public void insert(Point2D p) {
  if (root == null)
    root = new VertNode(p);
  else
    root.apply(new InsertPoint(p));
}

Beautiful.

Then I remembered.... isn't this supposed to be a pattern? Turns out, it is.

So what is a Visitor?

Wikipedia gives the following definiton for Visitor

The visitor design pattern is a way of separating an algorithm from an object structure on which it operates. A practical result of this separation is the ability to add new operations to existing object structures without modifying those structures. It is one way to follow the open/closed principle.

In essence, the visitor allows one to add new virtual functions to a family of classes without modifying the classes themselves; instead, one creates a visitor class that implements all of the appropriate specializations of the virtual function. The visitor takes the instance reference as input, and implements the goal through double dispatch.

Given what we've just learned, perhaps an alternate definition would help clarify what the OOP mumbo-jumbo means.

The function object pattern is a way to hack "pattern matching" and higher-order functions into OO data types.

where algebraic data types are represented in OO as

Visitor is simply a poor man's way of emulating pattern matching in unexpressive3 object-oriented languages.

You don't quite get the safety of Haskell (no exhaustiveness checks), and people can use your function objects to mess around with fields separately in sum types, but you've got to admit it looks pretty neat.

Sum types

The function object interface for the tree shows how the pattern works for product types; let's see now how to apply it to sum types. It doesn't work as well, but it still works.

Here's the Visitor example from Wikipedia:

A Car has four wheels, an engine, and a body.

Correspondingly, you have a CarVisitor with brain-dead function names like:

interface ICarElementVisitor {
  void visit (Wheel wheel);
  void visit (Engine engine);
  void visit (Body body);
  void visit (Car car);
}

interface ICarElement {
  void accept (ICarElementVisitor visitor);
}

Here's the same thing but with sane function names:

interface CarFunction {
  void match (Wheel wheel);
  void match (Engine engine);
  void match (Body body);
  void match (Car car);
}

interface CarApplyF {
  void apply (CarFunction function);
}

Better?

An implementation would look like this:

class Car implements CarApplyF {
  Wheel fl,fr,rl,rr;
  Engine engine;
  Body body;

  public Car() { ... }

  public void apply(CarFunction function) {
    function.match(this);
  }
}

// same deal for the other parts...

class PrintCar implements CarFunction {
  public void match (Wheel wheel)   { System.out.println("i have a wheel");   }
  public void match (Engine engine) { System.out.println("i have an engine"); }
  public void match (Body body)     { System.out.println("i have an body");   }
  public void match (Car car) {
    System.out.println("i am a car, here are my parts");
    for (Wheel wheel : car.getWheels()) { match (wheel); }
    match (car.getEngine());
    match (car.getBody());
  }
}

// to run
Car car = new Car();
car.apply(new PrintCar());

Doesn't this look like a pattern match?

Well, not exactly, because now the sum type also looks like a product type. I guess you can also define something like this

interface CarFunction {
  // all possible subsets of fields...
  void match (Wheel fl, Wheel fr, Wheel ll, Wheel lr,
              Engine engine, Body body);
  void match (Engine engine, Body body);
  void match (Engine engine);
  void match (Body body);
  void match (Car car);
}

But it definitely works better for product types.

(If you didn't know this already...)

Strategy is a similar concept, but more lightweight.

From Wikipedia:

The strategy pattern

Again, the buzzword-soup definition makes basically no sense, so let me define it in a different way:

The function object pattern parameterises an operation on a data structure.

What this means is that Strategy is just a terribly limited way of simulating lambdas in Java. (So the Visitor should really be called "Function Objects with pattern matching".)

Exercises

I don't have time to explore these, so I'll just note them here:

  1. Rename the methods and interface names of Strategy to better reflect what's going on.

  2. Can you think of a better name than match? 4

  3. Make the interfaces generic.

  4. Add support for accumulating/returning values (in the pattern itself, not just the implementations), then

  5. Write interfaces for Functor, Traversable, Foldable... etc. Is this even possible?

  6. Implement lambdas in Java 7. Look at Functional Java.

  7. Extend the syntax and support pattern matching and lambdas natively, by desugaring to these patterns. This is probably the first step towards Scala.

Further reading

See the code here.



  1. https://twitter.com/mattmight/status/689478703907090433

  2. As usual, you've got to write the plumbing yourself in Java. Ugh.

  3. Scala does not suffer the same problem.

  4. Currently it's doing double duty as the "pattern match" keyword as well as the "recurse" keyword. (We can't actually recurse on the "function")