Giter Site home page Giter Site logo

qgfn's Introduction

QGFN: Controllable Greediness with Action Values

Table of Contents

Description

This repository contains the code to run the experiments and visualize the results highlighted in the paper QGFN: Controllable Greediness with Action Values.

Overview

Our codebase builds on top of a fork of the public recursion gflownet repo which provides the environment setup to run the gflownet framework on graph domains. Our main edits to the forked repo are found in the following files

Setup

To setup the project, run the following commands to install the required base packages:

Then, install the gflownet package from local source:

pip install --no-index torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html 
pip install --no-index torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
pip install --no-index torch-cluster -f https://data.pyg.org/whl/torch-1.13.1+cu116.html 

pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu116.html 
# It may take a while to build wheels for `torch-cluster`, `torch-scatter` & `torch-sparse

or follow the guideline from the public recursion gflownet repo

How to run

After setting up the virtual environment, you can test out QGFN by running python interactive_script.py. interactive_script.py has three flags, indicating the three variants of QGFN. Setting one of them to true will run one of the variants.

  • p_greedy_sample
  • p_of_max_sample
  • p_quantile_sample

Usage

The current project supports training a mixture policy from scratch and visualizing the results of one or multiple runs in the same plots.

Training from scratch

Our setup is made to support jobs running with the Slurm workload manager. To train a mixture policy from scratch, you must first edit the utils/template.sh file to customize the Slurm executable script to your GPU environment. Next, to generate the executables for a given job, simply edit the hyperparameters in the main function of gen.py and run the file from the project root directory:

Example

First, select the [task] on which you'd like to train the network. Next, set the hyperparameters in the gen.py file. The main hyperparameters of interest are as follows:

BASE_HPS: Config = {
    "log_dir": 'path.to.log.directory',
    "num_training_steps": 10000,
    ...
    "cond": {
        "temperature": {
            "sample_dist": "constant",
            "dist_params": 32.0,        # Set reward exponent beta
        }
    },
    "algo": {
        "p_greedy_sample": True,        # Set to true to run p-greedy sampling
        "p_of_max_sample": False,       # Set to true to run p-of-max sampling
        "p_quantile_sample": False,     # Set to true to run p-quantile sampling
        "p": 0.6,                       # Set value of `p` to control greediness
        "dqn_n_step": 30,               # Set n-step returns for Q learning
        ...
    },
    "task": {
        "qm9": {
            "h5_path": "path.to.dataset/qm9.h5",                 # Set path to qm9 dataset
            "model_path": "path.to.model/mxmnet_gap_model.pt"    # Set path to mxmnet proxy model for qm9 task reward
        },
        "bitseq": {
            "variant": "prepend-append",                         # Set variant of bit sequence generation mode as `prepend-append` or `autoregressive`
            "modes_path": "path.to.ref.sequences/modes.pkl",     # Set path to the pre-determined set of modes for the task 
            ...
        }
    }
}

Finally, you may set a list of hyperparameters in the main() method of gen.py for grid search. To generate the runs, simply run the file:

python gen.py

For every hyperparameter combination you specified, this will generate a corresponding run folder at jobs/<current_date>/<run_id-current_time>. This folder will contain the following files:

  • run.sh: the Slurm executable script for the job
  • run.py: the main executable script for the job
  • howto.txt: a text file containing the command to submit the job to slurm
  • config.json: a json file containing the hyperparameters for the job
  • run_object.json: a json file containing the class instance of the run object, which can be used to re-instantiate the run object for downstream analysis and plotting

To submit the job to slurm, simply run the command specified in howto.txt from the run config directory. For example, if the command is sbatch --array=0-4 run.sh config, then run the following:

cd jobs/<current_date>/<run_id-current_time>
sbatch --array=0-4 run.sh config

Plotting results

