Giter Site home page Giter Site logo

uber-research / lanegcn Goto Github PK

View Code? Open in Web Editor NEW
480.0 9.0 131.0 20.31 MB

[ECCV2020 Oral] Learning Lane Graph Representations for Motion Forecasting

Home Page: https://arxiv.org/abs/2007.13732

License: Other

Python 98.34% Shell 1.66%
self-driving motion-estimation graph-neural-networks artificial-intelligence motion-forecasting

lanegcn's Introduction

Caution

The Amazon AWS S3 bucket argoai-argoverse which holds many of the demo files has been compromised. The files may be corrupted.
Files referencing this S3 bucket have been modified, and any retrievals from this bucket are commented out. Please proceed with caution.

LaneGCN: Learning Lane Graph Representations for Motion Forecasting

Paper | Slides | Project Page | ECCV 2020 Oral Video

Ming Liang, Bin Yang, Rui Hu, Yun Chen, Renjie Liao, Song Feng, Raquel Urtasun

Rank 1st in Argoverse Motion Forecasting Competition

img

Table of Contents

Install Dependancy

You need to install following packages in order to run the code:

  1. Following is an example of create environment from scratch with anaconda, you can use pip as well:
conda create --name lanegcn python=3.7
conda activate lanegcn
conda install pytorch==1.5.1 torchvision cudatoolkit=10.2 -c pytorch # pytorch=1.5.1 when the code is release

# install argoverse api
pip install  git+https://github.com/argoai/argoverse-api.git

# install others dependancy
pip install scikit-image IPython tqdm ipdb
  1. [Optional but Recommended] Install Horovod and mpi4py for distributed training. Horovod is more efficient than nn.DataParallel for mulit-gpu training and easier to use than nn.DistributedDataParallel. Before install horovod, make sure you have openmpi installed (sudo apt-get install -y openmpi-bin).
pip install mpi4py

# install horovod with GPU support, this may take a while
HOROVOD_GPU_OPERATIONS=NCCL pip install horovod==0.19.4

# if you have only SINGLE GPU, install for code-compatibility
pip install horovod

if you have any issues regarding horovod, please refer to horovod github

Prepare Data: Argoverse Motion Forecasting

You could check the scripts, and download the processed data instead of running it for hours.

bash get_data.sh

Training

[Recommended] Training with Horovod-multigpus

# single node with 4 gpus
horovodrun -np 4 -H localhost:4 python /path/to/train.py -m lanegcn

# 2 nodes, each with 4 gpus
horovodrun -np 8 -H serverA:4,serverB:4 python /path/to/train.py -m lanegcn

It takes 8 hours to train the model in 4 GPUS (RTX 5000) with horovod.

We also supply training log for you to debug.

[Recommended] Training/Debug with Horovod in single gpu

python train.py -m lanegcn

Testing

You can download pretrained model from here

Inference test set for submission

python test.py -m lanegcn --weight=/absolute/path/to/36.000.ckpt --split=test

Inference validation set for metrics

python test.py -m lanegcn --weight=36.000.ckpt --split=val

Qualitative results

Labels(Red) Prediction (Green) Other agents(Blue)


Quantitative results img

Licence

check LICENSE

Citation

If you use our source code, please consider citing the following:

@InProceedings{liang2020learning,
  title={Learning lane graph representations for motion forecasting},
  author={Liang, Ming and Yang, Bin and Hu, Rui and Chen, Yun and Liao, Renjie and Feng, Song and Urtasun, Raquel},
  booktitle = {ECCV},
  year={2020}
}

If you have any questions regarding the code, please open an issue and @chenyuntc.

lanegcn's People

Contributors

chenyuntc avatar jonathanbaker7 avatar wqi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lanegcn's Issues

confused about the 'theta' param in actor preprocessing

HI @chenyuntc, thans for your source code,
I'm confused about the theta param in data:get_obj_feats recently.
I usually get theta from actan2( (pos[19].y - pos[18].y) , (pos[19].x - pos[18].x)).
what's your calculation specital mean here?

theta = np.pi - np.arctan2(pos[18].y-pos[19].y, pos[18].x - pos[19].x)

