Giter Site home page Giter Site logo

weitong8591 / differentiable_ransac Goto Github PK

View Code? Open in Web Editor NEW
144.0 5.0 9.0 4.67 MB

PyTorch Implementation of the ICCV 2023 paper: Generalized Differentiable RANSAC ($\nabla$-RANSAC).

License: GNU General Public License v3.0

Python 97.51% Jupyter Notebook 2.49%

differentiable_ransac's Introduction

Generalized Differentiable RANSAC (nabla-RANSAC)

📰 PyTorch Implementation of the ICCV 2023 paper: Generalized Differentiable RANSAC ($\nabla$-RANSAC).

Tong Wei, Yash Patel, Alexander Shekhovtsov, Jiri Matas and Daniel Barath.

| paper | poster | arxiv | diff_ransac_models | diff_ransac_data for E/F | 3d_match_data | Ransac-tutorial-data

Updates

[2023.10. Solvers in Kornia] Our implemented 5PC solver for essential matrix estimation is integrated in [Kornia](https://github.com/kornia/kornia)! Install it from source by
$ pip install git+https://github.com/kornia/kornia

An example of importing 5PC from Kornia is shown here.

Implementations and Environments

Here are the required packages,

python = 3.7.11
pytorch = 1.12.1
opencv = 3.4.2
tqdm
kornia
kornia_moons
tensorboardX = 2.2
scikit-learn
einops
yacs

or try with conda create --name <env> --file requirements.txt

Installation

The SOTA results are tested from our method intergrated in MAGSAC++, install it in Python as follows.

$ git clone https://github.com/weitong8591/magsac.git --recursive
$ cd magsac
$ mkdir build
$ cd build
$ cmake ..
$ make
$ cd ..
$ python setup.py install

Note that the proposed Gumbel Softmax Sampler is actiavted by sampler=3.

Evaluation

Download the pretrained models we provide here, and test them as follows.

[Two-view epipolar geometry estimation]

Download the RootSIFT features of PhotoTourism from here, and run

$ python test_magsac.py -nf 2000 -m pretrained_models/saved_model_5PC_l_epi/model.net -bs 32 -fmat 0 -sam 1 -bm 1 -t 2 -pth <>

add -fmat 1 to activate fundamental matrix estimation; use -ds <scene_name> instead of -bm 1 to test on a specific scene.

AUC scores thresholded at [5, 10, 20] are compared for E estimation, F1 scores and median epipolar errors are the evluation metrics for F estimation.

Note: SuperPoint+SuperGlue features on ScanNet are coming soon.

[3D point cloud registration]

Download the 3DMatch and 3DLoMatch data from here, and run

$ python test_magsac_point.py -m diff_ransac_models/point_model.net -d cpu -us 0 -max 50000 -pth <>

The evaluation metrics refer to registration and utils in GeoTransformer.

[Learned robust feature matching]

Download the images, camera intrinsics and extrinsics of PhotoTourism from here, and test with three protocols (-ransac): 0-OpenCV-RANSAC; 0-OpenCV-MAGSAC; 2-MAGSAC++ with PROSAC.

$ python test_ransac_loftr.py -nf 2000 -tr 1 -bs 1 -lr 0.000001 -t 3. -sam 3 -fmat 1 -sid loftr -m2 diff_ransac_models/loftr_model.pth -pth <>

Training

[Two-view epipolar geometry estimation]

Train a importance score prediction model (eg, CLNet backbone, predicting importance score for each input tentative correspondences) with $\nabla$-RANSAC end to end.

$ python train.py -nf 2000 -m pretrained_models/weights_init_net_3_sampler_0_epoch_1000_E_rs_r0.80_t0.00_w1_1.00_.net -bs 32 -fmat 0 -sam 2 -tr 1 -w2 1 -t 0.75 -pth <>

Notes: the initilized weights are applied; 5PC is used for essential matrix estimation (-sam 2 -fmat 0); 7PC (-sam 2 -fmat 1) and 8PC (-sam 3 -fmat 1) can be used for F estimation.

In terms of training loss, -w2 (mean epipolar errors) works the best in terms of AUC scores, however, using the linear combination of the classification loss (-w1 1) with -w2 as objective leads to more normal learning performance (always downward trend, but lower AUC scores in the inference). The epipolar loss (w2) is not stable in different training trials.

The train/valid data is updated to split off-line, the training image pair list and train-val data of St. Peters Square are avaiable here.

[3D point cloud registration]

Train a importance score prediction model (eg, CLNet backbone, predicting importance score for each input tentative correspondences) with $\nabla$-RANSAC end to end for point cloud registration. Rigid transformation slover is implemented.

3DMatch train/val dataset is used,

$ python train_point.py -nf 2000 -sam 2 -tr 1 -t 0.75 -pth <>
[Learning Robust Feature Matching]

End-to-end training of feature matcher(eg, LoFTR) with $\nabla$-RANSAC to improve the predictions of matches and confidences. Download 'outdoor_ds.ckpt' within diff_ransac_models and download loftr package.

$ python train_ransac_loftr.py -nf 2000 -tr 1 -bs 1 -lr 1e-6 -t 0.75 -sam 3 -fmat 1 -w2 1 -sid loftr -e 50 -p 0 -topk 1 -m2 diff_ransac_models/outdoor_ds.ckpt -pth RANSAC-Tutorial-Data/train/

Demo test

test E estimation on one scene without local optimiztion, no installation of MAGSAC++ needed.

$ python test.py -nf 2000 -m pretrained_models/saved_model_5PC_l_epi/model.net -bs 32 -fmat 1 -sam 3 -ds sacre_coeur -t 2 -pth <data_path>

Notes

[Useful parameters]
-pth: the source path of all datasets
-sam: samplers, 0 - Uniform sampler, 1,2 - Gumbel Sampler for 5PC/7PC, 3 - Gumbel Sampler for 8PC, default=0
-w0, -w1, -w2: coefficients of different loss combination, L pose, L classification, L essential
-fmat: 0 - E, 1 - F, default=0
-lr learning rate, default=1e-4
-t: threshold, default=0.75
-e: epochs, default=10
-bs: batch size, default=32
-rbs: batch size of RANSAC iterations, default=64
-tr: train or test mode, default=0
-nf: number of features, default=2000
-m: pretrained model or trained model
-snn: the threshold of SNN ratio filter
-ds dataset, single dataset
-bm in batch mode, using all the 12 testing scenes defined in utils.py
-p probabilities, 0-normalized weights, 1-unnormarlized weights, 2-logits, default=2,
-topk: whether to get the loss averaged on the topk models or all.
-sch: 0 - no learning rate scheduler used, 1 use scheduler from lr to eta_min, default=0.
-eta_min: float, the low bound for lr scheduler.
[Referred code]

The minimal solvers, model scoring functions, local optimization, etc. are re-implemented in PyTorch referring to MAGSAC. Also, thanks to the public repo of CLNet, NG-RANSAC, and the libraries of PyTorch, Kornia.

Citation

More details are covered in our paper and feel free to cite it if useful:

@InProceedings{wei2023generalized,
  title={Generalized differentiable RANSAC},
  author={Wei, Tong and Patel, Yash and Shekhovtsov, Alexander and Matas, Jiri and Barath, Daniel},
  booktitle={ICCV},
  year={2023}
}

Contact me at [email protected]

differentiable_ransac's People

Contributors

weitong8591 avatar yash0307 avatar

Stargazers

 avatar Miyakado avatar bw123l avatar Clayton Rabideau avatar  avatar  avatar  avatar  avatar  avatar  avatar davci avatar lnex avatar Jibril Muhammad Adam avatar Xuanhong Chen avatar  avatar Youngju Na avatar  avatar Gong Rui avatar  avatar Nando Metzger avatar Run avatar Jintao Zhang avatar  avatar Dengzhi avatar Zhongpai Gao avatar Sunghwan Hong avatar dengzhi avatar Wangchao_Yu avatar 庄庭达 avatar  avatar  avatar Bowen Du avatar  avatar Matthew avatar Haotian-Zh avatar Hengkai Guo avatar  avatar ming avatar liyang avatar Yun Xiang avatar 来自火星的魔方 avatar Jiahui Wang avatar Sheik Dawood avatar  avatar  avatar  avatar  avatar  avatar 张锋 avatar LiJin avatar Hao Lei avatar  avatar Kai-LUAN avatar nuo112 avatar Hongren Gong avatar  avatar  avatar ivan avatar SOUSIC avatar NengWang avatar XIANG CHEN(Richard) avatar  avatar  avatar Mia Thomas avatar  avatar Norio Kosaka avatar Fan-Huo avatar Xingan Ma avatar Eason Zhang avatar Yankai Chen avatar  avatar Uwe Hahne avatar Benjamin Rombaut avatar Siyu Ren avatar  avatar Dimitrios Athanasakis avatar Antyanta Bangunharcana avatar DeepDuke avatar Lars Heimdal avatar  avatar  avatar Matteo Poggi avatar István Sárándi avatar Yeonsoo Park avatar Jeff Carpenter avatar Tsun-Yi Yang avatar Hyeontae Son avatar Mikel Zhobro avatar  avatar  avatar Ray Phan avatar Daoyi Gao avatar  avatar Yi Xie avatar Neil Blake avatar Kambakhsh Eskandari avatar zw ruan avatar Jung-Hee Kim avatar SongShuangfu avatar P Vinohith Reddy avatar

Watchers

 avatar  avatar Giseop Kim avatar  avatar  avatar

differentiable_ransac's Issues

findFundamentalMatrix function of pymagsac is incompatible with the code

I haven't changed any code in the repo and test_magsac.py gives this error.

TypeError: findFundamentalMatrix(): incompatible function arguments. The following argument types are supported:
    1. (correspondences: numpy.ndarray[numpy.float64], w1: float, h1: float, w2: float, h2: float, probabilities: numpy.ndarray[numpy.float64], sampler: int = 4, use_magsac_plus_plus: bool = True, sigma_th: float = 1.0, conf: float = 0.99, min_iters: int = 50, max_iters: int = 1000, partition_num: int = 5) -> tuple

Invoked with: array([[443.9876 , 239.61026],
       [443.9876 , 239.61026],
       [443.9876 , 239.61026],
       ...,
       [328.67685, 423.61395],
       [328.67685, 423.61395],
       [328.67685, 423.61395]], dtype=float32), array([[423.3007 , 328.98468],
       [423.3007 , 328.98468],
       [423.3007 , 328.98468],
       ...,
       [502.42374, 161.1907 ],
       [502.42374, 161.1907 ],
       [502.42374, 161.1907 ]], dtype=float32), 575.0, 1077.0, 694.0, 1067.0; kwargs: probabilities=array([-2.132277e-03, -2.132277e-03, -2.132277e-03, ..., -2.458080e+01,
       -2.458080e+01, -2.458080e+01], dtype=float32), use_magsac_plus_plus=True, sigma_th=2.0, sampler_id=1, save_samples=True

pymagsac version is 0.3.dev0

Question regarding MatchLoss

def forward(self, models, gt_E, pts1, pts2, K1, K2, im_size1, im_size2, topk_flag=False, k=1):
    essential_loss = []
    for b in range(gt_E.shape[0]):
        if self.fmat:
            Es = K2[b].transpose(-1, -2) @ models[b] @ K1[b]
            pts1_1 = normalize_keypoints_tensor(denormalize_pts(pts1[b].clone(), im_size1[b]), K1[b])
            pts2_2 = normalize_keypoints_tensor(denormalize_pts(pts2[b].clone(), im_size2[b]), K2[b])
        else:
            pts1_1 = pts1[b].clone()
            pts2_2 = pts2[b].clone()
            Es = models[b]

        _, gt_R_1, gt_t_1, gt_inliers = cv2.recoverPose(
            gt_E[b].astype(np.float64),
            pts1_1.unsqueeze(1).cpu().detach().numpy(),
            pts2_2.unsqueeze(1).cpu().detach().numpy(),
            np.eye(3, dtype=gt_E.dtype)
        )

        # find the ground truth inliers
        gt_mask = np.where(gt_inliers.ravel() > 0, 1.0, 0.0).astype(np.bool)
        gt_mask = torch.from_numpy(gt_mask).to(pts1_1.device)

        # symmetric epipolar errors based on gt inliers
        geod = batch_episym(
            pts1_1[gt_mask].repeat(Es.shape[0], 1, 1),
            pts2_2[gt_mask].repeat(Es.shape[0], 1, 1),
            Es
        )
        e_l = torch.min(geod, geod.new_ones(geod.shape))
        if torch.isnan(e_l.mean()).any():
            print("nan values in pose loss")# .1*

        if topk_flag:
            topk_indices = torch.topk(e_l.mean(1), k=k, largest=False).indices
            essential_loss.append(e_l[topk_indices].mean())
        else:
            essential_loss.append(e_l.mean())
    # average
    return sum(essential_loss) / gt_E.shape[0]

def batch_episym(x1, x2, F, eps=1e-10):
    """
    Epipolar symmetric error from CLNet.
    x1, x2 : (B, N, 2)
    F: (B, 3, 3)    

    """    
    batch_size, num_pts = x1.shape[0], x1.shape[1] # (B, N, 2)
    x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(batch_size, num_pts, 3, 1) # (B, N, 3, 1)
    x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(batch_size, num_pts, 3, 1) # (B, N, 3, 1)
    F = F.reshape(-1, 1, 3, 3).repeat(1, num_pts, 1, 1) # (B, N, 3, 3)
    x2Fx1 = torch.matmul(x2.transpose(2, 3), torch.matmul(F, x1)).reshape(batch_size, num_pts) # (B, N)
    Fx1 = torch.matmul(F, x1).reshape(batch_size, num_pts, 3) # (B, N, 3)
    Ftx2 = torch.matmul(F.transpose(2, 3), x2).reshape(batch_size, num_pts, 3)
    ys = x2Fx1 ** 2 * (1.0 / (Fx1[:, :, 0] ** 2 + Fx1[:, :, 1] ** 2 + eps) + 1.0 / (Ftx2[:, :, 0] ** 2 + Ftx2[:, :, 1] ** 2 + eps))
    if torch.isnan(ys).any():
        print("ys is nan in batch_episym")
    return ys

When self.fmat is on, F of batch_episym contains Fundamental matrix so x1,x2 of batch_episym must be points in a pixel space. However, according to your code, they will be on the normalized image plane by the code below.

pts1_1 = normalize_keypoints_tensor(denormalize_pts(pts1[b].clone(), im_size1[b]), K1[b])
pts2_2 = normalize_keypoints_tensor(denormalize_pts(pts2[b].clone(), im_size2[b]), K2[b])

Is this intended?

[Docs] Link to the paper!

I love the implementation of the work and find it amazing. Could you please share the link to the original paper?

About Local Optimization

Hello, thanks for your exciting work!
In the Arxiv paper, you mentioned applying RANSAC-based local optimization to improve the accuracy.
I see that the variable lo in the RANSAC class has a default value of 0, and I don't see this value being modified anywhere in the code. So I was wondering whether you're using local optimization in your test inference, and if you do, how much does it improve the performance?

Question on singular_filter

At line, I am wondering why the singularity check is made with torch.linalg.matrix_rank instead of torch.linalg.cond?

Screenshot_20240611_003641

According to the screenshot, the matrix is extremely singular but still have a rank of 3, so I guess it will be better to check singularity with torch.linalg.cond?

Question about datas for training

Dear author,

Hello. Following your advice, I used the methods suggested in ngransac to generate some pairs.npy for training. However, the pairs.npy generated by this method differs in some aspects and cannot be used to train differentiate_ransac. Therefore, I plan to use the methods in ars_magsac to generate data for training. Could you please guide me on how to generate the series of h5 files required by ars_magsac? Alternatively, how can I modify the data generation method in ngransac to produce data suitable for training differentiate_ransac?

question about dataset of Application 1: Train -RANSAC for Two-view geometry estimation

Thank you for your great work. I have some questions about the dataset of Application 1: What is the dataset of Application 1? Is it diff_ransac_data? I noticed that the dataset contains pair_.npy files instead of images. How do you convert images into pair_.npy files? Is the input of Δ-RANSAC training a set of tentative correspondences? I read in the paper that Feature matcher and Trainable quality f can be trained and optimized together. Is this not done in Application 1? Sorry for asking so many questions at once, looking forward to your reply!

Camera system

Hi Tong,

When using the find_essential function implemented in kornia, may I ask which camera system the result would be defined in? I previously thought it is COLMAP camera coordinate (x right, y down, z towards image), but the result seems to be strange.

Question of inputs of ransac layer

Thank you for your amazing work. I am new to this field and would like to ask what the k1 and k2 in the ransac layer input represent?
image

What is the optimal configuration to train the model from scratch?

This work is really interesting!
I am investigating the use of this package for a monocular VO based problem I am current working on using ground truth poses which is mainly dealing with estimating relative camera poses.

Given that I have extracted all image pair key point correspondences and features, may I know what do you recommend to train differential RANSAC model completely from scratch for Fundamental Matrix estimation?

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.