wjmaddox / drbayes Goto Github PK
View Code? Open in Web Editor NEWCode Repo for "Subspace Inference for Bayesian Deep Learning"
License: BSD 2-Clause "Simplified" License
Code Repo for "Subspace Inference for Bayesian Deep Learning"
License: BSD 2-Clause "Simplified" License
Hello @wjmaddox,
Thanks for sharing your code. I found your paper very inspiring.
I wonder if the sampling of SWAG is correct. In particular, it appears that the standard Gaussian sample is multiplied by the vector of variances for the SWAG-Diagonal part:
Whereas in the original SWAG paper and its implementation the standard Gaussian is multiplied by the vector of standard deviations to sample from N(0, diag(variances)):
Am I missing something? Thanks for your time.
Hello,
The approach and results presented in the paper "Subspace Inference for Bayesian Deep Learning" impress me a lot and I would like to reproduce the experiments, specifically the UCI experiments, using this repository.
I followed the instructions and process described in the README
and managed to get things run. However, I failed to obtain similar numbers for certain datasets.
To be more precise:
versions of libraries match.
the datasets are downloaded from Google Drive as suggested in the README
under drbayes/experiments/uci_exps
.
all experiments on small UCI datasets are launched by the script run_ucismall.sh
under drbayes/experiments/uci_exps/bayesian_benchmarks/tasks
.
the reproduced unnormalized test likelihoods for the yacht
dataset are significantly different from those reported in the paper as shown below.
paper result | reproduce result | |
---|---|---|
PCA+ESS (SI) | -0.225 ± 0.400 | -2.493 ± 0.067 |
SWAG | -0.404 ± 0.418 | -2.545 ± 0.053 |
I would be extremely grateful if anyone could suggest what I could've done wrong in my reproduction experiments.
Thank you in advance!
Hi,
In subspaces.py
you use two different SVD algorithms: TruncatedSVD
which seems to be unused and then later randomized_svd
. Which one is the correct one?
Regards,
Maciej
Hello!
I was trying to run the visualisation jupyter notebook and had a couple of issues (I installed the repo as a package as mentioned):
scikit_learn>=0.20.2
. However , currently version 0.24 gets installed which actually throws some errors (I think some things have been refactored in the newest versions). I had to downgrade to 0.20.2 to get it to work, so I believe this should be changed to scikit_learn==0.20.2
(I actually just noticed that this is specified correctly in requirements.txt, just not in setup.py)And one more thing, in the curve subspaces section of the notebook it says "Note that for this to work you need to add the repo https://github.com/timgaripov/dnn-mode-connectivity to your Python path". However, I think the version of the curves.py file in that repo might not be the same as that used in the notebook. In fact, the notebook calls
model = curves.CurveNet(curve, architecture, 3, fix_end=True, fix_start=True,
architecture_kwargs=model_cfg.kwargs)
but in the original repo that function is written as:
class CurveNet(Module):
def __init__(self, num_classes, curve, architecture, num_bends, fix_start=True, fix_end=True,
architecture_kwargs={}):
so I believe the num_classes
argument is missing in the function call in the notebook.
Thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.