Plot PyTorch tensors with matplotlib
This article covers:
Have you ever tried to plot a PyTorch tensor with matplotlib like:
plt.plot(tensor)
and then received the following error?
AttributeError: 'Tensor' object has no attribute 'ndim'
You can get around this easily by letting all PyTorch tensors know how to respond to ndim like this:
torch.Tensor.ndim = property(lambda self: len(self.shape))
Basically, this uses the property
decorator to create ndim as a property which reads its value as the length of self.shape.
Thus, after you define this, a PyTorch tensor has ndim, so it can be plotted like shown here:
import torch
import matplotlib.pyplot as plt
x = torch.linspace(-5,5,100)
x_squared = x * x
plt.plot(x, x_squared) # Fails: 'Tensor' object has no attribute 'ndim'
torch.Tensor.ndim = property(lambda self: len(self.shape)) # Fix it
plt.plot(x, x_squared) # Works now
plt.show()
Read other posts
comments powered by Disqus