Giter Site home page Giter Site logo

cosmos's Introduction

Atharva Sehgal, Arya Grayeli, Jennifer J. Sun, Swarat Chaudhuri

[Website][Paper]

This repository contains the code for the paper "Neurosymbolic Grounding for Compositional Generalization in Object-Oriented World Models".

Installation

We use two conda environments for this project. cosmos will be used for training and comes with pytorch, and cosmos_datagen is used for data generation. To install the environments, run the following commands:

$ conda create --name cosmos -f artifacts/env/cosmos.yml
$ conda create --name cosmos_datagen -f artifacts/env/cosmos_datagen.yml

By default, all training is logged with comet.ml. We will load comet.ml arguments from src/comet_init.py. This will throw an error if the API key is not set.

We need pretrained weights for SAM and CLIP. These are not included in the repo. You can download the weights from the following links:

  • SAM:
  • CLIP:

Data Generation

We have two environments: 2D shapeworld and 3D shapeworld.

2D shapeworld

$ conda activate cosmos_datagen
# Step 1: Generates all combinations of valid 2D shape configurations.
# Stored in datasets/shapeworld_2d/{n_objects}_objects/sampled_scenes.pkl 
(cosmos_datagen) $ python -m src.datagen.shapeworld_2d.sample_scenes --params src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 2: Eyeball the generated scenes.
(cosmos_datagen) $ python -m src.datagen.shapeworld_2d.visualize --params src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 3: Generates training, validation, and test set with the specified params.
# The validation and test set use the same scene configurations.
# Warning! CPU intensive, can take a few hours.
# The generated data is stored in datasets/shapeworld_2d/{n_objects}_objects/{composition_type}/{split}/seq_{i}.npz
(cosmos_datagen) $ python -m src.datagen.shapeworld_2d.generate --params src/cosmos/configs/2d_shape_3obj_rel_comp_sticky.py
# Step 4: Eyeball compositional split statistics for each split.
(cosmos_datagen) $ python -m src.datagen.shapeworld_2d.stats --params src/cosmos/configs/2d_shape_3obj_rel_comp_sticky.py --split train

For relational composition, run Step 3 and Step 4 with src/cosoms/configs/2d_shape_3obj_rel_comp_sticky.py or src/cosoms/configs/2d_shape_3obj_rel_comp_team.py.

3D shapeworld

# Download Mujoco v1.26 and place it in artifacts/mujoco_bin/...
$ conda activate cosmos_datagen
$ bash -i scripts/setup_mujoco.sh
# This generates a training, validation, and test set with the specified params.
# Warning! CPU intensive, can take a few hours.
# The generated data is stored in datasets/shapeworld_3d/{n_objects}_objects/{composition_type}/{split}/seq_{i}.npz
(cosmos_datagen) $ python -m src.datagen.shapeworld_3d.generate --params src/cosmos/configs/3d_shape_3obj_entity_comp.py

The 3D dataset isn't used in the paper but we include it as a "challenge" domain for future work.

Training COSMOS

# Download SAM VIT weights and place them in artifacts/sam_vit/...
# Download CLIP weights and place them in artifacts/clip/...
$ conda activate cosmos
# Step 1: Extract the object segmentations with SAM.
# Warning! GPU intensive, can take a few hours.
# The extracted masks follows a similar directory structure to the generation scripts.
# The masks are stored in datasets/segment_anything_masks/shapeworld_2d/{n_objects}_objects/{composition_type}/{split}/seq_{i}_masks.npz
(cosmos) $ CUDA_VISIBLE_DEVICES=0 python -m src.cosmos.extract_object_masks_sam --config src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 2: Warmstart the resnet encoder and spatial decoder
# The artifacts created during training are stored in checkpoints/cosmos/shapeworld_2d/{n_objects}_objects/{composition_type}/...
# Notably the best model is stored in checkpoints/cosmos/shapeworld_2d/{n_objects}_objects/{composition_type}/best_model.pt
(cosmos) $ CUDA_VISIBLE_DEVICES=0 python -m src.cosmos.train_slot_autoencoder --config src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 3: Train COSMOS. This uses the train and validation set.
(cosmos) $ CUDA_VISIBLE_DEVICES=0 python -m src.cosmos.train_transition_model --config src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 3.1: (Alternatively) Train COSMOS with multiple GPUs. I added this for convenience.
(cosmos) $ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m src.cosmos.train_transition_model --gpus 4 --config src/cosmos/configs/2d_shape_3obj_entity_comp.py
# Step 4: Evaluate COSMOS. This runs the test set.
(cosmos) $ CUDA_VISIBLE_DEVICES=0 python -m src.cosmos.evaluate_transition_model --config src/cosmos/configs/2d_shape_3obj_entity_comp.py

(Additional) Training COSMOS (with slot attention)

This repo also contains code to try out BO-QSA for unsupervised segmentation learning. The original repo is here: ``.

$ conda activate cosmos
# Step 1: Warm start the object detector:
(cosmos) $ python -m src.cosmos.train_slot_autoencoder --config src/cosmos/configs/3d_shape_3obj_entity_comp_withQSA.py
# Skip step 2 and directly train COSMOS with the warm started object detector...

Feel free to contact me @ [email protected] with any questions!

cosmos's People

Contributors

atharvas avatar

Stargazers

Jie-Jing Shao avatar Jennifer J. Sun avatar

Watchers

Swarat Chaudhuri avatar  avatar Jennifer J. Sun avatar

cosmos's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.