BR~

Training details

Hi,
According to the paper section 4.1 (implementation details), you use a batch size of 128 and train for 36 epochs with a learning rate 0.001 and decayed at 32 to 0.0001.

According to the provided code, the batch size is 32:

config["batch_size"] = 32

Does it give the same performance?

Also one more question about the loss function, can you give more insights for the classification loss? why do you need it, and have you tried training without it?

Thanks a lot for the great work.

Can you tell me the reason for randomness?

Hello, Thanks for your nice project.

I tried to train the model several times without code editing, then I found that the difference in performance of each trial was quite large.
There is no big difference when testing multiple times with one checkpoint, so randomness seems to occur during the learning process.
Do you have any idea what could be the reason?

Thank you!

"left" and "right" not in gragh

for k1 in ["left", "right"]:
    graph[k1] = dict()
    for k2 in ["u", "v"]:
        temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
        temp = [
            x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
            for x in temp
        ]
        graph[k1][k2] = torch.cat(temp)

KeyError: 'left'

how to visualization

Your work has been a great help for a beginner, but I don’t understand how to visualize the results obtained? Where are the imported cv2 modules used? Hope to get your reply, thank you again for your work.

multi agent trajectory prediction?

Thank you for sharing your work and neat code!

In your paper (3.5 Learning),
the loss is used for training multiple agent. (M is the number of agent in the scene)

But the pretrained checkpoint (36.000.ckpt) seems only predicting the first agent for computing scores such as minADE.

