diff --git a/examples/gensen.py b/examples/gensen.py index 46dbc5b6..4e1e3bc6 100644 --- a/examples/gensen.py +++ b/examples/gensen.py @@ -31,8 +31,8 @@ def prepare(params, samples): def batcher(params, batch): batch = [' '.join(sent) if sent != [] else '.' for sent in batch] - _, reps_h_t = gensen.get_representation( - sentences, pool='last', return_numpy=True, tokenize=True + _, reps_h_t = params['gensen'].get_representation( + batch, pool='last', return_numpy=True, tokenize=True ) embeddings = reps_h_t return embeddings @@ -49,9 +49,6 @@ def batcher(params, batch): pretrained_emb='../data/embedding/glove.840B.300d.h5' ) gensen_encoder = GenSen(gensen_1, gensen_2) -reps_h, reps_h_t = gensen.get_representation( - sentences, pool='last', return_numpy=True, tokenize=True -) # Set params for SentEval params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}