Giter Site home page Giter Site logo

trendingtechnology / r2sql Goto Github PK

View Code? Open in Web Editor NEW

This project forked from huybery/r2sql

0.0 1.0 0.0 871 KB

:hot_pepper: R²SQL: "Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing." (AAAI 2021)

Home Page: https://arxiv.org/abs/2101.01686

Shell 1.25% Python 78.82% Jupyter Notebook 19.93%

r2sql's Introduction

R²SQL

The PyTorch implementation of paper Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing. (AAAI 2021)

Requirements

The model is tested in python 3.6 with following requirements:

torch==1.0.0
transformers==2.10.0
sqlparse
pymysql
progressbar
nltk
numpy
six
spacy

All experiments on SParC and CoSQL datasets were run on NVIDIA V100 GPU with 32GB GPU memory.

  • Tips: The 16GB GPU memory may appear out-of-memory error.

Setup

The SParC and CoSQL experiments in two different folders, you need to download different datasets from [SParC | CoSQL] to the {sparc|cosql}/data folder separately. Another related data file could be download from EditSQL. Then, download the database sqlite files from [here] as data/database.

Download Pretrained BERT model from [here] as model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin.

Download Glove embeddings file (glove.840B.300d.txt) and change the GLOVE_PATH for your own path in all scripts.

Download Reranker models from [SParC reranker | CoSQL reranker] as submit_models/reranker_roberta.pt, besides the roberta-base model could download from here for ./[sparc|cosql]/local_param/.

Usage

Train the model from scratch.

./sparc_train.sh

Test the model for the concrete checkpoint:

./sparc_test.sh

then the dev prediction file will be appeared in results folder, named like save_%d_predictions.json.

Get the evaluation result from the prediction file:

./sparc_evaluate.sh

the final result will be appeared in results folder, named *.eval.

Similarly, the CoSQL experiments could be reproduced in same way.


You could download our trained checkpoint and results in here:

Reranker

If your want train your own reranker model, you could download the training file from here:

Then you could train, test and predict it:

train:

python -m reranker.main --train --batch_size 64 --epoches 50

test:

python -m reranker.main --test --batch_size 64

predict:

python -m reranker.predict

Improvements

We have improved the origin version (descripted in paper) and got more performance improvements 🥳!

Compare with the origin version, we have made the following improvements:

  • add the self-ensemble strategy for prediction, which use different epoch checkpoint to get final result. In order to easily perform this strategy, we remove the task-related representation in Reranker module.
  • remove the decay function in DCRI, we find that DCRI is unstable with decay function, so we let DCRI degenerate into vanilla cross attention.
  • replace the BERT-based with RoBERTa-based model for Reranker module.

The final performance comparison on dev as follows:

SParC CoSQL
QM IM QM IM
EditSQL 47.2 29.5 39.9 12.3
R²SQL v1 (origin paper) 54.1 35.2 45.7 19.5
R²SQL v2 (this repo) 54.0 35.2 46.3 19.5
R²SQL v2 + ensemble 55.1 36.8 47.3 20.9

Citation

Please star this repo and cite paper if you want to use it in your work.

Acknowledgments

This implementation is based on "Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions" EMNLP 2019.

r2sql's People

Contributors

huybery avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.