“Typeerror: invalid dimensions for image data” in Matplotlib drawing imshow() function

The key to solve this problem is to understand the parameters of imshow function. matplotlib.pyplot.imshow () the input of the function needs to be a two-dimensional numpy or a three-dimensional numpy of 3 or 4. When the depth of the third dimension is 1, use np.squeeze The () function compresses data into a two-dimensional array. Because I use it in the python environment, the output of the result is (batch)_ Size, channel, width, height), so I first need the detach() function to cut off the backpropagation. It should be pointed out that imshow does not support the display of tensors, so I need to use the. CPU () function to transfer to the CPU. As mentioned earlier, the input of imshow function needs to be a two-dimensional numpy or a three-dimensional numpy of 3 or 4, because my usage is quite special, and there is an additional batch_ Size dimension, but it’s OK. I set up batch_ The size is only 1. At this time, you can use the. Squeeze() function to remove 1 and get a (channel, width, height) numpy, which obviously does not meet the input requirements of imshow. Therefore, we need to use the transpose function to move channel (= 3) to the end, which is why we have the usage of. Transpose (1,2,0). Of course, if the image to be displayed itself is channel = 1, you can use the squeeze() function to get rid of it and directly input it to the imshow function as a two-dimensional numpy

plt.imshow(img2.detach().cpu().squeeze().numpy().transpose(1,2,0))

 

Read More: