gensim doc2vec & IMDB sentiment dataset

TODO: section on introduction & motivation

TODO: prerequisites + dependencies (statsmodels, patsy, ?)

Load corpus

Fetch and prep exactly as in Mikolov's go.sh shell script. (Note this cell tests for existence of required files, so steps won't repeat once the final summary file (aclImdb/alldata-id.txt) is available alongside this notebook.)

In [2]:
import locale
import glob
import os.path
import requests
import tarfile

dirname = 'aclImdb'
filename = 'aclImdb_v1.tar.gz'
locale.setlocale(locale.LC_ALL, 'C')


# Convert text to lower-case and strip punctuation/symbols from words
def normalize_text(text):
    norm_text = text.lower()

    # Replace breaks with spaces
    norm_text = norm_text.replace('<br />', ' ')

    # Pad punctuation with spaces on both sides
    for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']:
        norm_text = norm_text.replace(char, ' ' + char + ' ')

    return norm_text


if not os.path.isfile('aclImdb/alldata-id.txt'):
    if not os.path.isdir(dirname):
        if not os.path.isfile(filename):
            # Download IMDB archive
            url = 'http://ai.stanford.edu/~amaas/data/sentiment/' + filename
            r = requests.get(url)
            with open(filename, 'wb') as f:
                f.write(r.content)

        tar = tarfile.open(filename, mode='r')
        tar.extractall()
        tar.close()

    # Concat and normalize test/train data
    folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup']
    alldata = u''

    for fol in folders:
        temp = u''
        output = fol.replace('/', '-') + '.txt'

        # Is there a better pattern to use?
        txt_files = glob.glob('/'.join([dirname, fol, '*.txt']))

        for txt in txt_files:
            with open(txt, 'r', encoding='utf-8') as t:
                control_chars = [chr(0x85)]
                t_clean = t.read()

                for c in control_chars:
                    t_clean = t_clean.replace(c, ' ')

                temp += t_clean

            temp += "\n"

        temp_norm = normalize_text(temp)
        with open('/'.join([dirname, output]), 'w', encoding='utf-8') as n:
            n.write(temp_norm)

        alldata += temp_norm

    with open('/'.join([dirname, 'alldata-id.txt']), 'w', encoding='utf-8') as f:
        for idx, line in enumerate(alldata.splitlines()):
            num_line = "_*{0} {1}\n".format(idx, line)
            f.write(num_line)
In [3]:
import os.path
assert os.path.isfile("aclImdb/alldata-id.txt"), "alldata-id.txt unavailable"

The data is small enough to be read into memory.

In [1]:
import gensim
from gensim.models.doc2vec import TaggedDocument
from collections import namedtuple

SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment')

alldocs = []  # will hold all docs in original order
with open('aclImdb/alldata-id.txt', encoding='utf-8') as alldata:
    for line_no, line in enumerate(alldata):
        tokens = gensim.utils.to_unicode(line).split()
        words = tokens[1:]
        tags = [line_no] # `tags = [tokens[0]]` would also work at extra memory cost
        split = ['train','test','extra','extra'][line_no//25000]  # 25k train, 25k test, 25k extra
        sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown
        alldocs.append(SentimentDocument(words, tags, split, sentiment))

train_docs = [doc for doc in alldocs if doc.split == 'train']
test_docs = [doc for doc in alldocs if doc.split == 'test']
doc_list = alldocs[:]  # for reshuffling per pass

print('%d docs: %d train-sentiment, %d test-sentiment' % (len(doc_list), len(train_docs), len(test_docs)))
100000 docs: 25000 train-sentiment, 25000 test-sentiment

Set-up Doc2Vec Training & Evaluation Models

Approximating experiment of Le & Mikolov "Distributed Representations of Sentences and Documents", also with guidance from Mikolov's example go.sh:

./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1

Parameter choices below vary:

  • 100-dimensional vectors, as the 400d vectors of the paper don't seem to offer much benefit on this task
  • similarly, frequent word subsampling seems to decrease sentiment-prediction accuracy, so it's left out
  • cbow=0 means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with dm=0
  • added to that DBOW model are two DM models, one which averages context vectors (dm_mean) and one which concatenates them (dm_concat, resulting in a much larger, slower, more data-hungry model)
  • a min_count=2 saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves)
