diff --git a/senteval/tools/relatedness.py b/senteval/tools/relatedness.py index d8bacce5..25d8dc1a 100644 --- a/senteval/tools/relatedness.py +++ b/senteval/tools/relatedness.py @@ -19,6 +19,7 @@ from scipy.stats import pearsonr +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') class RelatednessPytorch(object): # Can be used for SICK-Relatedness, and STS14 @@ -26,8 +27,8 @@ def __init__(self, train, valid, test, devscores, config): # fix seed np.random.seed(config['seed']) torch.manual_seed(config['seed']) - assert torch.cuda.is_available(), 'torch.cuda required for Relatedness' - torch.cuda.manual_seed(config['seed']) + if torch.cuda.is_available(): + torch.cuda.manual_seed(config['seed']) self.train = train self.valid = valid @@ -58,12 +59,12 @@ def __init__(self, train, valid, test, devscores, config): def prepare_data(self, trainX, trainy, devX, devy, testX, testy): # Transform probs to log-probs for KL-divergence - trainX = torch.from_numpy(trainX).float().cuda() - trainy = torch.from_numpy(trainy).float().cuda() - devX = torch.from_numpy(devX).float().cuda() - devy = torch.from_numpy(devy).float().cuda() - testX = torch.from_numpy(testX).float().cuda() - testY = torch.from_numpy(testy).float().cuda() + trainX = torch.from_numpy(trainX).float().to(device) + trainy = torch.from_numpy(trainy).float().to(device) + devX = torch.from_numpy(devX).float().to(device) + devy = torch.from_numpy(devy).float().to(device) + testX = torch.from_numpy(testX).float().to(device) + testY = torch.from_numpy(testy).float().to(device) return trainX, trainy, devX, devy, testX, testy @@ -107,7 +108,7 @@ def trainepoch(self, X, y, nepoches=1): all_costs = [] for i in range(0, len(X), self.batch_size): # forward - idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda() + idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(device) Xbatch = X[idx] ybatch = y[idx] output = self.model(Xbatch)