Giter Site home page Giter Site logo

word2gm's Introduction

Word2GM (Word to Gaussian Mixture)

This is an implementation of the model in Athiwaratkun and Wilson, Multimodal Word Distributions, ACL 2017.

We represent each word in the dictionary as a Gaussian Mixture distribution and train it using a max-margin objective based on expected likelihood kernel energy function.

The BibTeX entry for the paper is:

@InProceedings{athiwilson2017,
    author = {Ben Athiwaratkun and Andrew Gordon Wilson},
    title = {Multimodal Word Distributions},
    booktitle = {Conference of the Association for Computational Linguistics (ACL)},
    year = {2017}
}

Updates

Feb 27 2018: We updated the code to be compatible with tensorflow 1.0+. Training on large datasets also no longer need tf installation from source. In this version, we provide modified skipgram c ops to handle large dataset training.

Dependencies

This code is tested on Tensorflow 1.5.0. The code should be compatible with Tensorflow 1.0 and above.

Note:This repository was previously compatible with Tensorflow 0.12 but the support for pre tf1.0 will not be maintained. However, you can access it at this commit.

For plotting, we use ggplot

pip install -U ggplot
# or 
conda install -c conda-forge ggplot
# or
pip install git+https://github.com/yhat/ggplot.git

Training Data

The data used in the paper is the concatenation of ukWaC and WaCkypedia_EN, both of which can be requested here.

We include a script get_text8.sh to download a small dataset text8 which can be used to train word embeddings. We note that we can observe the polysemies behaviour even on a small dataset such as text8. That is, some word such as 'rock' has one Gaussian component being close to 'jazz', 'pop', 'blue' and another Gaussian component close to 'stone', 'sediment', 'basalt', etc.

Training

For text8, the training script with the proper hyperparameters are in train_text8.sh

For UKWAC+Wackypedia, the training script train_wac.sh contains our command to replicate the results.

Steps

Below are the steps for training and visualization with text8 dataset.

  1. Compile C skipgram module for tensorflow training. This generates word2vec_ops.so file which we will use when we import this module in the python code. Note that this version of the code supports training on large datasets without compiling the entire Tensorflow library from source (unlike in the previous version of our code).
chmod +x compile_skipgram_ops.sh
./compile_skipgram_ops.sh
  1. Obtain the dataset and train.
bash get_text8.sh
python word2gm_trainer.py --num_mixtures 2 --train_data data/text8 --spherical --embedding_size 50 --epochs_to_train 10 --var_scale 0.05 --save_path modelfiles/t8-2s-e10-v05-lr05d-mc100-ss5-nwout-adg-win10 --learning_rate 0.05  --subsample 1e-5 --adagrad  --min_count 5 --batch_size 128 --max_to_keep 100 --checkpoint_interval 500 --window_size 10
# or simply calling ./train_text8.sh

See at the end of page for details on training options.

  1. Note that the model will be saved at modelfiles/t8-2s-e10-v05-lr05d-mc100-ss5-nwout-adg-win10. The code to analyze the model and visualize the results is in Analyze Text8 Model.ipynb. See model API below.

  2. We can visualize the word embeddings itself by executing the following command in iPynb:

w2gm_text8_2s.visualize_embeddings()

This command prepares the word embeddings to be visualized by Tensorflow's Tensorboard. Once the embeddings are prepared, the visualization can be done by shell command:

tensorboard --logdir=modelfiles/t8-2s-e10-v05-lr05d-mc100-ss5-nwout-adg-win10_emb --port=6006

