Giter Site home page Giter Site logo

dayihengliu / fine-grained-style-transfer Goto Github PK

View Code? Open in Web Editor NEW
21.0 2.0 6.0 1.22 MB

Code for AAAI2020 paper: "Revision in Continuous Space: Unsupervised Text Style Transfer without Adversarial Learning"

Home Page: https://arxiv.org/abs/1905.12304

Jupyter Notebook 77.62% Python 22.38%
text-sentiment-transfer text-rewriting

fine-grained-style-transfer's Introduction

Revision in Continuous Space: Unsupervised Text Style Transfer without Adversarial Learning

This repo contains the code and data of the following paper:

Revision in Continuous Space: Unsupervised Text Style Transfer without Adversarial Learning, Dayiheng Liu, Jie Fu, Yidan Zhang, Chris Pal, Jiancheng Lv, AAAI20 [arXiv]

Overview

We explore a novel task setting for text style transfer, in which it is required to simultaneously manipulate multiple fine-grained attributes. We propose to address it by revising the original sentences in a continuous space based on gradient-based optimization.

Dataset

Prerequisites

  • Jupyter notebook 4.4.0
  • Python 3.6
  • Tensorflow 1.6.0+
  • Numpy
  • nltk 3.3
  • kenlm 0.0.0
  • Moses

Usage

  • TextCNN.ipynb: Pretrain a Text-CNN on the train set for predictor training.
  • TextBiLSTM.ipynb: Pretrain a Text-BiLSTM on the whole dataset for evaluation
  • KenLM / Moses: Pretrain a language model.
  • Text_Style_Transfer_Pipeline.ipynb: The pipeline (training, inference, and evaluation) for text sentiment transfer and text gender style transfer.
  • Multi_Finegrained_Control.ipynb: The pipeline (training, and inference) for multiple fine-grained attributes control.
  • Eval_Multi.ipynb: The Evaluation of the multiple fine-grained attributes control.

Output Samples

To make it easier for other researchers to compare our methods, we release the outputs of our methods for YELP and AMAZON.

For each dataset, we provide three kinds of outputs (content-strengthen, content-style-balanced, and style-strengthen) of our method, which can be found in outputs/.

fine-grained-style-transfer's People

Contributors

dayihengliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

fine-grained-style-transfer's Issues

Error while training the TextCNN

Hello. I am facing an error when training the TextCNN.ipynb. The error is as below

ValueError: setting an array element with a sequence

and this is happening after the 1st epoch in the test method of TextCNN_Util, because of the line

self.D.dropout_input: 1.0

Could you please help me out.

Full error

Writing to Model/YELP/TextCNN/
Epoch 1/30 | Batch 0/1731 | Train_loss: 4.134 Acc 0.367 | Test_loss: 1.580 Acc 0.770 | Time_cost:0.271
Test Input: it was large and good enough for <UNK> | Output: 0 | GroundTruth: 1
Test Input: service was average but could not make up for the poor food and drink | Output: 0 | GroundTruth: 0
Test Input: this place is a terrible place to live | Output: 0 | GroundTruth: 0
Train Input: perfect spot for a date or a girls night out | Output: 0 | GroundTruth: 1
Train Input: they have a nice sized outdoor seating area with umbrellas at each table | Output: 0 | GroundTruth: 1

Epoch 1/30 | Batch 577/1731 | Train_loss: 1.478 Acc 0.761 | Test_loss: 1.308 Acc 0.746 | Time_cost:21.817
Test Input: <UNK> food was <UNK> <UNK> would go there | Output: 0 | GroundTruth: 1
Test Input: i ordered a cesar salad with a side of blackened chicken | Output: 1 | GroundTruth: 0
Test Input: <UNK> always deliver and keep promises very happy | Output: 1 | GroundTruth: 1
Train Input: the plus side of this restaurant is that it is fast and cheap | Output: 1 | GroundTruth: 0
Train Input: the italian sub is fantastic | Output: 1 | GroundTruth: 1

Epoch 1/30 | Batch 1154/1731 | Train_loss: 1.186 Acc 0.817 | Test_loss: 0.807 Acc 0.859 | Time_cost:20.641
Test Input: <UNK> hours of my life <UNK> was happy to give | Output: 1 | GroundTruth: 1
Test Input: the answer came | Output: 0 | GroundTruth: 1
Test Input: this room that he found also reeked of smoke | Output: 1 | GroundTruth: 0
Train Input: loved their 1/2 pound meatball | Output: 1 | GroundTruth: 1
Train Input: do not go to this restaurant if you ca n't use stairs | Output: 0 | GroundTruth: 0

ipykernel_launcher:36: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
ValueError: setting an array element with a sequence.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-8-914ea13f228a> in <module>
      1 util = TextCNN_Util(dp=dp, model=D)
----> 2 util.fit(train_dir=train_dir)

<ipython-input-4-6a835914f626> in fit(self, train_dir)
     54             tic = time.time()
     55             train_c_loss, train_acc = self.train(epoch)
---> 56             test_c_loss, test_acc = self.test()
     57             print("Epoch %d/%d | Train_loss: %.3f Acc %.3f | Test_loss: %.3f Acc %.3f" % 
     58                   (epoch, self.n_epoch, train_c_loss, train_acc, test_c_loss, test_acc))

<ipython-input-4-6a835914f626> in test(self)
     38                 {self.D.input_x: X_test_batch, 
     39                 self.D.input_y: C_test_batch,
---> 40                 self.D.dropout_input: 1})
     41             avg_c_loss += loss
     42             avg_acc += acc

~/.conda/envs/deep_seq_model/lib/python3.6/site-packages/tensorflow_core/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.conda/envs/deep_seq_model/lib/python3.6/site-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1147             feed_handles[subfeed_t] = subfeed_val
   1148           else:
-> 1149             np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
   1150 
   1151           if (not is_tensor_handle_feed and

~/.conda/envs/deep_seq_model/lib/python3.6/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     81 
     82     """
---> 83     return array(a, dtype, copy=False, order=order)
     84 
     85 

ValueError: setting an array element with a sequence.

Missing files

Hello. I am running your code and I see that these 4 files

w2id, id2w = pickle.load(open('/workspace/Data/yelp/w2id_id2w.pkl','rb'))
Y_train, C_train = pickle.load(open('/workspace/Data/yelp/XC_train.pkl','rb'))
Y_dev, C_dev = pickle.load(open('/workspace/Data/yelp/XC_dev.pkl','rb'))
Y_test, C_test = pickle.load(open('/workspace/Data/yelp/XC_test.pkl','rb'))

is missing in the TextCNN.ipynb and is not available in the YELP dataset which you directed. Could you please let me know how can I get it.

How to get the 'pkl' datasets

Hello, I have used all ways metioned in readme to get datasets. But the datasets such as 'w2id_id2w.pkl', 'w2id_id2w_indices_labels_all.pkl' and 'XC_train.pkl' which are used in the training codes can't be found. Could you please help me?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.