This repository also supports a variety of plotting functions to visualize and compare the results across multiple runs. The main plotting script is vis.py and we currently support the following plot types:

  • AVERAGE_REWARD: plots the average reward per episode over the number of sampled trajectories during training
  • NUMBER_OF_MODES: plots the number of unique modes above a certain reward threshold found by the mixture policy over the course of training
  • TOP_K_REWARD: plots the average reward for the top k trajectories in the run with highest overall reward
  • TOP_K_SIMILARITY: plots the tanimoto similarity (or other similarity measure) between the top k trajectories in the run with highest overall reward
  • NUMBER_OF_MODES_AT_K: plots the number of modes at the kth or last k trajectory(ies) in the run
  • AVERAGE_REWARD_AT_K: plots the average reward at the kth or last k trajectory(ies) in the run

To produce a given plot, you need to provide the path to the config folders (read above) of the runs you want to compare in the runs list in the main function of vis.py, along with a name and color for each run. For example, if you want to compare the results of two runs with paths jobs/2023-11-19/001-01-54-02 and jobs/2023-12-09/002-23-23-29, you could write:

runs = [
    {
        'path': 'jobs/2023-11-19/001-01-54-02',
        'name': 'run1',
        'color': 'blue'
    },
    {
        'path': 'jobs/2023-12-09/002-23-23-29',
        'name': 'run2',
        'color': 'red'
    },
]

Then, you need to specify the plot type you want to produce in the plot_type variable. Available plot types are specified in the PlotType enum in utils/plotting.py. For example, if you want to produce an average reward plot, you would write:

plot_type = PlotType.NUMBER_OF_MODES

Finally, you can override the default plot parameters in the plot_params dictionary. For example, if you want to change the title of the plot, you could write:

plot_params = {
    **DEFAULT_SHARED_PARAMS,
    **DEFAULT_UNIQUE_PARAMS[plot_type],

    # Add your custom parameters here.
    'title': 'My custom title'
}

The full specification of modifiable parameters for each plot type can be found in the DEFAULT_UNIQUE_PARAMS and DEFAULT_SHARED_PARAMS dictionaries in vis.py.

Once you have specified the runs, plot type, and plot parameters, you must run the script from the project root directory, and your plot will be saved in the results folder under a unique identifier based on the current date and plot type.

python vis.py

Plotting for other task environments

The above script generation and plotting code has been written to support the standard molecules environment in the gflownet package. However, the plotting code can be easily adapted to support other environments by inhering from the PlottableRunObject class in utils/plotting.py. This class defines the interface for a run object that can be plotted, and characterizes a few simple methods to extract the relevant data from the runs into a universal format, which is then provided to the downstream PlotConfig class for standard plotting. The methods that need to be implemented are:

def load_raw_data(self, sqlite_cols: List[str]):
    """
    Loads the raw run data from the runs. Please override this method 
    in your child class depending on your data format. This implementation supports
    the molecules environment which reads data from a sqlite database using the 
    specified columns.
    """
    pass

def get_average_reward(self, save_df: bool=False) -> pd.DataFrame:
    """Returns a pandas dataframe with the average reward for each worker at each step"""
    pass

@staticmethod
def is_new_mode(obj, modes, sim_threshold=0.7) -> bool:
    """Returns True if obj is a new mode, False otherwise"""
    pass

def get_modes(self, min_reward: float=0.9, sim_threshold: float=0.7, save_df: bool=False) -> pd.DataFrame:
    """Returns a pandas dataframe with the number of modes with reward > min_reward and 
    similarity > sim_threshold for each worker at each step"""
    pass

def get_top_k_reward(self):
    """Returns a pandas dataframe with the top k reward for each worker at each step"""
    pass

def get_top_k_similarity(self):
    pass

def get_reward_distribution(self):
    pass

Updates

2024-05-15

  • Included hps.yaml Configuration File: We have added a hps.yaml file to streamline the process of replicating our results. You can find this file in the config directory. This file contains all the hyperparameters used in our experiments, ensuring that you can easily replicate the training conditions and results.

qgfn's People

Contributors

bengioe avatar pjanowski avatar yunglau avatar dmaljovec avatar alfred-rxrx avatar julienroyd avatar sobhanmp avatar dependabot[bot] avatar hohyun312 avatar timgaripov avatar

Stargazers

Taeyoung avatar  avatar Seanie Lee avatar Hyeonah Kim avatar Piotr Gaiński avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.