Giter Site home page Giter Site logo

dpot's Introduction

DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024)

Code for paper DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024). It pretrains neural operator transformers (from 7M to 1B) on multiple PDE datasets. Pre-trained weights could be found at https://huggingface.co/hzk17/DPOT.

fig1

Our pre-trained DPOT achieves the state-of-the-art performance on multiple PDE datasets and could be used for finetuning on different types of downstream PDE problems.

fig2

Usage

Pre-trained models

We have five pre-trained checkpoints of different sizes. Pre-trained weights are at https://huggingface.co/hzk17/DPOT.

Size Attention dim MLP dim Layers Heads Model size
Tiny 512 512 4 4 7M
Small 1024 1024 6 8 30M
Medium 1024 4096 12 8 122M
Large 1536 6144 24 16 509M
Huge 2048 8092 27 8 1.03B

Here is an example code of loading pre-trained model.

model = DPOTNet(img_size=128, patch_size=8, mixing_type='afno', in_channels=4, in_timesteps=10, out_timesteps=1, out_channels=4, normalize=False, embed_dim=512, modes=32, depth=4, n_blocks=4, mlp_ratio=1, out_layer_dim=32, n_cls=12)
model.load_state_dict(torch.load('model_Ti.pth')['model'])
Dataset Protocol

All datasets are stored using hdf5 format, containing data field. Some datasets are stored with individual hdf5 files, others are stored within a single hdf5 file.

In data_generation/preprocess.py, we have the script for preprocessing the datasets from each source. Download the original file from these sources and preprocess them to /data folder.

Dataset Link
FNO data Here
PDEBench data Here
PDEArena data Here
CFDbench data Here

In utils/make_master_file.py , we have all dataset configurations. When new datasets are merged, you should add a configuration dict. It stores all relative paths so that you could run on any places.

mkdir data
Single GPU Pre-training

Now we have a single GPU pretraining code script train_temporal.py, you could start it by

python train_temporal.py --model DPOT --train_paths ns2d_fno_1e-5 --test_paths ns2d_fno_1e-5 --gpu 0 

to start a training process.

Or you could start it by writing a configuration file in configs/ns2d.yaml and start it by automatically using free GPUs with

python trainer.py --config_file ns2d.yaml
Multiple GPU Pre-training
python parallel_trainer.py --config_file ns2d_parallel.yaml
Configuration file

Now I use yaml as the configuration file. You could specify parameters for args. If you want to run multiple tasks, you could move parameters into the tasks ,

model: DPOT
width: 512
tasks:
 lr: [0.001,0.0001]
 batch_size: [256, 32] 

This means that you start 2 tasks if you submit this configuration to trainer.py.

Requirement

Install the following packages via conda-forge

conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install matplotlib scikit-learn scipy pandas h5py -c conda-forge
conda install timm einops tensorboard -c conda-forge

Code Structure

  • README.md
  • train_temporal.py: main code of single GPU pre-training auto-regressive model
  • trainer.py: framework of auto scheduling training tasks for parameter tuning
  • utils/
    • criterion.py: loss functions of relative error
    • griddataset.py: dataset of mixture of temporal uniform grid dataset
    • make_master_file.py: datasets config file
    • normalizer: normalization methods (#TODO: implement instance reversible norm)
    • optimizer: Adam/AdamW/Lamb optimizer supporting complex numbers
    • utilities.py: other auxiliary functions
  • configs/: configuration files for pre-training or fine-tuning
  • models/
    • dpot.py: DPOT model
    • fno.py: FNO with group normalization
    • mlp.py
  • data_generation/: Some code for preprocessing data (ask hzk if you want to use them)
    • darcy/
    • ns2d/

Citation

If you use DPOT in your research, please use the following BibTeX entry.

@article{hao2024dpot,
  title={DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training},
  author={Hao, Zhongkai and Su, Chang and Liu, Songming and Berner, Julius and Ying, Chengyang and Su, Hang and Anandkumar, Anima and Song, Jian and Zhu, Jun},
  journal={arXiv preprint arXiv:2403.03542},
  year={2024}
}

dpot's People

Contributors

haozhongkai avatar

Stargazers

LeslisXu avatar j.dong avatar Hanjiang Hu avatar  avatar  avatar Learn2Learn avatar Aoming Liang avatar Jake Cunningham avatar OptRay avatar Jingmin avatar Sifan Wang avatar LiuPengwei avatar James Duvall avatar Michael Nguyen avatar Keke Wu avatar  avatar  avatar  avatar CoderPanda avatar Julius Berner avatar Yang Cui avatar Parham Abbasi avatar Prakhar Sharma avatar Jiachen Yao avatar  avatar  avatar AmirPouya Hemmasian avatar zjjlee avatar Xihaier avatar  avatar Bingrui Li avatar Armando Teles Fortes avatar David Marx avatar Chengyang Ying avatar  avatar

Watchers

 avatar

dpot's Issues

Request for Access to Pre-trained Model Parameters

Dear Zhongkai,

I am currently working on project about neural PDE solver, and I have found DPOT is a strong SOTA baseline. However, my computation resources are too limited to train it from scratch.

Would it be possible for you to share the pre-trained model parameters? I believe It would greatly benefit my research and to the NeuralPDE community as well. If there are any restrictions or conditions under which you would be willing to do so, I would be more than happy to discuss them.

I appreciate your consideration of my request and look forward to your positive response.

Best regards,
Mingquan

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.