Comments (7)
what I mean is this use DataParallel in code to accelerate featutre extraction wirh more GPUs,I have change some code as follows
- add DinoV2ExtractFeatures attributes self.batch_hooks = []
- add self.batch_hooks.append(output.to("cuda:0")) in function def _generate_forward_hook(self)
- change: class DinoV2ExtractFeatures --> def call -->
res = self._hook_out[:, 1:, ...]
to
batch_hook = torch.cat(self.batch_hooks,dim=0)
res = batch_hook[:, 1:, ...] - change:
class DinoV2ExtractFeatures --> def call --> self._hook_out = None # Reset the hook
to:
self._hook_out = None # Reset the hook
self.batch_hooks = [] - change:
class DinoV2ExtractFeatures --> def call --> self.fh_handle = self.dino_model.blocks[self.layer]
to:
self.fh_handle = self.dino_model.module.blocks[self.layer]
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as T
from torch.nn.parallel import DataParallel
class DinoV2ExtractFeatures(nn.Module):
"""
Extract features from an intermediate layer in Dino-v2
"""
def __init__(self, dino_model:str = "dinov2_vitg14", layer: int = 31,
facet: str = "key", use_cls=False,
norm_descs=True, device: str = "cpu") -> None:
super().__init__()
"""
Parameters:
- dino_model: The DINO-v2 model to use
- layer: The layer to extract features from
- facet: "query", "key", or "value" for the attention
facets. "token" for the output of the layer.
- use_cls: If True, the CLS token (first item) is also
included in the returned list of descriptors.
Otherwise, only patch descriptors are used.
- norm_descs: If True, the descriptors are normalized
- device: PyTorch device to use
"""
self.vit_type: str = dino_model
# self.dino_model: nn.Module = torch.hub.load(
# 'facebookresearch/dinov2', dino_model)
self.dino_model: nn.Module = torch.hub.load(repo_or_dir="/home/ly/hub/dinov2-main",
model=dino_model, trust_repo=True, source='local')
for param in self.dino_model.parameters():
param.requires_grad = False
# more GPU
device_ids =[0,1,2]
device = torch.device(f"cuda:{device_ids[0]}")
self.device = torch.device(device)
self.dino_model.to(device)
self.dino_model_parallel = DataParallel(self.dino_model,device_ids=device_ids)
self.dino_model = self.dino_model_parallel.eval()
# self.dino_model = self.dino_model.eval()
self.layer: int = layer
self.facet = facet
if self.facet == "token":
self.fh_handle = self.dino_model.module.blocks[self.layer]. \
register_forward_hook(
self._generate_forward_hook())
else:
self.fh_handle = self.dino_model.module.blocks[self.layer]. \
attn.qkv.register_forward_hook(
self._generate_forward_hook())
self.use_cls = use_cls
self.norm_descs = norm_descs
# Hook data
self._hook_out = None
self.batch_hooks = []
def _generate_forward_hook(self):
def _forward_hook(module, inputs, output):
self._hook_out = output
# print(output.device)
self.batch_hooks.append(output.to("cuda:0"))
return _forward_hook
def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Parameters:
- img: The input image
"""
# print("__call__")
with torch.no_grad():
res = self.dino_model(img)
torch.cuda.synchronize()
if self.use_cls:
res = self._hook_out
else:
res = self._hook_out[:, 1:, ...]
batch_hook = torch.cat(self.batch_hooks,dim=0)
res = batch_hook[:, 1:, ...]
if self.facet in ["query", "key", "value"]:
d_len = res.shape[2] // 3
if self.facet == "query":
res = res[:, :, :d_len]
elif self.facet == "key":
res = res[:, :, d_len:2 * d_len]
else:
res = res[:, :, 2 * d_len:]
if self.norm_descs:
res = F.normalize(res, dim=-1)
self._hook_out = None # Reset the hook
self.batch_hooks = []
return res
def __del__(self):
self.fh_handle.remove()
from anyloc.
there are some things wrong in this code just for advice,
from anyloc.
Hey @VClay
Thanks for your observations
What I found is that the image input for extract_patch_descriptors is processed one by one. Would it be better to improve efficiency using batch size (if there are more GPUs)?
I guess you're talking about this segment
AnyLoc/scripts/dino_v2_global_vocab_vlad.py
Lines 343 to 361 in 2ae462d
Yes, you could use libraries like joblib.Parallel to accelerate this. It uses about 6 GB VRAM, so if you have a GPU with say 32 GB VRAM, you could run five of these in parallel on the single GPU.
the AnyLoc project Just loaded the pretrained dinov2_vitg14 model to extract patch descriptors?
I might not have understood you correctly here. What we do is create a global dataset class which has datasets of the same domain, and use it just for getting cluster centers
AnyLoc/scripts/dino_v2_global_vocab_vlad.py
Lines 369 to 381 in 2ae462d
We then extract VLAD descriptors of the given dataset (the global dataset is not used after this)
AnyLoc/scripts/dino_v2_global_vocab_vlad.py
Line 387 in 2ae462d
AnyLoc/scripts/dino_v2_global_vocab_vlad.py
Line 411 in 2ae462d
Let me know if you have other issues.
from anyloc.
Also, if you're just testing our model, you can now try using torch.hub
We're working on restructuring the README to include this.
from anyloc.
Closing due to inactivity. Feel free to open up if any problems occur.
from anyloc.
I hope these can help improve the code
from anyloc.
Okay, I see what you mean. Thanks for the insight and proposing a DataParallel implementation for the codebase. We should keep this in mind for future releases. We're currently in the process of restructuring the repository to allow easier use.
from anyloc.
Related Issues (20)
- AnyLoc-VLAD-DINO : "ViT-S8 layer 9 key facet features and 128 clusters for VLAD." HOT 1
- VLAD Caching for the database and query HOT 4
- one error in AnLoc/demo/ anyloc_vlad_generate.py row of 120th HOT 1
- need help when run run the cmd python dino_v2_vlad.py on terminal HOT 2
- About GT.npy of the datasets in the paper HOT 1
- Recall reproduce HOT 3
- local variable 'domain' error HOT 1
- AttributeError: 'DinoV2ExtractFeatures' object has no attribute 'fh_handle' HOT 1
- Conda environment installation error HOT 3
- vlad cluster center for custom dataset HOT 1
- NetVLAD Results HOT 3
- Distance used for query (image retrieval) HOT 1
- Question about the "Anyloc vocabularies" HOT 4
- Question about SAM HOT 1
- Recall reproduce on Baidu Mall dataset HOT 4
- Rotation invariant features HOT 1
- Nordland dataset HOT 2
- Use of project HOT 1
- How to reproduce Fig.3 Qualitative result about similarity map?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from anyloc.