Django and Model Inheritance

I was recently helping someone with the seemingly easy task of getting a list of related sub-models. This was working with a "tree" structure, where the base model had a ForeignKey to itself.

Whilst django-model-utils and django-polymorphic provide solid implementations for this, I have derived my own solution and tuned it up.

This solution only works properly if there is only ever one sub-class for a given base class instance. Otherwise, it may behave contrary to expectations.

The setup

class Node(models.Model):
    parent = models.ForeignKey('self', null=True, blank=True, related_name="children")

class SlugNode(Node):
    slug = models.SlugField()

class IpNode(Node):
    ip = models.IPAddressField()

The goal

To get a list of "children" for a given node, as their sub-class model (SlugNode or IpNode, but not Node)

How it works

Every sub-class has a OneToOneField to its parent, and, therefore, the parent to each child class. What we're going to do is delve into the Node classes meta and find all its relation fields.

sub_classes = [
    rel
    for rel in Node._meta.get_all_related_objects()
]

But we need to filter this for:

  1. fields that are OneToOneFields
  2. fields that relate to a subclass of Node

So we use:

sub_classes = [
    rel
    for rel in Node._meta.get_all_related_objects()
    if isinstance(rel.field, models.OneToOneField) and issubclass(rel.field.model, Node)
]

OK. Now we have a list of relation fields - not quite as useful as we'd like, but we'll come back to this.

Next, we want to get the list of child nodes.

children = self.children.all()

Simple enough. But to avoid triggering another query on each to fetch its sub-class, we want to select_related. The name of each relation that we need is on rel.var_name.

children = self.children.all().select_related(*[rel.var_name for rel in sub_classes])

Once we have this list, we need to find which of the sub-classes actually exists. We need to iterate through the sub_class fields, and find which of them is not blank.

We could do the obvious solution:

for rel in sub_classes:
    try:
        schild = getattr(child, rel.var_name, None)
    except Node.DoesNotExist:
        pass
    else:
        break

However, there is an easier [though dirtier] solution. When we select_related, the related objects are hidden away on the Model instance. The name of the attribute for that is given by rel.get_cache_name().

for rel in sub_classes:
    schild = child.__dict__.get(rel.get_cache_name())
    if schild is not None:
        break

Now, moving all of our work outside the inner loops, we can get the cache name and var_name in the one loop.

# Build a map of (rel field -> cache attribute name) for each sub-class
sub_classes = dict(
    (rel.var_name, rel.get_cache_name())
    for rel in Node._meta.get_all_related_objects()
    if isinstance(rel.field, models.OneToOneField) and issubclass(rel.field.model, Node)
)

children = []
for child in self.children.select_related(*sub_classes.keys()):
    for sub in sub_classes.values():
        sub_child = child.__dict__.get(sub)
        if sub_child is not None:
           children.append(sub_child)
           break

return children

Now, this will construct a list and return it. If you expect the children list to be large, you could just as easily yield each child as you go.

for child in self.children.select_related(*sub_classes.keys()):
    for sub in sub_classes.values():
        sub_child = child.__dict__.get(sub)
        if sub_child is not None:
           yield sub_child
           break
comments powered by Disqus