Giter Site home page Giter Site logo

visual-attention-pytorch's People

Contributors

rohithreddy024 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

visual-attention-pytorch's Issues

TypeError: type Tensor doesn't define __round__ method bug fixed

def get_image(opt, info, istrain):
from_w, from_h, ind = info.cpu().numpy()[0], info.cpu().numpy()[1], info.cpu().numpy()[2]
if istrain:
filename = opt.train_files[ind][0][0]
else:
filename = opt.test_files[ind][0][0]

img = Image.open(os.path.join(images_folder, filename)).convert('RGB')
width, height = img.size
size = min(width, height)
img = img.crop((from_w, from_h, from_w + size, from_h + size))
return img, size

def get_patch(i, l, size, info, opt, istrain): # Get patch for each datapoint in a batch
img, imgsize = get_image(opt, info[i], istrain) # Get context image
patch_size = imgsize // 4
patch_size *= size # original size of patch before compressing it to 96x96
# location resized from [-1,1] to [image_size, image_size]
l_denorm = (0.5 * imgsize * (1 + l[i])).astype(int)
from_x, from_y = l_denorm[0] - (patch_size // 2), l_denorm[1] - (patch_size // 2)
to_x, to_y = from_x + patch_size, from_y + patch_size
# pad context image if corners of the patch exceeds its borders
if (from_x < 0 or from_y < 0 or to_x > imgsize or to_y > imgsize):
temp = patch_size // 2 + 1
img = ImageOps.expand(img, border=temp, fill='black')
from_x += temp
from_y += temp
to_x += temp
to_y += temp

img = img.crop((from_x, from_y, to_x, to_y))
img = opt.my_transform(img).unsqueeze(0)
return img

def extract_patches_batch(l, size, info, opt, istrain): #Extract square patches for given batch with given location as center and given size as length
batch_size = len(l)
patches = Parallel(n_jobs=opt.n_jobs, backend="threading")
(
delayed(get_patch)(i, l, size, info, opt, istrain) for i in range(batch_size)
# Parallelize get_patch function as its execution for each datapoint in batch is independent of others
)
patches = get_cuda(T.cat(patches, dim=0))
return patches

def retina(l, info, opt, istrain):
l = l.cpu().numpy()
phi = []
size = opt.start_size

for i in range(opt.k):
    phi.append(extract_patches_batch(l, size, info, opt, istrain))
    size *= 2

return phi

CERTIFICATE_VERIFY_FAILED fixed

#crash line
resnet = torchmodels.resnet50(pretrained=True)

#error message
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1091)>

#add this can fixed
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

I met this error when run main.py.

Hi, @rohithreddy024

I met this error when run main.py.

tezro@tezro:~/Deep07/Visual-Attention-Pytorch$ python3 main.py
Namespace(batch_size=32, epochs=101, k=2, lr=0.0001, n_c=120, n_glimpses=1, n_jobs=4, n_samples=20, num_workers=4, resume_training=False, rnn_hidden=2048, start_size=2, std_dev=0.2, task='train', valid_size=0.3)
/home/tezro/.local/lib/python3.8/site-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
Traceback (most recent call last):
File "main.py", line 278, in
training()
File "main.py", line 187, in training
la, lb, lr = train_model(x_batch, labels, info)
File "main.py", line 121, in train_model
_, _, _, _, output = my_model(l, hc1, cv, info, last = True)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/tezro/.local/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/models.py", line 166, in forward
g = self.glimpse(l_prev, info) #Extract glimpse based on tuple
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/models.py", line 68, in forward
phi = retina(l_prev, info, self.opt, self.training) #Extracts high, medium and low resolution patches corresponding to a location centre; Also compressed to 96x96 size
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 64, in retina
phi.append(extract_patches_batch(l, size))
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 54, in extract_patches_batch
patches = Parallel(n_jobs=opt.n_jobs, backend="threading")(
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 1061, in call
self.retrieve()
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 940, in retrieve
self._output.extend(job.get(timeout=self.timeout))
File "/usr/lib/python3.8/multiprocessing/pool.py", line 771, in get
raise self._value
File "/usr/lib/python3.8/multiprocessing/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 595, in call
return self.func(*args, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 262, in call
return [func(*args, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 262, in
return [func(*args, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 34, in get_patch
img, imgsize = get_image(opt, info[i], istrain) #Get context image
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 23, in get_image
img = img.crop((from_w, from_h, from_w + size, from_h + size))
File "/usr/local/lib/python3.8/dist-packages/PIL/Image.py", line 1128, in crop
return self._new(self._crop(self.im, box))
File "/usr/local/lib/python3.8/dist-packages/PIL/Image.py", line 1142, in _crop
x0, y0, x1, y1 = map(int, map(round, box))
TypeError: type Tensor doesn't define round method

What's wrong to me?

Thanks.

@bemoregt.

cpu load checkpoint fixed

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_resnet():
resnet = torchmodels.resnet50(pretrained=True)

resnet.conv1 = nn.Conv2d(3, 64, 5, 1, 2, bias=False)
resnet = nn.Sequential(*list(resnet.children())[:-2])

checkpoint = T.load(pretrained_glimpsemodel, map_location=device)
resnet.load_state_dict(checkpoint["model_dict"])
# We fix the parameters of resnet and do not train it
for param in resnet.parameters():
    param.requires_grad = False

return get_cuda(resnet)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.