Do you also have the weight checkpoint for multi agent?
If so, can you share one?
Training process is too heavy for my server ;(

Training is much slower than you described in paper.

Hi, I recently want to reproduce your result and can get the metric your described in paper but I got a problems that the training (almost 3 days) than you described in paper (less than 12 hours).

Environment:

  • 4 * Titan X (same as paper)
  • batch size 128 (4*32, same as paper)
  • change distribution framework from horovod to pytorch DDP since thehorovod framework is really hard to set up (even with official horovod docker I still got some errors I can't resolve)

Did I do something wrong? I'm sure that I use DDP correctly and also sure that the bottleneck of training speed is optimization (not IO or something else). Have others met the same problems like me?

How were graph["left_pairs"] and graph["left_pairs"] defined in the lane graph? Is there a bug of these keys?

I printed the value of graph["left_pairs"] and graph["right_pairs"] and found that in graph["left_pairs"] there are some duplicate node pairs, for example "[37,64]" and "[64,37]", but it is hard to understand why node "37" and "64" can be left neighbor of each other. But there are no repeated pairs in the graph["right_pairs"]. So why?
If graph["left_pairs"] and the reverse of graph["right_pairs"] should be equal? For instance, graph["left_pairs"]:[[1,2],[3,4]],
graph["right_pairs"]:[[2,1],[4,3]].

The printed value of graph["left_pairs"] and graph["right_pairs"] are as follows:
len graph[left_pairs] 67 len graph[right_pairs] 27 graph[left_pairs] tensor([[15, 91], [19, 86], [24, 97], [29, 30], [30, 29], [31, 43], [32, 68], [33, 60], [34, 39], [35, 51], [36, 66], [37, 64], [38, 69], [39, 34], [40, 67], [41, 95], [42, 51], [43, 31], [44, 29], [45, 78], [46, 63], [47, 69], [48, 58], [49, 67], [50, 58], [51, 35], [52, 15], [53, 47], [54, 37], [55, 56], [56, 55], [57, 96], [58, 48], [59, 75], [61, 48], [62, 82], [63, 72], [64, 37], [65, 49], [66, 98], [67, 49], [68, 32], [69, 47], [70, 64], [71, 30], [72, 63], [73, 74], [74, 82], [75, 59], [76, 77], [77, 76], [79, 85], [81, 80], [82, 74], [83, 85], [84, 91], [85, 83], [88, 83], [89, 35], [90, 72], [91, 15], [93, 92], [94, 98], [95, 41], [96, 57], [97, 24], [98, 66]], device='cuda:0') graph[right_pairs] tensor([[15, 52], [29, 44], [30, 71], [35, 89], [37, 54], [47, 53], [48, 61], [49, 65], [51, 42], [58, 50], [60, 33], [63, 46], [64, 70], [66, 36], [67, 40], [69, 38], [72, 90], [74, 73], [78, 45], [80, 81], [82, 62], [83, 88], [85, 79], [86, 19], [91, 84], [92, 93], [98, 94]], device='cuda:0')

run preprocess_data.py raise EOFError

(lanegcn) zzj@zzj-OMEN-25L-Desktop-GT12-1xxx:/mnt/data/pycharmmm/LaneGCN-master$ python preprocess_data.py -m lanegcn

71%|████████████████████████████████████████████████▋ | 1030/1459 [06:31<02:42, 2.63it/s]

Traceback (most recent call last):
File "preprocess_data.py", line 415, in
main()
File "preprocess_data.py", line 56, in main
val(config)
File "preprocess_data.py", line 130, in val
for i, data in enumerate(tqdm(val_loader)):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/tqdm/std.py", line 1185, in iter
for obj in iterable:
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in next
data = self._next_data()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 838, in _next_data
return self._process_data(data)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
data.reraise()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
IndexError: Caught IndexError in DataLoader worker process 6.
Original Traceback (most recent call last):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/mnt/data/pycharmmm/LaneGCN-master/data.py", line 85, in getitem
data = self.get_obj_feats(data)
File "/mnt/data/pycharmmm/LaneGCN-master/data.py", line 149, in get_obj_feats
orig = data['trajs'][0][19].copy().astype(np.float32)
IndexError: index 19 is out of bounds for axis 0 with size 16

Exception in thread Thread-2:
Traceback (most recent call last):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/threading.py", line 926, in _bootstrap_inner
self.run()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
fd = df.detach()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/resource_sharer.py", line 87, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 498, in Client
answer_challenge(c, authkey)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 747, in answer_challenge
response = connection.recv_bytes(256) # reject large message
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError

What can i do?
I use one 2080s in ubuntu20.04, memery=16G, pycharm max memery size =4G

Getting data Forbidden

bash get_data.sh

--2023-07-07 02:37:29-- https://s3.amazonaws.com/argoai-argoverse/hd_maps.tar.gz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.229.120, 52.216.39.56, 52.217.207.96, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.229.120|:443... connected.
HTTP request sent, awaiting response... 403 Forbidden
2023-07-07 02:37:30 ERROR 403: Forbidden.

How can I get the results in the Table 1 of the paper?

I run this command: python test.py -m lanegcn --weight=/absolute/path/to/36.000.ckpt --split=test, but nothing printed except this:
`2442it [15:14, 2.67it/s]^[[B^[[B
78143/78143--Return--
None

/laneGCN/test.py(114)main()
113 generate_forecasting_h5(preds, f"{config['save_dir']}/submit.h5") # this might take awhile
--> 114 import ipdb;ipdb.set_trace()
115 `

How can I get the results about test dataset

I can get result about val dataset by running command: python test.py -m=lanegcn --weight=[ckpt PATH] --split=val --preprocess=True

but when I runned this command

python test.py -m=lanegcn --weight=[ckpt PATH] --split=val --preprocess=True

nothing printed. How can I get results about test dataset?

Does it really generate graph['node_idcs']?

Hi, Thnanks for release the codes!

I have one question. When I generate graph through preprocess_data.py, I didn't find graph['node_idcs'] was saved in graph this dict, but during training lanegcn.py and executing the MapNet module, I found the following codes in the forward function of MapNet:
def forward(self, graph): if ( len(graph["feats"]) == 0 or len(graph["pre"][-1]["u"]) == 0 or len(graph["suc"][-1]["u"]) == 0 ): temp = graph["feats"] return ( temp.new().resize_(0), [temp.new().long().resize_(0) for x in graph["node_idcs"]], temp.new().resize_(0), )
Although there is no error when I running the original module lanegcn.py, Is it because every time the if condition won't be execute or the graph['node_idcs'] was really generated somewhere I didn't find? If this conditional statement is executed, the nodes feature of the output will be empty.

inference time

Sorry I just post a random question here. It appears every object need to go through the network once to predict trajectory. I am wondering about the online inference time of your network (maybe after onnx). Say if you have 20 objects in the map, can that be done in real time (say within couple of mili-seconds)?

Evaluation results mismatch raw source file

Recently I wanna visualize your prediction result but I got a little confused about the correspondence between the idx of preprocessed val data and $ARGO_RAW_DATA/val/data/argo_id.csv.
In ArgoTestDataset :

LaneGCN/data.py

Line 382 in 7e9b51d

data['argo_id'] = int(self.avl.seq_list[idx].name[:-4]) #160547

I download your preprocessed val data and found the prediction paths with argo_id doesn't match the source file $ARGO_RAW_DATA/val/data/argo_id.csv. My dataset source data is download from argo website(version 1.1), I want to know whether the order of preprocessed data matches argo_id? Or did I do something wrong?

Cannot download the data

When I execute this command python preprocess_data.py -m lanegcn, I encountered the following problem:

Traceback (most recent call last):
File "preprocess_data.py", line 21, in
from data import ArgoDataset as Dataset, from_numpy, ref_copy, collate_fn
File "/home/ht1/LaneGCN/data.py", line 12, in
from argoverse.map_representation.map_api import ArgoverseMap
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/map_representation/map_api.py", line 21, in
from argoverse.utils.cv2_plotting_utils import get_img_contours
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/cv2_plotting_utils.py", line 9, in
from .calibration import CameraConfig, proj_cam_to_uv
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/calibration.py", line 14, in
from argoverse.utils.camera_stats import (
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/camera_stats.py", line 8, in
from argoverse.sensor_dataset_config import ArgoverseConfig
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/sensor_dataset_config.py", line 55, in
cfg = hydra.compose(config_name=f"{DATASET_NAME}.yaml")
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/compose.py", line 33, in compose
with_log_configuration=False,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/hydra.py", line 550, in compose_config
from_shell=from_shell,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/config_loader_impl.py", line 150, in load_configuration
from_shell=from_shell,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/config_loader_impl.py", line 244, in _load_configuration_impl
skip_missing=run_mode == RunMode.MULTIRUN,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 724, in create_defaults_list
skip_missing=skip_missing,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 695, in _create_defaults_list
skip_missing=skip_missing,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 343, in _create_defaults_tree
overrides=overrides,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 420, in _create_defaults_tree_impl
return _expand_virtual_root(repo, root, overrides, skip_missing)
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 268, in _expand_virtual_root
overrides=overrides,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 427, in _create_defaults_tree_impl
config_not_found_error(repo=repo, tree=root)
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 776, in config_not_found_error
options=options,
hydra.errors.MissingConfigException: Cannot find primary config 'argoverse-v1.1.yaml'. Check that it's in your config search path.

Config search path:
provider=hydra, path=pkg://hydra.conf
provider=main, path=pkg://argoverse.config
provider=schema, path=structured://

Preprocessed data is very slow

Dear authors,
thank you for sharing your code.
I download the dataset from argoverse, then I want to preprocess the data. When I use 'python preprocess_data.py -m lanegcn', it take 5 hours but nothing output, the cpu occupancy rate is high but gpu is low.
Thank you very much!

left and right in the lane graph

Hi,
How do you find "left" and "right" connections using the using 'pre', 'suc', 'pre_pairs', 'suc_pairs', 'left_pairs', 'right_pairs' ?

Cannot download the pretrained model

Hi, @chenyuntc .

I want to download a pre-trained model.
Therefore, I pressed the "here" button shown in the picture below in your github page, but nothing happened.
Could you check if there is any error? Or is there another way to download the pre-trained model?

Thank you,
download_pretrained_model

Learning Rate Drop

Hi! Thank you for your cool project!

I have a question which might be stupid. I notice that in the repository no warm-up and dropping learning rate on every iteration are involved. Since these techniques are normal for other deep learning applications, I am wondering if the current style in LaneGCN works better or the structure of LaneGCN can already be good without such tricks.

Thanks!

Memory issue for multi-gpu training

Hi,

Thanks for the great work.

I am trying to train using horovod with 4 GPUs (RTX2080Ti) with a cpu memory of 80 GB. However, after sometime and before it starts training the first epoch, I got the following error:

mpirun noticed that process rank 0 with PID 0 on node dagobert exited on signal 9 (Killed).

According to the horovod github, it seems an out of memory issue, Therefore, I would like to know the system requirements you have to train on 4 gpus. What are the gpu memory, cpu memory, number of cpus, etc. Maybe any advice to help training on multi-gpu?

IndexError: list index out of range

python test.py -m lanegcn --weight=/home/jovyan/LaneGCN/36.000.ckpt --split=test

0it [00:00, ?it/s]
Traceback (most recent call last):
File "/home/jovyan/LaneGCN/test.py", line 118, in
main()
File "/home/jovyan/LaneGCN/test.py", line 82, in main
for ii, data in tqdm(enumerate(data_loader)):
File "/home/venv/lib/python3.9/site-packages/tqdm/std.py", line 1178, in iter
for obj in iterable:
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 634, in next
data = self._next_data()
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 678, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/jovyan/LaneGCN/data.py", line 384, in getitem
data['argo_id'] = int(self.avl.seq_list[idx].name[:-4]) #160547
IndexError: list index out of range

Pretrain Model

Hi all,

Thank you for open source the code! It has been a great help!

However, I am recently trying to use your pretrained model for inference. It seems the link in README is invalid. Therefore, I am wondering if you may be kind enough to update the link or offer me a the pretrained model. I indeed appreciate your help!

Best,

Ziqi

Preprocessed data download link

Dear authors,
thank you for sharing your code. I was trying to train landgcn from scratch and had problems generating the training data. The data generation code freezes in the beginning and the download link seems to have connection issues. I was wondering if you have any idea about that.
Thank you very much!

Question: What does the "u" and "v" in data.py?

Hello,

Thanks for the great code.
Can you explain about "u" and "v" for graph['pre'] and graph['suc'] in data.py?

pre, suc = dict(), dict()
for key in ['u', 'v']:
    pre[key], suc[key] = [], []
for i, lane_id in enumerate(lane_ids):
    lane = lanes[lane_id]
    idcs = node_idcs[i]
    
    pre['u'] += idcs[1:]
    pre['v'] += idcs[:-1]
    if lane.predecessors is not None:
        for nbr_id in lane.predecessors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                pre['u'].append(idcs[0])
                pre['v'].append(node_idcs[j][-1])
            
    suc['u'] += idcs[:-1]
    suc['v'] += idcs[1:]
    if lane.successors is not None:
        for nbr_id in lane.successors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                suc['u'].append(idcs[-1])
                suc['v'].append(node_idcs[j][0])

and this also shows up in lanegcn.py


def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []
    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"]).to(
            graphs[i]["feats"].device
        )
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]

    graph = dict()
    graph["idcs"] = node_idcs
    graph["ctrs"] = [x["ctrs"] for x in graphs]

    for key in ["feats", "turn", "control", "intersect"]:
        graph[key] = torch.cat([x[key] for x in graphs], 0)

    for k1 in ["pre", "suc"]:
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ["u", "v"]:
                graph[k1][i][k2] = torch.cat(
                    [graphs[j][k1][i][k2] + counts[j] for j in range(batch_size)], 0
                )

    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
    return graph


train.py

When I run train.py, it shows an error, and according to my analysis, it does not enter the trained function, but instead has an error loading the data. Train1.py has the same error. Can you give me some advice? @chenyuntc

### Has anyone encountered the same issue below?

93%|██████████████████████████████████▍ | 5996/6436 [1:11:10<07:28, 1.02s/it]Traceback (most recent call last):
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/resource_sharer.py", line 142, in _serve
with self._listener.accept() as conn:
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 456, in accept
answer_challenge(c, self._authkey)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 742, in answer_challenge
message = connection.recv_bytes(256) # reject large message
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer
killed

dilated_nbrs bug

Is it that mat = mat * mat should be modified to mat = mat * csr in function dilated_nbrs?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.