AttributeError: ‘GlobalStorage’ object has no attribute ‘train_mask’ Solution
def create_masks(data):
"""
Splits data into training, validation, and test splits in a stratified manner if
it is not already splitted. Each split is associated with a mask vector, which
specifies the indices for that split. The data will be modified in-place
:param data: Data object
:return: The modified data
"""
if not hasattr(data, "val_mask"):
data.train_mask = data.dev_mask = data.test_mask = None
for i in range(20):
labels = data.y.numpy()
dev_size = int(labels.shape[0] * 0.1)
test_size = int(labels.shape[0] * 0.8)
perm = np.random.permutation(labels.shape[0])
test_index = perm[:test_size]
dev_index = perm[test_size:test_size + dev_size]
data_index = np.arange(labels.shape[0])
test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool)
dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool)
train_mask = ~(dev_mask + test_mask)
test_mask = test_mask.reshape(1, -1)
dev_mask = dev_mask.reshape(1, -1)
train_mask = train_mask.reshape(1, -1)
if data.train_mask is None:
data.train_mask = train_mask
data.val_mask = dev_mask
data.test_mask = test_mask
else:
data.train_mask = torch.cat((data.train_mask, train_mask), dim=0)
data.val_mask = torch.cat((data.val_mask, dev_mask), dim=0)
data.test_mask = torch.cat((data.test_mask, test_mask), dim=0)
else: # in the case of WikiCS
data.train_mask = data.train_mask.T
data.val_mask = data.val_mask.T
return data
AttributeError: 'GlobalStorage' object has no attribute 'train_mask'
Line 33: Change
if data.train_mask is None:
to if 'train_mask' not in data: