执行到损失函数In[18]:的时候抛出异常
2022-08-01 21:06:21,049 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
2022-08-01 21:06:21,054 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
2022-08-01 21:06:21,056 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=16468, ip=172.16.4.140, repr=<secretflow.device.device.spu.SPURuntime object at 0x7f38e1bd0be0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
2022-08-01 21:06:21,305 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
2022-08-01 21:06:21,311 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
2022-08-01 21:06:21,314 ERROR worker.py:94 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 807, in computation_maker
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1779, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 982, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 2441, in _vjp
out_primal, out_vjp = ad.vjp(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 129, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 606, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/api.py", line 466, in cache_miss
out_flat = xla.xla_call(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/ad.py", line 344, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 216, in process_call
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1771, in bind
return call_bind(self, fun, *args, **params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/core.py", line 1787, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1534, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1816, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 656, in dot
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Incompatible shapes for dot: got (113, 45) and (30,).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ray::SPURuntime.run() (pid=32105, ip=172.16.4.141, repr=<secretflow.device.device.spu.SPURuntime object at 0x7fdfbc16cac0>)
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/secretflow/device/device/spu.py", line 196, in run
cfn, output = jax.xla_computation(fn, return_shape=True)(*args, **kwargs)
File "/tmp/ipykernel_15144/2203275946.py", line 4, in fit
File "/tmp/ipykernel_15144/4019579140.py", line 5, in train_step
File "/tmp/ipykernel_15144/1730905303.py", line 17, in loss
File "/tmp/ipykernel_15144/1730905303.py", line 12, in predict
File "/opt/software/anaconda3/envs/secretflow/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2692, in dot
return lax.dot(a, b, precision=precision)
TypeError: Incompatible shapes for dot: got (113, 45) and (30,).