In [2]:
from gensim.models import Doc2Vec
import gensim.models.doc2vec
from collections import OrderedDict
import multiprocessing

cores = multiprocessing.cpu_count()
assert gensim.models.doc2vec.FAST_VERSION > -1, "this will be painfully slow otherwise"

simple_models = [
    # PV-DM w/concatenation - window=5 (both sides) approximates paper's 10-word total window size
    Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DBOW 
    Doc2Vec(dm=0, size=100, negative=5, hs=0, min_count=2, workers=cores),
    # PV-DM w/average
    Doc2Vec(dm=1, dm_mean=1, size=100, window=10, negative=5, hs=0, min_count=2, workers=cores),
]

# speed setup by sharing results of 1st model's vocabulary scan
simple_models[0].build_vocab(alldocs)  # PV-DM/concat requires one special NULL word so it serves as template
print(simple_models[0])
for model in simple_models[1:]:
    model.reset_from(simple_models[0])
    print(model)

models_by_name = OrderedDict((str(model), model) for model in simple_models)
Doc2Vec(dm/c,d100,n5,w5,mc2,t4)
Doc2Vec(dbow,d100,n5,mc2,t4)
Doc2Vec(dm/m,d100,n5,w10,mc2,t4)

Following the paper, we also evaluate models in pairs. These wrappers return the concatenation of the vectors from each model. (Only the singular models are trained.)

In [5]:
from gensim.test.test_doc2vec import ConcatenatedDoc2Vec
models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[2]])
models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[0]])

Predictive Evaluation Methods

Helper methods for evaluating error rate.

In [8]:
import numpy as np
import statsmodels.api as sm
from random import sample

# for timing
from contextlib import contextmanager
from timeit import default_timer
import time 

@contextmanager
def elapsed_timer():
    start = default_timer()
    elapser = lambda: default_timer() - start
    yield lambda: elapser()
    end = default_timer()
    elapser = lambda: end-start
    
def logistic_predictor_from_data(train_targets, train_regressors):
    logit = sm.Logit(train_targets, train_regressors)
    predictor = logit.fit(disp=0)
    #print(predictor.summary())
    return predictor

def error_rate_for_model(test_model, train_set, test_set, infer=False, infer_steps=3, infer_alpha=0.1, infer_subsample=0.1):
    """Report error rate on test_doc sentiments, using supplied model and train_docs"""

    train_targets, train_regressors = zip(*[(doc.sentiment, test_model.docvecs[doc.tags[0]]) for doc in train_set])
    train_regressors = sm.add_constant(train_regressors)
    predictor = logistic_predictor_from_data(train_targets, train_regressors)

    test_data = test_set
    if infer:
        if infer_subsample < 1.0:
            test_data = sample(test_data, int(infer_subsample * len(test_data)))
        test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data]
    else:
        test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs]
    test_regressors = sm.add_constant(test_regressors)
    
    # predict & evaluate
    test_predictions = predictor.predict(test_regressors)
    corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data])
    errors = len(test_predictions) - corrects
    error_rate = float(errors) / len(test_predictions)
    return (error_rate, errors, len(test_predictions), predictor)

Bulk Training

Using explicit multiple-pass, alpha-reduction approach as sketched in gensim doc2vec blog post – with added shuffling of corpus on each pass.

Note that vector training is occurring on all documents of the dataset, which includes all TRAIN/TEST/DEV docs.

Evaluation of each model's sentiment-predictive power is repeated after each pass, as an error rate (lower is better), to see the rates-of-relative-improvement. The base numbers reuse the TRAIN and TEST vectors stored in the models for the logistic regression, while the inferred results use newly-inferred TEST vectors.

(On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.)

In [9]:
from collections import defaultdict
best_error = defaultdict(lambda :1.0)  # to selectively-print only best errors achieved
In [10]:
from random import shuffle
import datetime

alpha, min_alpha, passes = (0.025, 0.001, 20)
alpha_delta = (alpha - min_alpha) / passes

print("START %s" % datetime.datetime.now())

