Giter Site home page Giter Site logo

qubvel / ttach Goto Github PK

View Code? Open in Web Editor NEW
965.0 10.0 65.0 53 KB

Image Test Time Augmentation with PyTorch!

License: MIT License

Python 100.00%
augmentation deep-learning test-time-augmentation tta tta-wrapper pytorch computer-vision classification segmentation keypoint-detection

ttach's Introduction

TTAch

Image Test Time Augmentation with PyTorch!

Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].

           Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

Table of Contents

  1. Quick Start
  2. Transforms
  3. Aliases
  4. Merge modes
  5. Installation

Quick start

Segmentation model wrapping [docstring]:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])

Advanced Examples

Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
    
    # augment image
    augmented_image = transformer.augment_image(image)
    
    # pass to model
    model_output = model(augmented_image, another_input_data)
    
    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])
    
    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)
    
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms

Transform Parameters Values
HorizontalFlip - -
VerticalFlip - -
Rotate90 angles List[0, 90, 180, 270]
Scale scales
interpolation
List[float]
"nearest"/"linear"
Resize sizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
Add values List[float]
Multiply factors List[float]
FiveCrops crop_height
crop_width
int
int

Aliases

  • flip_transform (horizontal + vertical flips)
  • hflip_transform (horizontal flip)
  • d4_transform (flips + rotation 0, 90, 180, 270)
  • multiscale_transform (scale transform, take scales as input parameter)
  • five_crop_transform (corner crops + center crop)
  • ten_crop_transform (five crops + five crops on horizontal flip)

Merge modes

Installation

PyPI:

$ pip install ttach

Source:

$ pip install git+https://github.com/qubvel/ttach

Run tests

docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider

ttach's People

Contributors

gavrin-s avatar jzcruiser avatar qubvel avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ttach's Issues

what's the difference between scale and resize?

Hello,
Thanks for your great work. My quetion seems a little fool.If i want to fix a size just like 512x512,and i want a to implement scale tta(e.g. first:512 x scale;second:resize to 512,),what should i do ?
Thanks!

use on 5D data?

thanks for you work! but does it support 5D data? for example MRI data will be BCHWD,in testing ,B will be 1,thanks again!!

augment_image a batch of one

Hello!

I am trying to augment a batch of one

If I provide a tensor with shape (1,3,320,320):
augmented_image = transformer.augment_image(img_t.unsqueeze_(0))
I get:

ttach\functional.py:47, in scale(x, scale_factor, interpolation, align_corners)
45 def scale(x, scale_factor, interpolation="nearest", align_corners=None):
46 """scale batch of images by scale_factor with given interpolation mode"""
---> 47 h, w = x.shape[2:]
48 new_h = int(h * scale_factor)
49 new_w = int(w * scale_factor)
ValueError: too many values to unpack (expected 2)

If I provide a tensor with shape (3,320,320):
augmented_image = transformer.augment_image(img_t)
I get:

ttach\functional.py:7, in rot90(x, k)
5 def rot90(x, k=1):
6 """rotate batch of images by 90 degrees k times"""
----> 7 return torch.rot90(x, k, (2, 3))
RuntimeError: Rotation dim1 out of range, dim1 = 3

What to do?

Thank you!

About OOM problems

I did segmentation problems with standart unet. When I plan to inference the result using the tta library, there a OOM problem occurs. I am using nvidia 2080ti . It's ok when inference using naive inference. Does anyone know how to solve this problems ?

TTA for segmentation

please clear me TTA concept for segmentation.

lets see i have one test image, then apply flip left,flip right augmentation during testing.
I pass those three images [original,flip-left,flip-right] to model for prediction .
I will get three outputs , after that i have to directly average those prediction or take reverse of augmentation[ i.e again reverse the flipped images to original] and then average the prediction.

please clarify whats the way to merge prediction ?

FiveCrop for Segmentation

Hello,
Thank you for sharing your code.

I often use a sliding window method to inference a large image.
When I saw your ttach code, I thought it would be better to use FiveCrops or TenCrops for segmentation

However, I found fiveCrops is not working for segmentation.
As we know the shape of orignal image and cropped size, de-augmentation would be possible.
Can you implement FiveCrops and TenCrops for segmentation?

Thank you very much.

the usage of TTA

sorry, I don't know how to use TTA in my own model, could you please give an example?

Multi-GPU processing

Hi!

Is it possible to perform tta using Multi-GPU system?
I have got enough resources to calculate complex tta, but it uses only one GPU.
So I've got "out of memory" error on one of my GPUs, although other are free.

So slow...

I don't know why the prediction speed of the model is so slow when running the tta?

how can we use it with meta data?

tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform(384,384))
output = tta_model(images,meta)

like this it cant work

Custom transforms

Thanks for the library!

Is it also possible to add custom transforms to the tta pipeline from let's say e.g., albumentations?

AttributeError: 'SegmentationTTAWrapper' object has no attribute 'predict'

`import torch
import ttach as tta
import timm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')

tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")

image_dir = 'E:/PhD_Projects/segmentation models/patches'
image_filename_2 = 'image__02_02.tif'
image_path = os.path.join(image_dir, image_filename_2)
image = tiff.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_inference=get_preprocessing(preprocessing_fn_inference)
sample = preprocessing_inference(image=image)
image = sample['image']

x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = tta_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
pr_mask = (pr_mask.astype('float') * 255.0/16)
#pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')

=============================================================================

plt.imshow(pr_mask)
plt.show()`

Can anyone help me with this prediction problem? Thank you. @qubvel

merge_type = "tsharpen" cause "nan"

I use the example like:
model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="tsharpen")

when I use this model to predict, I found the output of model has value "nan"...

I look up the source code of this project, I have found when tsharpen model will do :
x = x**0.5

Is it the negative value in the tensors pass throught this operation will cause "nan"?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.