Refactoring machine learning code - einops
This article covers:
Einops is a really great library to improve your machine learning code. It supports Numpy, PyTorch, Tensorflow and many more machine learning libraries. It helps to give more semantic meaning to your code and can also save you a lot of headaches when transforming data.
As a primer let’s look at a typical use-case in machine learning where you have a bunch of data and you want to reshape it, so some dimensions are merged together like this:
x = x.view(x.shape[0], -1)
What does this code do? Well, it reshapes your data, so the first dimension is kept and the remaining dimensions are merged together.
Now compare this with the einops notation:
from einops import rearrange
x = rearrange(x, 'batch channel height width -> batch (channel height width)')
I like this a lot more. Why is that, you may be wondering?
Well, first of all, it tells you exactly and explicitly what is happening, i.e. you have 4 dimensions and you want to transform them to only 2 dimensions (batch
and channel * height * width
).
This explicit notation also means that the code will fail if you erroneously pass a tensor with 5 dimensions instead of the expected 4 dimensions.
The previous code would silently work and pretend that all is good while it’s not.
It also tells you exactly what are those 4 dimensions that are being used here, so you don’t have to look at earlier code and figure out what’s going on here.
Furthermore, you can even make the dimensions fully explicit:
x = rearrange(x, 'batch channel height width -> batch (channel height width)', channel=3)
If you pass the channel
dimension with dimensionality != 3, you will see an Error indicating a mismatch.
Thus, you can explicitly code your assumptions and not leave comments which will be outdated in the future.
Another example use-case is reduce
where you want to get rid / average over a couple of dimensions.
Let’s take a global pooling over our spatial dimensions:
from einops import reduce
result = reduce(x, 'batch channel height width -> batch channel', reduction='mean')
So here in this case, on the right side we omit the dimensions height
and width
which indicates to reduces
that these are the dimensions we want to get rid of.
We also pass the reduction
operation which is taking the mean
here, so reduce
knows how to take care of the job.
Let’s consider an example I’ve seen in real life where the original code wanted to flatten all the non-channel dimensions as follows:
def _flatten_by_channel(prediction):
batch_size, channels, height, width = prediction.shape
permuted = prediction.permute((0, 2, 3, 1))
final = permuted.contiguous().view((batch_size * height * width, channels))
return final
With einops, this becomes a simple one-liner that is much more readable as well:
def _flatten_by_channel(prediction):
return rearrange(prediction, 'batch channel height width -> (batch height width) channel')
Einops also provides neural networks layers like Rearrange()
and Reduce()
for PyTorch and other deep learning libraries like Tensorflow that you can directly use in nn.Sequential layers.
Check out some great PyTorch examples here.
comments powered by Disqus