Plot PyTorch tensors with matplotlib
Have you ever tried to plot a PyTorch tensor with matplotlib like:
plt.plot(tensor) and then received the following error?
AttributeError: 'Tensor' object has no attribute 'ndim' You can get around this easily by letting all PyTorch tensors know how to respond to ndim like this:
torch.Tensor.ndim = property(lambda self: len(self.shape)) Basically, this uses the property decorator to create ndim as a property which reads its value as the length of self.
Shapeshifting PyTorch
An important consideration in machine learning is the shape of your data and your variables. You are often shifting and transforming data and then combining it. Thus, it is essential to know how to do this and what shortcuts are available.
Let’s start with a tensor with a single dimension:
import torch test = torch.tensor([1,2,3]) test.shape torch.Size([3]) Now assume we have built some machine learning model which takes batches of such single dimensional tensors as input and returns some output.