for epoch in range(passes):
    shuffle(doc_list)  # shuffling gets best results
    
    for name, train_model in models_by_name.items():
        # train
        duration = 'na'
        train_model.alpha, train_model.min_alpha = alpha, alpha
        with elapsed_timer() as elapsed:
            train_model.train(doc_list)
            duration = '%.1f' % elapsed()
            
        # evaluate
        eval_duration = ''
        with elapsed_timer() as eval_elapsed:
            err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs)
        eval_duration = '%.1f' % eval_elapsed()
        best_indicator = ' '
        if err <= best_error[name]:
            best_error[name] = err
            best_indicator = '*' 
        print("%s%f : %i passes : %s %ss %ss" % (best_indicator, err, epoch + 1, name, duration, eval_duration))

        if ((epoch + 1) % 5) == 0 or epoch == 0:
            eval_duration = ''
            with elapsed_timer() as eval_elapsed:
                infer_err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs, infer=True)
            eval_duration = '%.1f' % eval_elapsed()
            best_indicator = ' '
            if infer_err < best_error[name + '_inferred']:
                best_error[name + '_inferred'] = infer_err
                best_indicator = '*'
            print("%s%f : %i passes : %s %ss %ss" % (best_indicator, infer_err, epoch + 1, name + '_inferred', duration, eval_duration))

    print('completed pass %i at alpha %f' % (epoch + 1, alpha))
    alpha -= alpha_delta
    
print("END %s" % str(datetime.datetime.now()))
START 2015-06-28 20:34:29.500839
*0.417080 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 84.5s 1.0s
*0.363200 : 1 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8)_inferred 84.5s 14.9s
*0.219520 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.0s 0.6s
*0.184000 : 1 passes : Doc2Vec(dbow,d100,n5,mc2,t8)_inferred 19.0s 4.6s
*0.277080 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.0s 0.6s
*0.230800 : 1 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8)_inferred 35.0s 6.4s
*0.207840 : 1 passes : dbow+dmm 0.0s 1.5s
*0.185200 : 1 passes : dbow+dmm_inferred 0.0s 11.2s
*0.220720 : 1 passes : dbow+dmc 0.0s 1.1s
*0.189200 : 1 passes : dbow+dmc_inferred 0.0s 19.3s
completed pass 1 at alpha 0.025000
*0.357120 : 2 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 73.1s 0.6s
*0.144360 : 2 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.8s 0.6s
*0.225640 : 2 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 36.2s 1.0s
*0.141160 : 2 passes : dbow+dmm 0.0s 1.1s
*0.144800 : 2 passes : dbow+dmc 0.0s 1.2s
completed pass 2 at alpha 0.023800
*0.326840 : 3 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 73.6s 0.6s
*0.125880 : 3 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 20.1s 0.7s
*0.202680 : 3 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 36.0s 0.6s
*0.123280 : 3 passes : dbow+dmm 0.0s 1.6s
*0.126040 : 3 passes : dbow+dmc 0.0s 1.2s
completed pass 3 at alpha 0.022600
*0.302360 : 4 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 72.6s 0.6s
*0.113640 : 4 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.9s 0.7s
*0.189880 : 4 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.8s 0.6s
*0.114200 : 4 passes : dbow+dmm 0.0s 1.2s
*0.115640 : 4 passes : dbow+dmc 0.0s 1.6s
completed pass 4 at alpha 0.021400
*0.281480 : 5 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 72.7s 0.7s
*0.109720 : 5 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 21.5s 0.7s
*0.181360 : 5 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 37.8s 0.7s
*0.109760 : 5 passes : dbow+dmm 0.0s 1.3s
*0.110400 : 5 passes : dbow+dmc 0.0s 1.6s
completed pass 5 at alpha 0.020200
*0.264640 : 6 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 72.0s 0.7s
*0.292000 : 6 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8)_inferred 72.0s 13.3s
*0.107440 : 6 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 21.6s 0.7s
*0.116000 : 6 passes : Doc2Vec(dbow,d100,n5,mc2,t8)_inferred 21.6s 4.7s
*0.176040 : 6 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 37.4s 1.1s
*0.213600 : 6 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8)_inferred 37.4s 6.4s
*0.107000 : 6 passes : dbow+dmm 0.0s 1.2s
*0.108000 : 6 passes : dbow+dmm_inferred 0.0s 11.2s
*0.107880 : 6 passes : dbow+dmc 0.0s 1.2s
*0.124400 : 6 passes : dbow+dmc_inferred 0.0s 18.3s
completed pass 6 at alpha 0.019000
*0.254200 : 7 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 65.7s 1.1s
*0.106720 : 7 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.5s 0.7s
*0.172880 : 7 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.6s 0.7s
*0.106080 : 7 passes : dbow+dmm 0.0s 1.2s
*0.106320 : 7 passes : dbow+dmc 0.0s 1.2s
completed pass 7 at alpha 0.017800
*0.245880 : 8 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 68.6s 0.7s
*0.104920 : 8 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 20.0s 1.0s
*0.171000 : 8 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.4s 0.7s
*0.104760 : 8 passes : dbow+dmm 0.0s 1.3s
*0.105600 : 8 passes : dbow+dmc 0.0s 1.3s
completed pass 8 at alpha 0.016600
*0.238400 : 9 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 66.1s 0.6s
*0.104520 : 9 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 21.2s 1.1s
*0.167600 : 9 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 37.5s 0.7s
*0.103680 : 9 passes : dbow+dmm 0.0s 1.2s
*0.103480 : 9 passes : dbow+dmc 0.0s 1.2s
completed pass 9 at alpha 0.015400
*0.232160 : 10 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 69.0s 0.7s
*0.103680 : 10 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 21.8s 0.7s
*0.166000 : 10 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.4s 1.1s
*0.101920 : 10 passes : dbow+dmm 0.0s 1.2s
 0.103560 : 10 passes : dbow+dmc 0.0s 1.2s
