Giter Site home page Giter Site logo

jax-perspective-transform's Introduction

JAX Perspective Transform / Warp

Image perspective transform port/copypasta from Kornia (Pytorch) to JAX/Numpy. Could be buggy. Most of the docstrings and comments in the code should be mostly ignored because they're based on the old torch code. The main/relevant function is warp_perspective. It's docstring should be correct regarding the available options/parameters. You'll also need to use get_perspective_transform, as shown in the examples below.

Example

pip install git+https://github.com/josephrocca/jax-perspective-transform.git
# Example based on: https://kornia-tutorials.readthedocs.io/en/latest/warp_perspective.html
import jpt
import PIL.Image as Image

# the source points are the region to crop corners ([x0,y0], [x1,y1], etc. clockwise starting from top left)
points_src = np.array([
    [125., 150.], [562., 40.], [562., 282.], [54., 328.], # corners of bruce lee poster
])

# the destination points are the image vertexes ([x0,y0], [x1,y1], etc. clockwise starting from top left)
dst_h, dst_w = 64, 128
points_dst = np.array([
    [0., 0.], [dst_w - 1., 0.], [dst_w - 1., dst_h - 1.], [0., dst_h - 1.],
])

img = np.array(Image.open('bruce.png').convert("RGB"))  
img = img.transpose(2, 0, 1) # CxHxW / np.uint8
print(img.shape)

# compute perspective transform
M = jpt.get_perspective_transform(points_src, points_dst)

# warp the original image by the found transform
img_warped = jpt.warp_perspective(img.astype('float32'), M, dsize=(dst_h, dst_w))
print(img_warped.shape)

# convert back to HxWxC
img = img.transpose(1, 2, 0)
img_warped = img_warped.transpose(1, 2, 0).astype('uint8')

Image.fromarray(onp.array(img)).show()
Image.fromarray(onp.array(img_warped)).show()

image

Random Transform Example

E.g. for image augmentation:

import jpt
import time
import PIL.Image as Image
import numpy as onp

def r(a=0, b=1): # generate a random float between `a` and `b`
    key = jax.random.PRNGKey(int(time.time()*1000)) # just for demo!
    return a + jax.random.uniform(key) * (b - a)

img = np.array(Image.open('happy_dog.png').convert("RGB"))  
img = img.transpose(2, 0, 1) # CxHxW / np.uint8
print(img.shape)

# the source points are the region to crop corners ([x0,y0], [x1,y1], etc. clockwise starting from top left)
in_w, in_h = (img.shape[2], img.shape[1])
points_src = np.array([
    [0, 0], [in_w, 0], [in_w, in_h], [0, in_h],
])

# the destination points are the image vertexes ([x0,y0], [x1,y1], etc. clockwise starting from top left)
dst_w, dst_h = (img.shape[2], img.shape[1])
scale = 0.1 # <-- move corners by 10% of image width/height
points_dst = np.array([
    [0+r(-scale*dst_w, scale*dst_w), 0+r(-scale*dst_h, scale*dst_h)],
    [dst_w+r(-scale*dst_w, scale*dst_w), 0+r(-scale*dst_h, scale*dst_h)],
    [dst_w+r(-scale*dst_w, scale*dst_w), dst_h+r(-scale*dst_h, scale*dst_h)],
    [0+r(-scale*dst_w, scale*dst_w), dst_h+r(-scale*dst_h, scale*dst_h)],
])

# compute perspective transform
M = jpt.get_perspective_transform(points_src, points_dst)

# warp the original image by the found transform
img_warped = jpt.warp_perspective(img.astype('float32'), M, dsize=(dst_h, dst_w))
print(img_warped.shape)

# convert back to HxWxC
img = img.transpose(1, 2, 0)
img_warped = img_warped.transpose(1, 2, 0).astype('uint8')

Image.fromarray(onp.array(img)).show()
Image.fromarray(onp.array(img_warped)).show()

image

jax-perspective-transform's People

Contributors

josephrocca avatar

Stargazers

 avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.