Error prompt
RuntimeError: gather(): Expected dtype int64 for index
Error code:
a_batch = torch.tensor(a_batch.astype(int, copy=False),device=device)
Solution:
Add type requirements: dtype = torch.int64
a_batch = torch.tensor(a_batch.astype(int, copy=False),dtype=torch.int64,device=device)