completed pass 10 at alpha 0.014200
*0.227760 : 11 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 66.4s 0.7s
*0.242400 : 11 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8)_inferred 66.4s 13.0s
*0.102160 : 11 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.7s 0.6s
*0.113200 : 11 passes : Doc2Vec(dbow,d100,n5,mc2,t8)_inferred 19.7s 5.0s
*0.163480 : 11 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.4s 0.6s
*0.208800 : 11 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8)_inferred 35.4s 6.2s
*0.101560 : 11 passes : dbow+dmm 0.0s 1.2s
*0.102000 : 11 passes : dbow+dmm_inferred 0.0s 11.4s
*0.101920 : 11 passes : dbow+dmc 0.0s 1.6s
*0.109600 : 11 passes : dbow+dmc_inferred 0.0s 17.4s
completed pass 11 at alpha 0.013000
*0.225960 : 12 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 61.8s 0.7s
*0.101720 : 12 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 20.2s 0.7s
*0.163000 : 12 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.5s 0.7s
*0.100840 : 12 passes : dbow+dmm 0.0s 1.2s
*0.101920 : 12 passes : dbow+dmc 0.0s 1.7s
completed pass 12 at alpha 0.011800
*0.222360 : 13 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 65.2s 0.7s
 0.103120 : 13 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 20.0s 0.7s
*0.161960 : 13 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.2s 0.6s
 0.101640 : 13 passes : dbow+dmm 0.0s 1.2s
 0.102600 : 13 passes : dbow+dmc 0.0s 1.2s
completed pass 13 at alpha 0.010600
*0.220960 : 14 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 65.3s 1.1s
 0.102920 : 14 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.9s 0.7s
*0.160160 : 14 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 36.0s 0.7s
 0.101720 : 14 passes : dbow+dmm 0.0s 1.2s
 0.102560 : 14 passes : dbow+dmc 0.0s 1.2s
completed pass 14 at alpha 0.009400
*0.219400 : 15 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 64.0s 1.0s
*0.101440 : 15 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.5s 0.7s
 0.160640 : 15 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 38.6s 0.7s
*0.100160 : 15 passes : dbow+dmm 0.0s 1.2s
*0.101880 : 15 passes : dbow+dmc 0.0s 1.3s
completed pass 15 at alpha 0.008200
*0.216880 : 16 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 64.1s 1.1s
*0.232400 : 16 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8)_inferred 64.1s 12.8s
 0.101760 : 16 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.1s 0.7s
*0.111600 : 16 passes : Doc2Vec(dbow,d100,n5,mc2,t8)_inferred 19.1s 4.7s
*0.159800 : 16 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 34.9s 0.6s
*0.184000 : 16 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8)_inferred 34.9s 6.5s
 0.100640 : 16 passes : dbow+dmm 0.0s 1.6s
