Giter Site home page Giter Site logo

Some questions that need help about anyloc HOT 7 CLOSED

anyloc avatar anyloc commented on June 6, 2024
Some questions that need help

from anyloc.

Comments (7)

VClay avatar VClay commented on June 6, 2024 1

what I mean is this use DataParallel in code to accelerate featutre extraction wirh more GPUs,I have change some code as follows

  1. add DinoV2ExtractFeatures attributes self.batch_hooks = []
  2. add self.batch_hooks.append(output.to("cuda:0")) in function def _generate_forward_hook(self)
  3. change: class DinoV2ExtractFeatures --> def call -->
    res = self._hook_out[:, 1:, ...]
    to
    batch_hook = torch.cat(self.batch_hooks,dim=0)
    res = batch_hook[:, 1:, ...]
  4. change:
    class DinoV2ExtractFeatures --> def call --> self._hook_out = None # Reset the hook
    to:
    self._hook_out = None # Reset the hook
    self.batch_hooks = []
  5. change:
    class DinoV2ExtractFeatures --> def call --> self.fh_handle = self.dino_model.blocks[self.layer]
    to:
    self.fh_handle = self.dino_model.module.blocks[self.layer]
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as T
from torch.nn.parallel import DataParallel
class DinoV2ExtractFeatures(nn.Module):
    """
        Extract features from an intermediate layer in Dino-v2
    """

    def __init__(self, dino_model:str = "dinov2_vitg14", layer: int = 31,
                 facet: str = "key", use_cls=False,
                 norm_descs=True, device: str = "cpu") -> None:
        super().__init__()
        """
            Parameters:
            - dino_model:   The DINO-v2 model to use
            - layer:        The layer to extract features from
            - facet:    "query", "key", or "value" for the attention
                        facets. "token" for the output of the layer.
            - use_cls:  If True, the CLS token (first item) is also
                        included in the returned list of descriptors.
                        Otherwise, only patch descriptors are used.
            - norm_descs:   If True, the descriptors are normalized
            - device:   PyTorch device to use
        """
        self.vit_type: str = dino_model
        # self.dino_model: nn.Module = torch.hub.load(
        #     'facebookresearch/dinov2', dino_model)

        self.dino_model: nn.Module = torch.hub.load(repo_or_dir="/home/ly/hub/dinov2-main",
                                                    model=dino_model, trust_repo=True, source='local')

        for param in self.dino_model.parameters():
            param.requires_grad = False
        # more GPU
        device_ids =[0,1,2]
        device = torch.device(f"cuda:{device_ids[0]}")
        self.device = torch.device(device)
        self.dino_model.to(device)
        self.dino_model_parallel = DataParallel(self.dino_model,device_ids=device_ids)
        self.dino_model = self.dino_model_parallel.eval()
        # self.dino_model = self.dino_model.eval()
        self.layer: int = layer
        self.facet = facet
        if self.facet == "token":
            self.fh_handle = self.dino_model.module.blocks[self.layer]. \
                register_forward_hook(
                self._generate_forward_hook())
        else:
            self.fh_handle = self.dino_model.module.blocks[self.layer]. \
                attn.qkv.register_forward_hook(
                self._generate_forward_hook())
        self.use_cls = use_cls
        self.norm_descs = norm_descs
        # Hook data
        self._hook_out = None
        self.batch_hooks = []

    def _generate_forward_hook(self):
        def _forward_hook(module, inputs, output):
            self._hook_out = output
            # print(output.device)
            self.batch_hooks.append(output.to("cuda:0"))
        return _forward_hook

    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        """
            Parameters:
            - img:   The input image
        """
        # print("__call__")
        with torch.no_grad():
            res = self.dino_model(img)
            torch.cuda.synchronize()
            if self.use_cls:
                res = self._hook_out
            else:
                res = self._hook_out[:, 1:, ...]
                batch_hook = torch.cat(self.batch_hooks,dim=0)
                res = batch_hook[:, 1:, ...]
            if self.facet in ["query", "key", "value"]:
                d_len = res.shape[2] // 3
                if self.facet == "query":
                    res = res[:, :, :d_len]
                elif self.facet == "key":
                    res = res[:, :, d_len:2 * d_len]
                else:
                    res = res[:, :, 2 * d_len:]
        if self.norm_descs:
            res = F.normalize(res, dim=-1)
        self._hook_out = None  # Reset the hook
        self.batch_hooks = []
        return res

    def __del__(self):
        self.fh_handle.remove()

from anyloc.

VClay avatar VClay commented on June 6, 2024 1

there are some things wrong in this code just for advice,

from anyloc.

TheProjectsGuy avatar TheProjectsGuy commented on June 6, 2024

Hey @VClay

Thanks for your observations

What I found is that the image input for extract_patch_descriptors is processed one by one. Would it be better to improve efficiency using batch size (if there are more GPUs)?

I guess you're talking about this segment

def extract_patch_descriptors(indices,
use_set: Literal["vpr", "distractor", "global"]="vpr"):
patch_descs = []
for i in tqdm(indices, disable=not verbose):
if use_set == "vpr":
img = vpr_ds[i][0]
elif use_set == "distractor":
img = vpr_distractor_ds[i][0]
elif use_set == "global":
img = glob_ds[i][0]
else:
raise ValueError(f"Invalid use set: {use_set}")
c, h, w = img.shape
h_new, w_new = (h // 14) * 14, (w // 14) * 14
img_in = tvf.CenterCrop((h_new, w_new))(img)[None, ...]
ret = dino(img_in.to(device))
patch_descs.append(ret.cpu())
patch_descs = torch.cat(patch_descs, dim=0) # [N, n_p, d_dim]
return patch_descs

Yes, you could use libraries like joblib.Parallel to accelerate this. It uses about 6 GB VRAM, so if you have a GPU with say 32 GB VRAM, you could run five of these in parallel on the single GPU.

the AnyLoc project Just loaded the pretrained dinov2_vitg14 model to extract patch descriptors?

I might not have understood you correctly here. What we do is create a global dataset class which has datasets of the same domain, and use it just for getting cluster centers

# Get cluster centers using global voccabulary
if verbose:
print("Building VLAD cluster centers...")
num_db = len(glob_ds)
db_indices = np.arange(0, num_db, largs.sub_sample_db_vlad)
# Database descriptors (for VLAD clusters): [n_db, n_d, d_dim]
full_db_vlad = extract_patch_descriptors(db_indices, "global")
if verbose:
print(f"Database descriptors shape: {full_db_vlad.shape}")
d_dim = full_db_vlad.shape[2]
if verbose:
print(f"Descriptor dimensionality: {d_dim}")
vlad.fit(ein.rearrange(full_db_vlad, "n k d -> (n k) d"))

We then extract VLAD descriptors of the given dataset (the global dataset is not used after this)

# Get VLADs of the queries

Let me know if you have other issues.

from anyloc.

TheProjectsGuy avatar TheProjectsGuy commented on June 6, 2024

Also, if you're just testing our model, you can now try using torch.hub

We're working on restructuring the README to include this.

from anyloc.

TheProjectsGuy avatar TheProjectsGuy commented on June 6, 2024

Closing due to inactivity. Feel free to open up if any problems occur.

from anyloc.

VClay avatar VClay commented on June 6, 2024

I hope these can help improve the code

from anyloc.

TheProjectsGuy avatar TheProjectsGuy commented on June 6, 2024

Okay, I see what you mean. Thanks for the insight and proposing a DataParallel implementation for the codebase. We should keep this in mind for future releases. We're currently in the process of restructuring the repository to allow easier use.

from anyloc.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.