[Solved] RuntimeError: gather(): Expected dtype int64 for index

Error prompt

RuntimeError: gather(): Expected dtype int64 for index

Error code:

a_batch = torch.tensor(a_batch.astype(int, copy=False),device=device)

Solution:

Add type requirements: dtype = torch.int64

a_batch = torch.tensor(a_batch.astype(int, copy=False),dtype=torch.int64,device=device)

Read More: