I am using a dataset that sometimes contains tiles that don't overlap. I've added an image of the coarse offsets. We can see that one offset (the offset between tiles (0, 0) and (0, 1) is not computed).
I use the workflow from the em_stitching Colab notebook. I compute the coarse offsets and mesh with this code:
from sofima import stitch_rigid
cx, cy = stitch_rigid.compute_coarse_offsets(grid_size, tile_map)
f, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].quiver((0, 1, 2), (0, 1, 2), cx[0, 0, ...], cx[1, 0, ...])
ax[0].set_ylim(-0.5, 2.5)
ax[0].set_xlim(-0.5, 1.5)
ax[0].set_title('horizontal NNs')
ax[1].quiver((0, 1, 2), (0, 1, 2), cy[0, 0, ...], cy[1, 0, ...])
ax[1].set_ylim(-0.5, 1.5)
ax[1].set_xlim(-0.5, 2.5)
ax[1].set_title('vertical NNs')
coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy)
from sofima import stitch_elastic, flow_utils, mesh
stride = 20
cx = np.squeeze(cx)
cy = np.squeeze(cy)
fine_x, offsets_x = stitch_elastic.compute_flow_map(tile_map, cx, 0, stride=(stride, stride), batch_size=4) # (x,y) -> (x+1,y)
fine_y, offsets_y = stitch_elastic.compute_flow_map(tile_map, cy, 1, stride=(stride, stride), batch_size=4) # (x,y) -> (x,y+1)
# "min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0}
kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0}
fine_x = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_x.items()}
fine_y = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_y.items()}
kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1}
fine_x = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in fine_y.items()}
data_x = (cx, fine_x, offsets_x)
data_y = (cy, fine_y, offsets_y)
fx, fy, x, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
data_x, data_y, list(tile_map.keys()),
coarse_mesh[:, 0, ...], stride=(stride, stride),
tile_shape=next(iter(tile_map.values())).shape)
@jax.jit
def prev_fn(x):
target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=x, fx=fx,
fy=fy, stride=(stride, stride))
x = jax.vmap(target_fn)(nbors)
return jnp.transpose(x, [1, 0, 2, 3])
config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
num_iters=1000, max_iters=20000, stop_v_max=0.001,
dt_max=100, prefer_orig_order=True,
start_cap=0.1, final_cap=10., remove_drift=True)
x, ekin, t = mesh.relax_mesh(x, None, config, prev_fn=prev_fn)
OverflowError Traceback (most recent call last)
in <cell line: 16>()
14
15 # Compute flow maps for horizontal and vertical directions
---> 16 fine_x, offsets_x = stitch_elastic.compute_flow_map(tile_map, cx, 0, stride=(stride, stride), batch_size=4) # (x,y) -> (x+1,y)
17 fine_y, offsets_y = stitch_elastic.compute_flow_map(tile_map, cy, 1, stride=(stride, stride), batch_size=4) # (x,y) -> (x,y+1)
18 /usr/local/lib/python3.10/dist-packages/sofima/stitch_elastic.py in compute_flow_map(tile_map, offset_map, axis, patch_size, stride, batch_size)
241 rounded_offset = stride[::-1] * np.round(offset / stride[::-1])
242
--> 243 overlap = -int(offset[axis])
244 overlap = pre.shape[1 - axis] - (
245 (pre.shape[1 - axis] - overlap) // stride[1 - axis] * stride[1 - axis]
OverflowError: cannot convert float infinity to integer