Giter Site home page Giter Site logo

Comments (10)

zihangJiang avatar zihangJiang commented on May 20, 2024

You may refer to the code here to compare the output (prediction) and the target (ground truth).

TokenLabeling/validate.py

Lines 238 to 242 in 09bb641

# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))

from tokenlabeling.

Williamlizl avatar Williamlizl commented on May 20, 2024

You may refer to the code here to compare the output (prediction) and the target (ground truth).

TokenLabeling/validate.py

Lines 238 to 242 in 09bb641

# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))

And if I want to get the dir with the prediction , ?

from tokenlabeling.

zihangJiang avatar zihangJiang commented on May 20, 2024

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

from tokenlabeling.

Williamlizl avatar Williamlizl commented on May 20, 2024

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

from tokenlabeling.

zihangJiang avatar zihangJiang commented on May 20, 2024

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

from tokenlabeling.

Williamlizl avatar Williamlizl commented on May 20, 2024

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912- 114053-lvvit_s-384/model_best.pth.tar') model.eval() transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0)
` RuntimeError Traceback (most recent call last)
in
4 from timm.data import create_transform
5 model = lvvit_s(img_size=384)
----> 6 load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar')
7 model.eval()
8 transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct'])

~/.local/lib/python3.7/site-packages/tlt/utils/utils.py in load_pretrained_weights(model, checkpoint_path, use_ema, strict, num_classes)
109 def load_pretrained_weights(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
110 state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
--> 111 model.load_state_dict(state_dict, strict=strict)
112
113

~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1222 if len(error_msgs) > 0:
1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224 self.class.name, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
1226

RuntimeError: Error(s) in loading state_dict for LV_ViT:
Missing key(s) in state_dict: "head.weight", "head.bias", "aux_head.weight", "aux_head.bias". `

from tokenlabeling.

zihangJiang avatar zihangJiang commented on May 20, 2024

Please use the latest version of our repo. (pip install tlt==0.2.0)
This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

from tokenlabeling.

Williamlizl avatar Williamlizl commented on May 20, 2024

Please use the latest version of our repo. (pip install tlt==0.2.0)
This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar',strict=False,num_classes=2) model.eval() print(model) transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0)
If I use model = lvvit_s(img_size=384), it loads the official model, but how to load my finetune model ?

from tokenlabeling.

zihangJiang avatar zihangJiang commented on May 20, 2024

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

from tokenlabeling.

Williamlizl avatar Williamlizl commented on May 20, 2024

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

It does work, thank you

from tokenlabeling.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.