Thank you for the great job. I find it very helpful for me to solve the inverse function. However, I have a problem as follows:
torch_model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.ReLU(),
)
def func1(y, A): # example function
return A - torch_model(y)
A = 10*torch.ones((2,1)).requires_grad_() # output shape of the torch_model_1 is (2,1)
y0 = torch.zeros((2,3)) # the torch_model accepts input shape of (2,3)
# finding a root
yroot = rootfinder(func1, y0, params=(A,))
The torch_model accepts input shape of (batch_size,3) and the outputs (batch_size, 1)
Given the outputs, I want to find the inputs values. So, I choose to let the param A to be the given outputs and try to get the yroot.
But the code show the errors:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[42], line 9
6 y0 = torch.zeros((2,3)) #
8 # finding a root
----> 9 yroot = rootfinder(func1, y0, params=(A,))
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/optimize/rootfinder.py:93](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/optimize/rootfinder.py:93), in rootfinder(fcn, y0, params, bck_options, method, **fwd_options)
91 pfunc = get_pure_function(fcn)
92 fwd_options["method"] = _get_rootfinder_default_method(method)
---> 93 return _RootFinder.apply(pfunc, y0, pfunc, False, fwd_options, bck_options,
94 len(params), *params, *pfunc.objparams())
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/torch/autograd/function.py:506](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/torch/autograd/function.py:506), in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/optimize/rootfinder.py:302](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/optimize/rootfinder.py:302), in _RootFinder.forward(ctx, fcn, y0, fwd_fcn, is_opt_method, options, bck_options, nparams, *allparams)
300 name = "rootfinder" if not is_opt_method else "minimizer"
301 method_fcn = get_method(name, methods, method)
--> 302 y = method_fcn(fwd_fcn, y0, params, **config)
304 ctx.fcn = fcn
305 ctx.is_opt_method = is_opt_method
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:181](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:181), in broyden1(fcn, x0, params, **kwargs)
167 @functools.wraps(_nonlin_solver, assigned=('__annotations__',)) # takes only the signature
168 def broyden1(fcn, x0, params=(), **kwargs):
169 """
170 Solve the root finder or linear equation using the first Broyden method [1]_.
171 It can be used to solve minimization by finding the root of the
(...)
179 https://web.archive.org/web/20161022015821/http://www.math.leidenuniv.nl/scripties/Rotten.pdf
180 """
--> 181 return _nonlin_solver(fcn, x0, params, "broyden1", **kwargs)
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:123](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:123), in _nonlin_solver(fcn, x0, params, method, alpha, uv0, max_rank, maxiter, f_tol, f_rtol, x_tol, x_rtol, line_search, verbose, **unused)
118 raise ValueError("Jacobian inversion yielded zero vector. "
119 "This indicates a bug in the Jacobian "
120 "approximation.")
122 if line_search:
--> 123 s, xnew, ynew, y_norm_new = _nonline_line_search(func, x, y, dx,
124 search_type=line_search)
125 else:
126 s = 1.0
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:277](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:277), in _nonline_line_search(func, x, y, dx, search_type, rdiff, smin)
274 return (phi(s + ds, store=False) - phi(s)) [/](https://file+.vscode-resource.vscode-cdn.net/) ds
276 if search_type == 'armijo':
--> 277 s, phi1 = _scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0],
278 amin=smin)
280 if s is None:
281 # No suitable step length found. Take the full Newton step,
282 # and hope for the best.
283 s = 1.0
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:295](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:295), in _scalar_search_armijo(phi, phi0, derphi0, c1, alpha0, amin, max_niter)
294 def _scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0, max_niter=20):
--> 295 phi_a0 = phi(alpha0)
296 if phi_a0 <= phi0 + c1 * alpha0 * derphi0:
297 return alpha0, phi_a0
File [~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:263](https://file+.vscode-resource.vscode-cdn.net/media/lk/disk1/DL_lk/202205_RZ_GAN%E6%A8%A1%E5%9E%8B/model/EvoOpt/~/anaconda3/envs/dpc/lib/python3.8/site-packages/xitorch/_impls/optimize/root/rootsolver.py:263), in _nonline_line_search..phi(s, store)
261 if s == tmp_s[0]:
262 return tmp_phi[0]
--> 263 xt = x + s * dx
264 v = func(xt)
265 p = _safe_norm(v)**2
RuntimeError: The size of tensor a (6) must match the size of tensor b (2) at non-singleton dimension 0