Then, navigate the browser to (http://localhost/6006) (or a url of the appropriate machine that has the model) and click at the Embeddings tab. Note that the logdir folder is the "original-folder" + "_emb".

Visualization

The Tensorboard embeddings visualization tools (please use Firefox or Chrome) allow for nearest neighbors query, in addition to PCA and t-sne visualization. We use the following notation: x:i refers to the ith mixture component of word 'x'. For instance, querying for 'bank:0' yields 'river:1', 'confluence:0', 'waterway:1' as the nearest neighbors, which means that this component of 'bank' corresponds to river bank. On the other hand, querying for 'bank:1' gives the nearest neighbors 'banking:1', 'banker:0', 'ATM:0', which indicates that this component of 'bank' corresponds to financial bank.

Trained Model

We provide a trained model for K=2 here. To analyze the model, see Analyze Model.ipynb. The code expects the model to be extracted to directory modelfiles/w2gm-k2-d50/.

Training Options

arguments:
  -h, --help            show this help message and exit
  --save_path SAVE_PATH
                        Directory to write the model and training summaries.
                        (required)
  --train_data TRAIN_DATA
                        Training text file. (required)
  --embedding_size EMBEDDING_SIZE
                        The embedding dimension size.
  --epochs_to_train EPOCHS_TO_TRAIN
                        Number of epochs to train. Each epoch processes the
                        training data once completely.
  --learning_rate LEARNING_RATE
                        Initial learning rate.
  --batch_size BATCH_SIZE
                        Number of training examples processed per step (size
                        of a minibatch).
  --concurrent_steps CONCURRENT_STEPS
                        The number of concurrent training steps.
  --window_size WINDOW_SIZE
                        The number of words to predict to the left and right
                        of the target word.
  --min_count MIN_COUNT
                        The minimum number of word occurrences for it to be
                        included in the vocabulary.
  --subsample SUBSAMPLE
                        Subsample threshold for word occurrence. Words that
                        appear with higher frequency will be randomly down-
                        sampled. Set to 0 to disable.
  --statistics_interval STATISTICS_INTERVAL
                        Print statistics every n seconds.
  --summary_interval SUMMARY_INTERVAL
                        Save training summary to file every n seconds (rounded
                        up to statistics interval).
  --checkpoint_interval CHECKPOINT_INTERVAL
                        Checkpoint the model (i.e. save the parameters) every
                        n seconds (rounded up to statistics interval).
  --num_mixtures NUM_MIXTURES
                        Number of mixture component for Mixture of Gaussians
  --spherical [SPHERICAL]
                        Whether the model should be spherical of diagonalThe
                        default is spherical
  --nospherical
  --var_scale VAR_SCALE
                        Variance scale
  --ckpt_all [CKPT_ALL]
                        Keep all checkpoints(Warning: This requires a large
                        amount of disk space).
  --nockpt_all
  --norm_cap NORM_CAP   The upper bound of norm of mean vector
  --lower_sig LOWER_SIG
                        The lower bound for sigma element-wise
  --upper_sig UPPER_SIG
                        The upper bound for sigma element-wise
  --mu_scale MU_SCALE   The average norm will be around mu_scale
  --objective_threshold OBJECTIVE_THRESHOLD
                        The threshold for the objective
  --adagrad [ADAGRAD]   Use Adagrad optimizer instead
  --noadagrad
  --loss_epsilon LOSS_EPSILON
                        epsilon parameter for loss function
  --constant_lr [CONSTANT_LR]
                        Use constant learning rate
  --noconstant_lr
  --wout [WOUT]         Whether we would use a separate wout
  --nowout
  --max_pe [MAX_PE]     Using maximum of partial energy instead of the sum
  --nomax_pe
  --max_to_keep MAX_TO_KEEP
                        The maximum number of checkpoint files to keep
  --normclip [NORMCLIP]
                        Whether to perform norm clipping (very slow)
  --nonormclip

word2gm's People

Contributors

andrewgordonwilson avatar benathi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

word2gm's Issues

Increasing Throughput

In running word2gm_trainer.py on the text8 data (via train_text8.sh) I'm getting a throughput of about 6,500 words/sec. This seems a bit low-- is this about what you would expect? Is there something I'm missing that would help to increase the throughput?

Number of negative samples = 1: for efficiency?

Hi Ben,

Thank you for the open source code, it has been very helpful! One question however: I was wondering why in both the current paper and the prob-fast-text paper, the number of negative samples is hardcoded to 1? Was it due to efficiency reasons, or is there reason to believe this is optimal for the current model? Thank you!

Projectutil missing

From projectutil the code imports find_list_ckpts. But that module isn't made available or referred to.

Tensorflow Error

Hi,

Thank you for sharing this beautiful work. I tried to run the code but I received two errors for the following two commands:

  1. ./compile_skipgram_ops.sh
    It gives me the following error message:
    ./compile_skipgram_ops.sh: 1: Syntax error: "(" unexpected.

  2. I found word2vec_ops.so file from different github rep. and downloaded it. Then, I tried to run word2gm_trainer.py, I received the following error message:

Screen Shot 2022-04-13 at 21 44 26

Can you please help me?
Thanks in advance.

ImportError: No module named models.embedding

I ran the train_test8.sh and the following message popped up:
Traceback (most recent call last): File "word2gm_trainer.py", line 27, in <module> from tensorflow.models.embedding import gen_word2vec as word2vec ImportError: No module named models.embedding

Apparently it's related to the compilation of tutorials from the tensorflow models repo. So I have tried recompiling tensorflow from sources with the models in it, but in vain.

Any thoughts about this please ?

'module' object has no attribute 'SummaryWriter'

The code (w2gm_text8_2s.visualize_embeddings()) in Analyze Text8 Model.ipynb generates the following error.

376                 saver = tf.train.Saver()
377                 saver.save(session, os.path.join(emb_logdir, "model.ckpt"), 0)
378                 summary_writer = tf.summary.FileWriter(emb_logdir, session.g)
379                 config = projector.ProjectorConfig()
380                 embedding = config.embeddings.add()

Actually, the original code contains the tf.train.SummaryWriter. When I ran this, I am getting the same error. So I have changed into FileWriter. For this also, the program outputs the same error.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.