Giter Site home page Giter Site logo

zlinao / mintl Goto Github PK

View Code? Open in Web Editor NEW
66.0 66.0 16.0 30.65 MB

MinTL: Minimalist Transfer Learning for Task-Oriented Dialogue Systems

License: MIT License

Shell 1.32% Python 98.68%
language-model task-oriented-dialogue transfer-learning transformer

mintl's People

Contributors

zlinao avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

mintl's Issues

AttributeError: 'str' object has no attribute 'size'

thank you for this work.

During train the model, I got this error

Traceback (most recent call last):
  File "/content/drive/MyDrive/MinTL/train.py", line 388, in <module>
    main()
  File "/content/drive/MyDrive/MinTL/train.py", line 377, in main
    m.train()
  File "/content/drive/MyDrive/MinTL/train.py", line 130, in train
    lm_labels=inputs["response"]
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/drive/MyDrive/MinTL/T5.py", line 69, in forward
    head_mask=head_mask,
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py", line 924, in forward
    encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
AttributeError: 'str' object has no attribute 'size'

what can I do to solve it ? please

MultiWoZ 2.1

Can you please provide the data files for running on MultiWoZ 2.1, or the preprocessing scripts to generate annotated_user_da_with_span_full.json.zip

Evaluation and dialog state

Hi, first thank you for your work!

I am trying to understand the way you evaluate models (especially the inform and success rates on MultiWOZ) when using the MinTL framework. For generating a response, the model has to encode the previous dialogue state and the context, and predict the state update which is combined with the original state to form a new one. This new dialog state is used for querying the database etc., right?

During the evaluation (these lines?), is the model given the ground truth belief state from the previous turn, or does it use the "cummulative" one that was predicted in the previous turns of the particular conversation?

I see some problems in both cases. When using the ground-truth belief state from the previous turn, the metrics might be overestimated. On the other hand, when using the fully-predicted last state, the ground-truth user response is used and it might not be consistent with the previous state. So I would actually expect the metrics to be underestimated, am I right?

Thank you in advance ๐Ÿ™‚

Hyper-parameters for reproducing

Hello, thanks for your amazing work ^-^
I tried to reproduce your experiment result shown in paper, using the end-2-end setting shown in run.py
image
The result I got is about 3 point lower than the result in paper.
I followed run.py to run my experiment.
I use python3.6, 1 * V100, transformers==2.8.0, the other python packages are the same as requirements.txt

AttributeError: 'SimBART' object has no attribute 'shared' (End2End Bart)

I tried to execute the code with BART using MBartForConditionalGeneration

So
In Bart.py I removed the lines

class MiniBART(MBartModel):
     def __init__(self, config):
         super().__init__(config)
         self.dst_decoder = type(self.decoder)(config, self.shared)
         self.dst_decoder.load_state_dict(self.decoder.state_dict())
   def tie_decoder(self):
         self.shared.padding_idx = self.config.pad_token_id
         self.dst_decoder = type(self.decoder)(self.config, self.shared)
         self.dst_decoder.load_state_dict(self.decoder.state_dict())

and used these lines instead


class SimBART(MBartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
    def tie_decoder(self):
        pass

but I got an error when run with
python train.py --mode train --context_window 2 --pretrained_checkpoint facebook/mbart-large-50 --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=557 batch_size=8

the error is

Traceback (most recent call last):
  File "C:\Users\E\train.py", line 363, in <module>
    main()
  File "C:\Users\E\train.py", line 351, in main
    m = Model(args)
  File "C:\Users\train.py", line 35, in __init__
    self.model =  SimBART.from_pretrained(args.model_path if test else 'facebook/mbart-large-50')
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 1224, in from_pretrained
    model.tie_weights()
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 522, in tie_weights
    output_embeddings = self.get_output_embeddings()
  File "C:\Users\E\BART.py", line 193, in get_output_embeddings
    return _make_linear_from_emb(self.shared)  # make it on the fly
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SimBART' object has no attribute 'shared'

what can I do to fix this error?

Unable to train: numpy.AxisError: axis 1 is out of bounds for array of dimension 1

Hello again ๐Ÿ™‚

I am trying to run the code of this repository, but I am not successful. I installed all requirements, updated & run setup.sh.
I would like to train the t5-small, so I use the suggested command:

python train.py --mode train --context_window 2 --pretrained_checkpoint t5-small --cfg seed=557 batch_size=32

However, it fails at start saying:

utils.py:448: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
Traceback (most recent call last):
  File "train.py", line 359, in <module>
    main()
  File "train.py", line 348, in main
    m.train()
  File "train.py", line 67, in train
    inputs = self.reader.convert_batch(turn_batch, py_prev, first_turn=first_turn, dst_start_token=self.model.config.decoder_start_token_id)
  File "utils.py", line 448, in convert_batch
    inputs["response_input"] = torch.tensor( np.concatenate( ( np.array(batch['input_pointer']), response_input[:,:-1]), axis=1 ) ,dtype=torch.long)
  File "<__array_function__ internals>", line 6, in concatenate
numpy.AxisError: axis 1 is out of bounds for array of dimension 1

The first array has shape (32,), the second (32, 40).

These are my installed packages:

blis==0.4.1      
certifi==2020.12.5
chardet==4.0.0          
click==7.1.2                           
cymem==2.0.5   
en-core-web-sm==2.2.5
filelock==3.0.12   
idna==2.10                                                                                                 
importlib-metadata==3.10.0
joblib==1.0.1                                                                                                                                                                                                      
murmurhash==1.0.5                                                                                                                                                                                                  
nltk==3.4.5                                                                                                                                                                                                        
numpy==1.20.2                                                                                                                                                                                                      
packaging==20.9                                                                                                                                                                                                    
plac==1.1.3                                                                                                                                                                                                        
preshed==3.0.5                                                                                                                                                                                                     
pyparsing==2.4.7                                                                                                                                                                                                   
regex==2021.3.17                                                                                                                                                                                                   
requests==2.25.1                                                                                                                                                                                                   
sacremoses==0.0.43                                                                                                                                                                                                 
sentencepiece==0.1.95                                                                                                                                                                                              
six==1.15.0                                                                                                                                                                                                        
spacy==2.2.2                                                                                                                                                                                                       
srsly==1.0.5                                                                                                                                                                                                       
thinc==7.3.1                                                                                                                                                                                                       
tokenizers==0.10.1                                                                                                                                                                                                 
torch==1.4.0                                                                                                                                                                                                       
tqdm==4.59.0                                                                                                                                                                                                       
transformers==4.4.2                                                                                                                                                                                                
typing-extensions==3.7.4.3                                                                                                                                                                            
urllib3==1.26.4                                                                                               
wasabi==0.8.2   
zipp==3.4.1

Am I missing something?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.