Giter Site home page Giter Site logo

jax_verify's Introduction

jax_verify: Neural Network Verification in JAX

tests status docs: latest

Jax_verify is a library containing JAX implementations of many widely-used neural network verification techniques.

Overview

If you just want to get started with using jax_verify to verify your neural networks, the main thing to know is we provide a simple, consistent interface for a variety of verification algorithms:

output_bounds = jax_verify.verification_technique(network_fn, input_bounds)

Here, network_fn is any JAX function, input_bounds define bounds over possible inputs to network_fn, and output_bounds will be the computed bounds over possible outputs of network_fn. verification_technique can be one of many algorithms implemented in jax_verify, such as interval_bound_propagation or crown_bound_propagation.

The overall approach is to use JAX’s powerful program transformation system, which allows us to analyze general network structures defined by network_fn and then to define corresponding functions for calculating verified bounds for these networks.

Verification Techniques

The methods currently provided by jax_verify include:

Installation

Stable: Just run pip install jax_verify and you can import jax_verify from any of your Python code.

Latest: Clone this directory and run pip install . from the directory root.

Getting Started

We suggest starting by looking at the minimal examples in the examples/ directory. For example, all the bound propagation techniques can be run with the run_boundprop.py script:

cd examples/
python3 run_boundprop.py --boundprop_method=interval_bound_propagation

For documentation, please refer to the API reference page.

Notes

Contributions of additional verification techniques are very welcome. Please open an issue first to let us know.

License

All code is made available under the Apache 2.0 License. Model parameters are made available under the Creative Commons Attribution 4.0 International (CC BY 4.0) License. See https://creativecommons.org/licenses/by/4.0/legalcode for more details.

Disclaimer

This is not an official Google product.

jax_verify's People

Contributors

bunelr avatar juesato avatar lberrada avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax_verify's Issues

Location of neural networks

Hello,

When I run this script:

python3 run_sdp_verify.py --epsilon=0.1 --dataset=mnist
--model_name=models/raghunathan18_pgdnn.pkl --use_exact_eig_train=True
--use_exact_eig_eval=True --opt_name=adam --lam_coeff=0.1 --nu_coeff=0.03
--custom_kappa_coeff=10000 --anneal_lengths=10000,4000,1000
--kappa_zero_after=2000

I see the neural network is stored in the path models/raghunathan18_pgdnn.pkl, but I didn't find the folder "models". Could you tell where I can find the files and how I can use my own networks?

Thanks in advance.

Having Trouble Installing Dependencies

I created a fresh conda environments and pip install ... 'ed the requirements.txt only to realize this had not installed GPU-compatible jax so after a little searching I then installed jax using conda install jax cuda-nvcc -c conda-forge -c nvidia as recommended by this page but then I got the following warnings telling me that I was not using GPU, which I'd like to use:

(jax_verify) chelseas@server:~/jax_verify$ python3 examples/run_boundprop.py --boundprop_method=backward_crown_bound_propagation                                                                                                  
I1228 22:37:51.228809 140315511407680 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.   
I1228 22:37:51.228915 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'                                                                                         
I1228 22:37:51.228991 140315511407680 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'                                                                     
I1228 22:37:51.229050 140315511407680 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'                                                                     
I1228 22:37:51.229235 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'                                                                          
I1228 22:37:51.229310 140315511407680 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W1228 22:37:51.229365 140315511407680 xla_bridge.py:362] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I opened an issue in the jax repo initially but I was thinking that someone maintaining this repo might also be able to help.

einshape requirement use of git://

Hi Rudy -

I was just trying to pip install a local clone of this repo and got this error:

(venv3.10) mfe@Michaels-MacBook-Air robustness_analysis % python -m pip install -e jax_verify
Obtaining file:///Users/mfe/code/robustness_analysis/jax_verify
  Preparing metadata (setup.py) ... done
Collecting einshape@ git+git://github.com/deepmind/einshape.git
  Cloning git://github.com/deepmind/einshape.git to /private/var/folders/cq/gg1tn9g54pn2vg_fh99fh3_00000gn/T/pip-install-s6x2khvr/einshape_5dc538cc56a94e73a5e1c2474dedfd11
  Running command git clone --filter=blob:none --quiet git://github.com/deepmind/einshape.git /private/var/folders/cq/gg1tn9g54pn2vg_fh99fh3_00000gn/T/pip-install-s6x2khvr/einshape_5dc538cc56a94e73a5e1c2474dedfd11
  fatal: unable to connect to github.com:
  github.com[0: 140.82.112.3]: errno=Operation timed out

  error: subprocess-exited-with-error

  × git clone --filter=blob:none --quiet git://github.com/deepmind/einshape.git /private/var/folders/cq/gg1tn9g54pn2vg_fh99fh3_00000gn/T/pip-install-s6x2khvr/einshape_5dc538cc56a94e73a5e1c2474dedfd11 did not run successfully.
  │ exit code: 128
  ╰─> See above for output.

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× git clone --filter=blob:none --quiet git://github.com/deepmind/einshape.git /private/var/folders/cq/gg1tn9g54pn2vg_fh99fh3_00000gn/T/pip-install-s6x2khvr/einshape_5dc538cc56a94e73a5e1c2474dedfd11 did not run successfully.
│ exit code: 128
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.

I fixed this by changing the einshape line in requirements.txt to use https:// instead of git:// per this post.

