rohithreddy024 / visual-attention-pytorch Goto Github PK
View Code? Open in Web Editor NEWImplementation of Attention for Fine-Grained Categorization in Pytorch
Implementation of Attention for Fine-Grained Categorization in Pytorch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_resnet():
resnet = torchmodels.resnet50(pretrained=True)
resnet.conv1 = nn.Conv2d(3, 64, 5, 1, 2, bias=False)
resnet = nn.Sequential(*list(resnet.children())[:-2])
checkpoint = T.load(pretrained_glimpsemodel, map_location=device)
resnet.load_state_dict(checkpoint["model_dict"])
# We fix the parameters of resnet and do not train it
for param in resnet.parameters():
param.requires_grad = False
return get_cuda(resnet)
Hi, @rohithreddy024
I get no result image when run "python main.py --task=view_glimpses"
but there is no error. and "imgs" folder also created.
What's wrong to me?
Thanks in advance.
Best,
@bemoregt.
Hi, @rohithreddy024
I met this error when run main.py.
tezro@tezro:~/Deep07/Visual-Attention-Pytorch$ python3 main.py
Namespace(batch_size=32, epochs=101, k=2, lr=0.0001, n_c=120, n_glimpses=1, n_jobs=4, n_samples=20, num_workers=4, resume_training=False, rnn_hidden=2048, start_size=2, std_dev=0.2, task='train', valid_size=0.3)
/home/tezro/.local/lib/python3.8/site-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
Traceback (most recent call last):
File "main.py", line 278, in
training()
File "main.py", line 187, in training
la, lb, lr = train_model(x_batch, labels, info)
File "main.py", line 121, in train_model
_, _, _, _, output = my_model(l, hc1, cv, info, last = True)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/tezro/.local/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/models.py", line 166, in forward
g = self.glimpse(l_prev, info) #Extract glimpse based on tuple
File "/home/tezro/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/models.py", line 68, in forward
phi = retina(l_prev, info, self.opt, self.training) #Extracts high, medium and low resolution patches corresponding to a location centre; Also compressed to 96x96 size
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 64, in retina
phi.append(extract_patches_batch(l, size))
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 54, in extract_patches_batch
patches = Parallel(n_jobs=opt.n_jobs, backend="threading")(
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 1061, in call
self.retrieve()
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 940, in retrieve
self._output.extend(job.get(timeout=self.timeout))
File "/usr/lib/python3.8/multiprocessing/pool.py", line 771, in get
raise self._value
File "/usr/lib/python3.8/multiprocessing/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 595, in call
return self.func(*args, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 262, in call
return [func(*args, **kwargs)
File "/home/tezro/.local/lib/python3.8/site-packages/joblib/parallel.py", line 262, in
return [func(*args, **kwargs)
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 34, in get_patch
img, imgsize = get_image(opt, info[i], istrain) #Get context image
File "/home/tezro/Deep07/Visual-Attention-Pytorch/helper_functions.py", line 23, in get_image
img = img.crop((from_w, from_h, from_w + size, from_h + size))
File "/usr/local/lib/python3.8/dist-packages/PIL/Image.py", line 1128, in crop
return self._new(self._crop(self.im, box))
File "/usr/local/lib/python3.8/dist-packages/PIL/Image.py", line 1142, in _crop
x0, y0, x1, y1 = map(int, map(round, box))
TypeError: type Tensor doesn't define round method
What's wrong to me?
Thanks.
#crash line
resnet = torchmodels.resnet50(pretrained=True)
#error message
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1091)>
#add this can fixed
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
def get_image(opt, info, istrain):
from_w, from_h, ind = info.cpu().numpy()[0], info.cpu().numpy()[1], info.cpu().numpy()[2]
if istrain:
filename = opt.train_files[ind][0][0]
else:
filename = opt.test_files[ind][0][0]
img = Image.open(os.path.join(images_folder, filename)).convert('RGB')
width, height = img.size
size = min(width, height)
img = img.crop((from_w, from_h, from_w + size, from_h + size))
return img, size
def get_patch(i, l, size, info, opt, istrain): # Get patch for each datapoint in a batch
img, imgsize = get_image(opt, info[i], istrain) # Get context image
patch_size = imgsize // 4
patch_size *= size # original size of patch before compressing it to 96x96
# location resized from [-1,1] to [image_size, image_size]
l_denorm = (0.5 * imgsize * (1 + l[i])).astype(int)
from_x, from_y = l_denorm[0] - (patch_size // 2), l_denorm[1] - (patch_size // 2)
to_x, to_y = from_x + patch_size, from_y + patch_size
# pad context image if corners of the patch exceeds its borders
if (from_x < 0 or from_y < 0 or to_x > imgsize or to_y > imgsize):
temp = patch_size // 2 + 1
img = ImageOps.expand(img, border=temp, fill='black')
from_x += temp
from_y += temp
to_x += temp
to_y += temp
img = img.crop((from_x, from_y, to_x, to_y))
img = opt.my_transform(img).unsqueeze(0)
return img
def extract_patches_batch(l, size, info, opt, istrain): #Extract square patches for given batch with given location as center and given size as length
batch_size = len(l)
patches = Parallel(n_jobs=opt.n_jobs, backend="threading")
(
delayed(get_patch)(i, l, size, info, opt, istrain) for i in range(batch_size)
# Parallelize get_patch function as its execution for each datapoint in batch is independent of others
)
patches = get_cuda(T.cat(patches, dim=0))
return patches
def retina(l, info, opt, istrain):
l = l.cpu().numpy()
phi = []
size = opt.start_size
for i in range(opt.k):
phi.append(extract_patches_batch(l, size, info, opt, istrain))
size *= 2
return phi
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.