*0.094800 : 16 passes : dbow+dmm_inferred 0.0s 11.7s
*0.101320 : 16 passes : dbow+dmc 0.0s 1.2s
 0.109600 : 16 passes : dbow+dmc_inferred 0.0s 17.5s
completed pass 16 at alpha 0.007000
 0.217160 : 17 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 58.6s 0.6s
 0.101760 : 17 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.5s 0.7s
*0.159640 : 17 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 37.0s 1.1s
 0.100760 : 17 passes : dbow+dmm 0.0s 1.3s
 0.101480 : 17 passes : dbow+dmc 0.0s 1.3s
completed pass 17 at alpha 0.005800
*0.216080 : 18 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 60.7s 0.6s
 0.101520 : 18 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.6s 0.6s
*0.158760 : 18 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 34.9s 1.0s
 0.100800 : 18 passes : dbow+dmm 0.0s 1.2s
 0.101760 : 18 passes : dbow+dmc 0.0s 1.2s
completed pass 18 at alpha 0.004600
*0.215560 : 19 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 62.6s 0.7s
*0.101000 : 19 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 20.6s 0.7s
 0.159080 : 19 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 35.9s 0.7s
*0.099920 : 19 passes : dbow+dmm 0.0s 1.7s
 0.102280 : 19 passes : dbow+dmc 0.0s 1.2s
completed pass 19 at alpha 0.003400
*0.215160 : 20 passes : Doc2Vec(dm/c,d100,n5,w5,mc2,t8) 58.3s 0.6s
 0.101360 : 20 passes : Doc2Vec(dbow,d100,n5,mc2,t8) 19.5s 0.7s
 0.158920 : 20 passes : Doc2Vec(dm/m,d100,n5,w10,mc2,t8) 33.6s 0.6s
 0.100480 : 20 passes : dbow+dmm 0.0s 1.5s
 0.102160 : 20 passes : dbow+dmc 0.0s 1.1s
completed pass 20 at alpha 0.002200
END 2015-06-28 21:20:48.994706

Achieved Sentiment-Prediction Accuracy

In [12]:
# print best error rates achieved
for rate, name in sorted((rate, name) for name, rate in best_error.items()):
    print("%f %s" % (rate, name))
0.094800 dbow+dmm_inferred
0.099920 dbow+dmm
0.101000 Doc2Vec(dbow,d100,n5,mc2,t8)
0.101320 dbow+dmc
0.109600 dbow+dmc_inferred
0.111600 Doc2Vec(dbow,d100,n5,mc2,t8)_inferred
0.158760 Doc2Vec(dm/m,d100,n5,w10,mc2,t8)
0.184000 Doc2Vec(dm/m,d100,n5,w10,mc2,t8)_inferred
0.215160 Doc2Vec(dm/c,d100,n5,w5,mc2,t8)
0.232400 Doc2Vec(dm/c,d100,n5,w5,mc2,t8)_inferred

In my testing, unlike the paper's report, DBOW performs best. Concatenating vectors from different models only offers a small predictive improvement. The best results I've seen are still just under 10% error rate, still a ways from the paper's 7.42%.

Examining Results

Are inferred vectors close to the precalculated ones?

In [13]:
doc_id = np.random.randint(simple_models[0].docvecs.count)  # pick random doc; re-run cell for more examples
print('for doc %d...' % doc_id)
for model in simple_models:
    inferred_docvec = model.infer_vector(alldocs[doc_id].words)
    print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3)))
for doc 25430...
Doc2Vec(dm/c,d100,n5,w5,mc2,t8):
 [(25430, 0.6583491563796997), (27314, 0.4142411947250366), (16479, 0.40846431255340576)]
Doc2Vec(dbow,d100,n5,mc2,t8):
 [(25430, 0.9325973987579346), (49281, 0.5766637921333313), (79679, 0.5634804964065552)]
Doc2Vec(dm/m,d100,n5,w10,mc2,t8):
 [(25430, 0.7970066666603088), (97818, 0.6925815343856812), (230, 0.690807580947876)]

(Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Note the defaults for inference are very abbreviated – just 3 steps starting at a high alpha – and likely need tuning for other applications.)