extensions is not a module

Hi!

Problem Description

I tried to run the examples for the functional_lagrangian extensions (https://github.com/google-deepmind/jax_verify/tree/master/jax_verify/extensions/functional_lagrangian), but the script fails because the extensions module was not found.

Reproduce

The precise steps I took:

# clone jax_verify repository
# create new conda environment with python3.8
pip install jax_verify
cd jax_verify/jax_verify/extensions/functional_lagrangian/run/
python3 run_functional_lagrangian.py --config=configs/config_adv_stochastic_input.py

The outcome is:

Traceback (most recent call last):
  File "run_functional_lagrangian.py", line 26, in <module>
    from jax_verify.extensions.functional_lagrangian import attacks
ModuleNotFoundError: No module named 'jax_verify.extensions'

I guess this is due to the extensions directory not being a python module (it has no__init__.py).

Where to Find Model and Dataset Files

I'm currently using your repository and I'm looking for the model and dataset files that are used in the project. I've reviewed the repository's README and documentation, but I couldn't locate specific instructions on where to find these files.

Could you please provide guidance on where I can find the following:

  1. Model Files: I'm interested in the pre-trained model files that the project uses for inference. Are these available in the repository, or should I download them from a specific source?

  2. Dataset Files: Additionally, I'd like to access the dataset used for training or evaluation. Could you please specify where I can obtain these dataset files?

It would be immensely helpful if you could provide details or direct links to these files if they exist within the repository or if they are hosted externally.

Thank you for your assistance, and I'm looking forward to exploring and working with your project.

'ClosedJaxpr' object has no attribute 'invars' on newer versions of jax

I'm working on a machine with Cuda 12.1 installed, so I needed to install a newer version of JAX (at least 0.4.7 I believe) which leads to the following error when calling functions like jax_verify.interval_bound_propagation or jax_verify.backward_crown_bound_propagation:

Traceback (most recent call last):
  File "jax_verify/jax_verify/src/ibp.py", line 485, in interval_bound_propagation
    output_bound, _ = bound_propagation.bound_propagation(
  File "jax_verify/jax_verify/src/bound_propagation.py", line 229, in bound_propagation
    simplified_graph = synthetic_primitives.simplify_graph(
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 68, in simplify_graph
    return _simplify_graph(var_is_bound, graph, graph_simplifier)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 88, in _simplify_graph
    return functools.reduce(
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 94, in _simplify_graph
    simplified_graph = graph_simplifier(graph, var_is_bound)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 1061, in <lambda>
    activation_simplifier = lambda graph, _: activation_detector(graph)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 1060, in <lambda>
    activation_detector = lambda graph: detect(activation_specs(), graph)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 273, in detect
    eqn, eqn_idx = _next_equation(synthetic_primitive_specs, graph, eqn_idx)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 299, in _next_equation
    ) = _matches(spec.graph, spec.capture_literals, graph, eqn_idx)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 417, in _matches
    ) = _matches(subspec, capture_literals, subgraph, 0)
  File "jax_verify/jax_verify/src/synthetic_primitives.py", line 356, in _matches
    spec_invar_indices = {invar: j for j, invar in enumerate(spec.invars)}
AttributeError: 'ClosedJaxpr' object has no attribute 'invars'

After a little bit of digging, I found that in this version of JAX, invars and outvars are not immediately available to a ClosedJaxpr, but instead can be accessed through the jaxpr member of ClosedJaxpr.

I fixed this specific line with an if statement, calling spec.jaxpr.invars instead when the type(spec) == ClosedJaxpr, but the same issue arose in other areas of the code. So for a quicker fix, I went into jax/_src/core.py and edited the ClosedJaxpr class adding the following properties:

invars = property(lambda self: self.jaxpr.invars)
outvars = property(lambda self: self.jaxpr.outvars)

which seemed to fix the issue everywhere. Is there a better solution to this issue?

Unable to rub the code on GPU

I am trying to run the code on GPU, but getting the following error:

File "/home/hesam/anaconda3/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: nvlink exited with non-zero error code 256, output: nvlink fatal : Input file '/tmp/tempfile-hesam-System-Product-Name-9fa0db7-13565-6002898000b0d.cubin' newer than toolkit (122 vs 120)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/run_sdp_verify.py", line 244, in
app.run(main)
File "/home/hesam/anaconda3/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/hesam/anaconda3/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/run_sdp_verify.py", line 180, in main
run_verification(PickleWriter())
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/run_sdp_verify.py", line 202, in run_verification
verif_instance = get_verif_instance(
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/run_sdp_verify.py", line 112, in get_verif_instance
bounds = boundprop_utils.boundprop(
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/jax_verify/extensions/sdp_verify/boundprop_utils.py", line 56, in boundprop
return boundprop_method(params, x, epsilon, input_bounds,
File "/home/hesam/Desktop/UCSD/SOC/codes/jax_verify-master/examples/jax_verify/extensions/sdp_verify/boundprop_utils.py", line 67, in _crown_ibp_boundprop
jnp.maximum(x - epsilon, input_bounds[0]),
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: nvlink exited with non-zero error code 256, output: nvlink fatal : Input file '/tmp/tempfile-hesam-System-Product-Name-9fa0db7-13565-6002898000b0d.cubin' newer than toolkit (122 vs 120)

CUDA version: 12.0
CUDNN version: 8.9

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.