# perform binary search
eps_gx_UB = 1000000.0
eps_gx_LB = 0.0
is_pos = True
is_neg = True
# eps = eps_gx_LB*2
# eps = args.eps
while eps_gx_UB - eps_gx_LB > 0.00001:
ptb = PerturbationLpNorm(norm=2, eps=eps)
image = BoundedTensor(input, ptb)
pred = model(image)
label = torch.argmax(pred, dim=1).cpu().numpy()
# for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']:
lb, ub = model.compute_bounds(x=(image,), method='IBP+backward')
gap_gx = torch.min(lb)
lb = lb.detach().cpu().numpy()
ub = ub.detach().cpu().numpy()
print("Bounding method:", method)
for i in range(N):
print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
for j in range(n_classes):
indicator = '(ground-truth)' if j == true_label[i] else ''
print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format(
j=j, l=lb[i][j], u=ub[i][j], ind=indicator))
print()
if gap_gx > 0:
if gap_gx < 0.01:
eps_gx_LB = eps
return eps
break
if is_pos: # so far always > 0, haven't found eps_UB
eps_gx_LB = eps
eps *= 10
else:
eps_gx_LB = eps
eps = (eps_gx_LB + eps_gx_UB) / 2
is_neg = False
else:
if is_neg: # so far always < 0, haven't found eps_LB
eps_gx_UB = eps
eps /= 10
else:
eps_gx_UB = eps
eps = (eps_gx_LB + eps_gx_UB) / 2
is_pos = False
counter += 1
if counter >= 500:
return eps
break
print("[L2][binary search] step = {}, eps = {:.5f}, gap_gx = {:.2f}".format(counter, eps, gap_gx))