In [14]:
import random

doc_id = np.random.randint(simple_models[0].docvecs.count)  # pick random doc, re-run cell for more examples
model = random.choice(simple_models)  # and a random model
sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count)  # get *all* similar documents
print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words)))
print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model)
for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]:
    print(u'%s %s: «%s»\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words)))
TARGET (72927): «this is one of the best films of this year . for a year that was fueled by controversy and crap , it was nice to finally see a film that had a true heart to it . from the opening scene to the end , i was so moved by the love that will smith has for his son . basically , if you see this movie and walk out of it feeling nothing , there is something that is very wrong with you . loved this movie , it's the perfect movie to end the year with . the best part was after the movie , my friends and i all got up and realized that this movie had actually made the four of us tear up ! it's an amazing film and if will smith doesn't get at least an oscar nom , then the oscars will just suck . in fact will smith should actually just win an oscar for this role . ! ! ! i loved this movie ! ! ! ! everybody needs to see especially the people in this world that take everything for granted , watch this movie , it will change you !»

SIMILAR/DISSIMILAR DOCS PER MODEL Doc2Vec(dm/m,d100,n5,w10,mc2,t8):

MOST (2046, 0.7372332215309143): «i thought this movie would be dumb , but i really liked it . people i know hate it because spirit was the only horse that talked . well , so what ? the songs were good , and the horses didn't need to talk to seem human . i wouldn't care to own the movie , and i would love to see it again . 8/10»

MEDIAN (6999, 0.4129640758037567): «okay , the recent history of star trek has not been good . the next generation faded in its last few seasons , ds9 boldly stayed where no one had stayed before , and voyager started very bad and never really lived up to its promise . so , when they announced a new star trek series , i did not have high expectations . and , the first episode , broken bow , did have some problems . but , overall it was solid trek material and a good romp . i'll get the nits out of the way first . the opening theme is dull and i don't look forward to sitting through it regularly , but that's what remotes are for . what was really bad was the completely gratuitous lotion rubbing scene that just about drove my wife out of the room . they need to cut that nonsense out . but , the plot was strong and moved along well . the characters , though still new , seem to be well rounded and not always what you would expect . the vulcans are clearly being presented very differently than before , with a slightly ominous theme . i particularly liked the linguist , who is the first star trek character to not be able to stand proud in the face of death , but rather has to deal with her phobias and fears . they seemed to stay true to trek lore , something that has been a significant problem in past series , though they have plenty of time to bring us things like shooting through shields , the instant invention of technology that can fix anything , and the inevitable plethora of time-travel stories . anyone want to start a pool on how long before the borg show up ? all in all , the series has enormous potential . they are seeing the universe with fresh eyes . we have the chance to learn how things got the way they were in the later series . how did the klingons go from just insulting to war ? how did we meet the romulans ? how did the federation form and just who put earth in charge . why is the prime directive so important ? if they address these things rather than spitting out time travel episodes , this will be an interesting series . my favorite line : zephram cochran saying " where no man has gone before " ( not " no one " )»

LEAST (16617, 0.015464222989976406): «i saw this movie during a tolkien-themed interim class during my sophomore year of college . i was seated unfortunately close to the screen and my professor chose me to serve as a whipping boy- everyone else was laughing , but they weren't within constant eyesight . let's get it out of the way : the peter jackson 'lord of the rings' films do owe something to the bakshi film . in jackson's version of the fellowship of the ring , for instance , the scene in which the black riders assault the empty inn beds is almost a complete carbon copy of the scene in bakshi's film , shot by shot . you could call this plagiarism or homage , depending on your agenda . i'm sure the similarities don't stop there . i'm not going to do any research to find out what they are , because that would imply i have some mote of respect for this film . i'm sure others have outlined the similarities- look around . this movie is a complete train wreck in every sense of the metaphor , and many , many people died in the accident . i've decided to list what i can remember in a more or less chronological fashion- if i've left out anything else that offended me it's because i'm completely overwhelmed , confronted with a wealth of failure ( and , at high points , mediocrity ) . *due to heavy use of rotoscoping , gandalf is no longer a gentle , wise wizard but a wildly flailing prophet of doom ( whose hat inexplicably changes color once or twice during the course of the film ) . *saruman the white is sometimes referred to as 'aruman' during the film , without explanation . he wears purple and red for some mysterious reason . *sam is flat out hideous . the portrayal of his friendship with frodo is strangely childlike and unsatisfying . yes , hobbits are small like children , but they are not children . *merry and pippin are never introduced--they simply appear during a scene change with a one-sentence explanation . the film is filled with sloppy editing like this . *frodo , sam , pippin and merry are singing merrily as they skip through along the road . one of the hobbits procures a lute at least twice as large as he is from behind his back--which was not visible before--and begins strumming in typical fantasy bard fashion as they all break into " la-la-la " s . awful . *aragorn , apparently , is a native american dressed in an extremely stereotypical fantasy tunic ( no pants ) , complete with huge , square pilgrim belt buckle . he is arguably the worst swordsman in the entire movie--oftentimes he gets one wobbly swing in before being knocked flat on his ass . *the black riders appear more like lepers than menacing instruments of evil . they limp everywhere they go at a painfully slow pace . this is disturbing to be sure , but not frightening . *the scene before the black riders attempt to cross the ford of bruinen ( in which they stare at frodo , who is on the other side on horseback ) goes on forever , during which time the riders rear their horses in a vaguely threatening manner and . . . do nothing else . the scene was probably intended to illustrate frodo's hallucinatory decline as he succumbs to his wound . it turns out to be more plodding than anything else . *gimli the dwarf is just as tall as legolas the elf . he's a dwarf . there is simply no excuse for that . he also looks like a bastardized david the gnome . it's a crude but accurate description . *boromir appears to have pilfered elmer fudd's golden viking armor from that bugs bunny opera episode . he looks ridiculous . *despite the similarity to tolkien's illustration , the balrog is howl inducing and the least-threatening villain in the entire film . it looks like someone wearing pink bedroom slippers , and it's barely taller than gandalf . " purists " may prefer this balrog , but i'll take jackson's version any day . *the battle scenes are awkward and embarrassing . almost none of the characters display any level of competency with their armaments . i'm not asking for action-packed scenes like those in jackson's film , but they are supposed to be fighting . *treebeard makes a very short appearance , and i was sorry he bothered to show up at all . watch the film , you'll see what i mean . alright , now for the good parts of the film . *some of the voice acting is pretty good . it isn't that aragorn sounds bad , he just looks kind of like the jolly green giant . *galadriel is somewhat interesting in this portrayal ; like tom bombadil , she seems immune to the ring's powers of temptation , and her voice actress isn't horrible either . *boromir's death isn't as heart wrenching as in jackson's portrayal of the same scene , but it's still appropriately dramatic ( and more true to his death in the book , though i don't believe jackson made a mistake shooting it the way he did ) . *as my professor pointed out ( between whispered threats ) , the orcs ( mainly at helm's deep , if i'm correct ) resemble the war-ravaged corpses of soldiers , a political statement that works pretty well if you realize what's being attempted . *while this isn't really a positive point about the film , bakshi can't be blamed for the majority of the failures in this movie , or so i've been told--the project was on a tight budget , and late in its production he lost creative control to some of the higher-ups ( who i'm sure hadn't read the books ) . let me be clear : i respect bakshi for even attempting something of this magnitude . i simply have a hard time believing he was happy with the final product . overall , i cannot in any way recommend this blasphemous adaptation of tolkien's classic trilogy even for laughs , unless you've already read the books and have your own visualizations of the characters , places and events . i'm sure somebody , somewhere , will pick a copy of this up in confusion ; if you do , keep an open mind and glean what good you can from it .»

(Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST.)

Do the word vectors show useful similarities?

In [15]:
word_models = simple_models[:]
In [17]:
import random
from IPython.display import HTML
# pick a random word with a suitable number of occurences
while True:
    word = random.choice(word_models[0].index2word)
    if word_models[0].vocab[word].count > 10:
        break
# or uncomment below line, to just pick a word from the relevant domain:
#word = 'comedy/drama'
similars_per_model = [str(model.most_similar(word, topn=20)).replace('), ','),<br>\n') for model in word_models]
similar_table = ("<table><tr><th>" +
    "</th><th>".join([str(model) for model in word_models]) + 
    "</th></tr><tr><td>" +
    "</td><td>".join(similars_per_model) +
    "</td></tr></table>")
print("most similar words for '%s' (%d occurences)" % (word, simple_models[0].vocab[word].count))
HTML(similar_table)
most similar words for 'comedy/drama' (38 occurences)
Out[17]:
Doc2Vec(dm/c,d100,n5,w5,mc2,t8)Doc2Vec(dbow,d100,n5,mc2,t8)Doc2Vec(dm/m,d100,n5,w10,mc2,t8)
[('comedy', 0.7255545258522034),
('thriller', 0.6946465969085693),
('drama', 0.6763534545898438),
('romance', 0.6251884698867798),
('dramedy', 0.6217159032821655),
('melodrama', 0.6156137585639954),
('adventure', 0.6091135740280151),
('farce', 0.6034293174743652),
('chiller', 0.5948368906974792),
('romantic-comedy', 0.5876704454421997),
('fantasy', 0.5863304138183594),
('mystery/comedy', 0.577541708946228),
('whodunit', 0.572147011756897),
('biopic', 0.5679721832275391),
('thriller/drama', 0.5630226731300354),
('sitcom', 0.5574496984481812),
('slash-fest', 0.5573585033416748),
('mystery', 0.5542301535606384),
('potboiler', 0.5519827604293823),
('mockumentary', 0.5490710139274597)]
[('1000%', 0.42290645837783813),
("gymnast's", 0.4180164337158203),
('hollywoodland', 0.3898555636405945),
('cultures', 0.3857914209365845),
('hooda', 0.3851744532585144),
('cites', 0.38047513365745544),
("78's", 0.3792475461959839),
("dormael's", 0.3775535225868225),
('jokester', 0.3725704252719879),
('impelled', 0.36853262782096863),
('lia', 0.3684236407279968),
('snivelling', 0.3683513104915619),
('astral', 0.36715900897979736),
('euro-exploitation', 0.35853487253189087),
("serra's", 0.3578598201274872),
('down-on-their-luck', 0.3576606214046478),
('rowles', 0.3567575514316559),
('romantica', 0.3549702763557434),
('bonham-carter', 0.354231059551239),
('1877', 0.3541453182697296)]
[('comedy-drama', 0.6274900436401367),
('comedy', 0.5986765623092651),
('thriller', 0.5765297412872314),
('road-movie', 0.5615973472595215),
('dramedy', 0.5580120086669922),
('time-killer', 0.5497636795043945),
('potboiler', 0.5456510782241821),
('comedy/', 0.5439876317977905),
('actioner', 0.5423712134361267),
('diversion', 0.541743278503418),
('romcom', 0.5402226448059082),
('rom-com', 0.5358527302742004),
('drama', 0.5320745706558228),
('chiller', 0.5229591727256775),
('romp', 0.5228806734085083),
('horror/comedy', 0.5219299793243408),
('weeper', 0.5195824503898621),
('mockumentary', 0.5149033069610596),
('camp-fest', 0.5122634768486023),
('mystery/comedy', 0.5020694732666016)]

Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the dbow_words=1 initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task.

Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word vector training concurrent with doc vector training.)

Are the word vectors from this dataset any good at analogies?

In [26]:
# assuming something like
# https://word2vec.googlecode.com/svn/trunk/questions-words.txt 
# is in local directory
# note: this takes many minutes
for model in word_models:
    sections = model.accuracy('questions-words.txt')
    correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect'])
    print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect))
Doc2Vec(dm/c,d100,n5,w5,mc2,t8): 28.70% correct (2873 of 10012)
Doc2Vec(dbow,d100,n5,mc2,t8): 0.01% correct (1 of 10012)
Doc2Vec(dm/m,d100,n5,w10,mc2,t8): 27.24% correct (2727 of 10012)

Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/concat and DM/mean models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.)

Slop

In [ ]:
This cell left intentionally erroneous. 

To mix the Google dataset (if locally available) into the word tests...

In [ ]:
from gensim.models import Word2Vec
w2v_g100b = Word2Vec.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
w2v_g100b.compact_name = 'w2v_g100b'
word_models.append(w2v_g100b)

To get copious logging output from above steps...

In [ ]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)

To auto-reload python code while developing...

In [ ]:
%load_ext autoreload
%autoreload 2