Giter Site home page Giter Site logo

argo_llama's Introduction

Argo_Llama ๐Ÿš— ๐Ÿฆ™

The goal of this project is to predict future trajectory using the simplest input representation and networks possible. We aim to minimize the domain-specific knowledge and leverage the capabilities of the transformer model to handle all aspects of the task.

  • We chose the argoverse v1 dataset due to its ease of use.
  • The model uses only self-attention transformer (adapted from LLaMA2.c).
  • Checkpoint can be found on huggingface.

Usage ๐Ÿ› ๏ธ

  • Install dependency:

    conda create -n argo_llama python==3.10
    pip install -r requirements.txt
  • Install the Argoverse API.

  • Prepare input data for training:

    conda activate argo_llama
    python3 run.py prep
  • Train the model: To run on a single GPU small debug run, example:

    python3 train.py

    To run with DDP on 4 gpus on 1 node, example:

    torchrun --standalone --nproc_per_node=4 train.py
  • Evaluate the model and visualize the results: (Check out the sample data and checkpoints)

    python3 run.py viz

Model:

Input Output representation:

Input dimension: [batch_size, seq_len, c]

  • "c" represents the channel dimension. Object paths and maps are represented using line segments with types [x0, y0, x1, y1, type]. Here, (x0, y0) denotes the start position, and (x1, y1) denotes the end position.
  • "x" and "y" are offset by the AV's position at prediction time.
  • Object previous paths are sampled using a fixed step, but the timestamp is not represented in the input data, requiring the model to learn it.
  • "seq_len" represents the sequence length. To maintain a fixed input size, sequences exceeding the maximum size are truncated. Padding with [0,0,0,0,-1] is added if the sequence is shorter than the maximum size. No mask is applied for the padding; the model learns to recognize padding as a special token.

Changes to the LLaMA model:

  • Positional encoding is removed, and the sequence is trained as a bag of tokens.
  • The attention layer's is_causal parameter is set to false to enable the network to observe the entire sequence.
  • The model aggregates information to the first token and use a linear layer to reduce the output dimension to the desired size.
  • Sample result visualization: Ground truth (magenta), Prediction (green)

viz

Loss:

Our model is designed to generate multiple paths, whereas the ground truth consists of only one path. For example, at junctions, we expect the model to generate options for all possible turns. We use the loss of the path with the smallest L2 loss. This strategy encourages the model to explore various possibilities. Since the gradient of the min function is discontinuous, we employ the softmin function as a weighting factor to ensure smoother optimization.

Data loader:

In the data loading process, a random rotation is applied to all vectors. This approach significantly reduces overfitting and enhances the model's ability to learn rotation invariance โš ๏ธ.

Use on your own data ๐Ÿ”„

When transferring the model to your own data, we found that rotating the coordinates by a 45-degree angle improves performance โš ๏ธ. The model exhibits non-rotation invariance even we add the random rotation in the training. This is due to the Loss funtion. When near the axes, the effect of one dimension is much more pronounced. Rotating the coordinates by 45 degrees compensates for this discrepancy, as both the x and y coordinates contribute roughly equally to the loss. viz

argo_llama's People

Contributors

dalaska avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.