The key to solve this problem is to understand the parameters of imshow function. matplotlib.pyplot.imshow () the input of the function needs to be a two-dimensional numpy or a three-dimensional numpy of 3 or 4. When the depth of the third dimension is 1, use np.squeeze The () function compresses data into a two-dimensional array. Because I use it in the python environment, the output of the result is (batch)_ Size, channel, width, height), so I first need the detach() function to cut off the backpropagation. It should be pointed out that imshow does not support the display of tensors, so I need to use the. CPU () function to transfer to the CPU. As mentioned earlier, the input of imshow function needs to be a two-dimensional numpy or a three-dimensional numpy of 3 or 4, because my usage is quite special, and there is an additional batch_ Size dimension, but it’s OK. I set up batch_ The size is only 1. At this time, you can use the. Squeeze() function to remove 1 and get a (channel, width, height) numpy, which obviously does not meet the input requirements of imshow. Therefore, we need to use the transpose function to move channel (= 3) to the end, which is why we have the usage of. Transpose (1,2,0). Of course, if the image to be displayed itself is channel = 1, you can use the squeeze() function to get rid of it and directly input it to the imshow function as a two-dimensional numpy
plt.imshow(img2.detach().cpu().squeeze().numpy().transpose(1,2,0))