torch.max Example (How to Use)

torch.max(input, dim)

pred = torch.max(input, dim)

Returns the maximum value per row (dim = 1) or column (dim = 0).

_, pred = torch.max(input, dim)

Only the position of the maximum value in each row (dim = 1) or column (dim = 0) is returned.

Example:

import torch

# Construct a 5x3 randomly initialized matrix
x = torch.rand(5, 3)
print('input: ', x)
print('-'*10)
y1 = torch.max(x, 1)
print('max by row: ', y1)
print('-'*10)
y2 = torch.max(x, 0)
print('max by col: ', y2)
print('-'*10)
_, y3 = torch.max(x, 1)
print('max index by row: ', y3)
print('-'*10)
_, y4 = torch.max(x, 0)
print('max index by col: ', y4)

Output result:

input:  tensor([[0.5504, 0.3160, 0.2448],
        [0.8694, 0.3295, 0.2085],
        [0.5530, 0.9984, 0.3531],
        [0.2874, 0.1025, 0.9419],
        [0.0867, 0.4234, 0.8334]])
----------
max by row:  torch.return_types.max(
values=tensor([0.5504, 0.8694, 0.9984, 0.9419, 0.8334]),
indices=tensor([0, 0, 1, 2, 2]))
----------
max by col:  torch.return_types.max(
values=tensor([0.8694, 0.9984, 0.9419]),
indices=tensor([1, 2, 3]))
----------
max index by row:  tensor([0, 0, 1, 2, 2])
----------
max index by col:  tensor([1, 2, 3])

Read More: