Giter Site home page Giter Site logo

visualsearch_mxnet's Introduction

Visual Search with MXNet Gluon and HNSW

How does it work?

In this tutorial we will create a Visual Search engine for browsing 1M amazon product images.

First step, indexing the image dataset by computing the image embeddings using a pre-trained network as a featurizer:

Second step, query the index using an efficient K-NN search algorithm, here we use Hierarchical Navigable Small World graphs (HNSW)

Pre-requisite:

import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon.model_zoo import vision
import multiprocessing
from mxnet.gluon.data.vision.datasets import ImageFolderDataset
from mxnet.gluon.data import DataLoader
import numpy as np
import wget
import imghdr
import json
import pickle
import hnswlib
import numpy as np
import glob, os, time
import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec
import urllib.parse
import urllib
import gzip
%matplotlib inline

Data originally from here: http://jmcauley.ucsd.edu/data/amazon/

Image-based recommendations on styles and substitutes J. McAuley, C. Targett, J. Shi, A. van den Hengel SIGIR, 2015

Downloading images

We only use a subset of the total number of images, here 1M (it takes about 40 minutes to download all the data on an ec2 instance)

subset_num = 1000000

Beware, if using the full dataset this will download 300GB of images, make sure you have the appropriate hardware and connexion! Alternatively, just set images_path to a directory containing images following this format ID.jpg

data_path = 'metadata.json'
images_path = '/data/amazon_images_subset'
num_lines = 0
num_lines = sum(1 for line in open(data_path))
assert num_lines >= subset_num, "Subset needs to be smaller or equal to total number of example"

Download the metadata.json file that contains the URL of the images

if not os.path.isfile(data_path):
    # Downloading the metadata, 3.1GB, unzipped 9GB
    !wget -nv https://s3.us-east-2.amazonaws.com/mxnet-public/stanford_amazon/metadata.json.gz
    !gzip -d metadata.json.gz

if not os.path.isdir(images_path):
    os.makedirs(images_path)
def parse(path, num_cpu, modulo):
    g = open(path, 'r')
    for i, l in enumerate(g):
        if (i >= num_lines - subset_num and i%num_cpu == modulo):
            yield eval(l)
def download_files(modulo):
    for data in parse(data_path, NUM_CPU, modulo):
        if 'imUrl' in data and data['imUrl'] is not None and 'categories' in data and data['imUrl'].split('.')[-1] == 'jpg':
            url = data['imUrl']
            try:
                path = os.path.join(images_path, data['asin']+'.jpg')
                if not os.path.isfile(path):
                    file = urllib.request.urlretrieve(url, path)
            except:
                print("Error downloading {}".format(url))

Downloading the images using 10 times more processes than cores

NUM_CPU = multiprocessing.cpu_count()*10
pool = multiprocessing.Pool(processes=NUM_CPU)
results = pool.map(download_files, list(range(NUM_CPU)))
# Removing all the fake jpegs
list_files = glob.glob(os.path.join(images_path, '**.jpg'))
for file in list_files:
    if imghdr.what(file) != 'jpeg':
        print('Removed {} it is a {}'.format(file, imghdr.what(file)))
        os.remove(file)

Generate the image embeddings

BATCH_SIZE = 256
EMBEDDING_SIZE = 512
SIZE = (224, 224)
MEAN_IMAGE= mx.nd.array([0.485, 0.456, 0.406])
STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])

Featurizer

We use a pre-trained model from the model zoo

ctx = mx.gpu()

Networks from the model-zoo follow the convention that the features are on the .features property and output on the .output property. It makes it very easy to transform any pre-trained network in featurizer.

net = vision.resnet18_v2(pretrained=True, ctx=ctx)
net = net.features

Data Transform

to convert the images to a shape usable by the network

def transform(image, label):
    resized = mx.image.resize_short(image, SIZE[0]).astype('float32')
    cropped, crop_info = mx.image.center_crop(resized, SIZE)
    cropped /= 255.
    normalized = mx.image.color_normalize(cropped,
                                      mean=MEAN_IMAGE,
                                      std=STD_IMAGE) 
    transposed = nd.transpose(normalized, (2,0,1))
    return transposed, label

Data Loading

import os, tempfile, glob
empty_folder = tempfile.mkdtemp()
# Create an empty image Folder Data Set
dataset = ImageFolderDataset(root=empty_folder, transform=transform)
list_files = glob.glob(os.path.join(images_path, '**.jpg'))

Because of the data validation and invalid URL, our actual subset is smaller than the one requested

dataset.items = list(zip(list_files, [0]*len(list_files)))

We load the dataset in a dataloader with as many workers as CPU cores

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, last_batch='keep', shuffle=False, num_workers=multiprocessing.cpu_count())

Featurization

features = np.zeros((len(dataset), EMBEDDING_SIZE), dtype=np.float32)
%%time
tick = time.time()
n_print = 100
j = 0
for i, (data, label) in enumerate(dataloader):
    data = data.as_in_context(ctx)
    if i%n_print == 0 and i > 0:
        print("{0} batches, {1} images, {2:.3f} img/sec".format(i, i*BATCH_SIZE, BATCH_SIZE*n_print/(time.time()-tick)))
        tick = time.time()
    output = net(data)
    features[(i)*BATCH_SIZE:(i+1)*max(BATCH_SIZE, len(output)), :] = output.asnumpy().squeeze()
100 batches, 25600 images, 1333.611 img/sec
200 batches, 51200 images, 2097.873 img/sec
300 batches, 76800 images, 2108.257 img/sec
400 batches, 102400 images, 2119.740 img/sec
500 batches, 128000 images, 2007.043 img/sec
600 batches, 153600 images, 2104.296 img/sec
700 batches, 179200 images, 2155.201 img/sec
800 batches, 204800 images, 2105.456 img/sec
900 batches, 230400 images, 2106.616 img/sec
1000 batches, 256000 images, 2128.810 img/sec
1100 batches, 281600 images, 2125.134 img/sec
1200 batches, 307200 images, 2141.244 img/sec
1300 batches, 332800 images, 2103.341 img/sec
1400 batches, 358400 images, 2116.504 img/sec
1500 batches, 384000 images, 2090.445 img/sec
1600 batches, 409600 images, 2138.420 img/sec
1700 batches, 435200 images, 2088.554 img/sec
1800 batches, 460800 images, 2127.671 img/sec
1900 batches, 486400 images, 2118.631 img/sec
2000 batches, 512000 images, 2084.014 img/sec
2100 batches, 537600 images, 2111.905 img/sec
2200 batches, 563200 images, 2125.523 img/sec
2300 batches, 588800 images, 2106.901 img/sec
2400 batches, 614400 images, 2123.917 img/sec
2500 batches, 640000 images, 2064.876 img/sec
2600 batches, 665600 images, 2117.610 img/sec
2700 batches, 691200 images, 2112.028 img/sec
2800 batches, 716800 images, 2066.120 img/sec
2900 batches, 742400 images, 2068.632 img/sec
3000 batches, 768000 images, 2095.919 img/sec
3100 batches, 793600 images, 2104.414 img/sec
3200 batches, 819200 images, 2090.150 img/sec
3300 batches, 844800 images, 2068.915 img/sec
3400 batches, 870400 images, 2113.243 img/sec
3500 batches, 896000 images, 2105.340 img/sec
3600 batches, 921600 images, 2127.197 img/sec
3700 batches, 947200 images, 2123.200 img/sec
CPU times: user 4min 43s, sys: 3min 22s, total: 8min 5s
Wall time: 7min 42s

Create the search index

# Number of elements in the index
num_elements = len(features)
labels_index = np.arange(num_elements)
%%time 

# Declaring index
p = hnswlib.Index(space = 'l2', dim = EMBEDDING_SIZE) # possible options are l2, cosine or ip

# Initing index - the maximum number of elements should be known beforehand
p.init_index(max_elements = num_elements, ef_construction = 100, M = 16)

# Element insertion (can be called several times):
int_labels = p.add_items(features, labels_index)

# Controlling the recall by setting ef:
p.set_ef(100) # ef should always be > k
CPU times: user 31min 34s, sys: 16.4 s, total: 31min 51s
Wall time: 1min
p.save_index('index.idx')

Testing

We test the results by sampling random images from the dataset and searching their K-NN

def plot_predictions(images):
    gs = gridspec.GridSpec(3, 3)
    fig = plt.figure(figsize=(15, 15))
    gs.update(hspace=0.1, wspace=0.1)
    for i, (gg, image) in enumerate(zip(gs, images)):
        gg2 = gridspec.GridSpecFromSubplotSpec(10, 10, subplot_spec=gg)
        ax = fig.add_subplot(gg2[:,:])
        ax.imshow(image, cmap='Greys_r')
        ax.tick_params(axis='both',       
                       which='both',      
                       bottom='off',      
                       top='off',         
                       left='off',
                       right='off',
                       labelleft='off',
                       labelbottom='off') 
        ax.axes.set_title("result [{}]".format(i))
        if i == 0:
            plt.setp(ax.spines.values(), color='red')
            ax.axes.set_title("SEARCH".format(i))
def search(N, k):
    # Query dataset, k - number of closest elements (returns 2 numpy arrays)
    q_labels, q_distances = p.knn_query([features[N]], k = k)
    images = [plt.imread(dataset.items[label][0]) for label in q_labels[0]]
    plot_predictions(images)

Random testing

%%time
index = np.random.randint(0,len(features))
k = 6
search(index, k)
CPU times: user 292 ms, sys: 0 ns, total: 292 ms
Wall time: 287 ms

png

Manual testing

path = 'dress.jpg'
p.set_ef(300) # ef should always be > k
image = plt.imread(path)[:,:,:3]
image_t, _ = transform(nd.array(image), 1)
output = net(image_t.expand_dims(axis=0).as_in_context(ctx))
labels, distances = p.knn_query([output.asnumpy().reshape(-1,)], k = 5)
images = [image]
images += [plt.imread(dataset.items[label][0]) for label in labels[0]]
plot_predictions(images)

png

visualsearch_mxnet's People

Contributors

srochel avatar thomasdelteil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

visualsearch_mxnet's Issues

Own Dataset

Hey, any reference to add our own dataset in this code?

Unable to locate visualsearch.model file

The model server configuration references a file visualsearch.model. However, I could not locate the file. I also tried to recreate it but could not.

Kindly explain more on this thanks

HNSW model

Hello, thanks very much for your contribution. I have a question I would like to ask. Every time I search for an image, do I have to reload the HNSW model? I have 20 million images here, and loading the model will be very slow. Is there any way to turn the search into a service? or how do you get the results so quickly?,that is so cool, thank you

Service input file

Thank you for providing a nice program.
Please teach me.
Is the input file a json file?
Please tell me the format of the input.json file.

curl -X POST http://127.0.0.1:8080/visualsearch/predict -F "[email protected]"

signature.json
{
"inputs": [
{
"data_name": "data",
"data_shape": [0, 3, 224, 224]
}
],
"input_type": "application/json",  <--- inputfile = input.json ?

Fine tuned model

Hi Thomas,
Great work!

I have a very fundamental question.

I understand that your idea behind the presentation and sharing codes is to spread knowledge, but do you think that instead of simply using the standard resnet model to get the features, one should first fine-tune the model on a specific dataset and then extract features?

For instance,
If I intend to find features for fashion data, then I should first finetune the standard resnet model by unfreezing the initial layers and then extract features for new fashion images?
Do you think this would improve accuracy in the search?

Thanks

Again, Great work!

How to train model in our own data

Hi, thanks for the code. Can you please tell me, How to train the model in our own data and how can we create more custom models into it?

Thanks

IndexError: index 0 is out of bounds for axis 0 with size 0

%%time

p = hnswlib.Index(space = 'l2', dim = EMBEDDING_SIZE) # possible options are l2, cosine or ip

p.init_index(max_elements = num_elements, ef_construction = 100, M = 16)

int_labels = p.add_items(features, labels_index, num_threads = -1)

p.set_ef(100) # ef should always be > k

IndexError: index 0 is out of bounds for axis 0 with size 0

i am getting this error here

Can we have the results with hyperlinks to the page?

If the latency is too large (if query for each page ), at least we may want to have the title in full length so that users can copy and search.

Or another step further:

  1. Add a button to each image to "search in amazon".
  2. A search bar on top for users to copy/paste title into it.

Same result when multiple calls at the same time

Hi Thomas,
This is absolutely a great work!
I ran this service on my local and it works fine, but I have an issue. when people call the service at the same time, we see each other's results. for example, when I upload a T-shirt picture and at the same time, my friend upload a shoe picture, I see shoe pictures instead of T-shirt pictures!
Do you know what is the problem and how can I fix this?
Thanks in advance.

Failure when Running Training Model on AWS Batch

I was able to follow the tutorial and reproduce the model and results for my use case. However, when I schedule the model training on AWS Batch, EC2 instance m4.4xlarge, it fails with the error below while extracting features. See line

for i, (data, label) in enumerate(data_loader):
        data = data.as_in_context(ctx)
        if i % n_print == 0 and i > 0:
            print(
                "{0} batches, {1} images, {2:.3f} img/sec".format(
                    i, i*BATCH_SIZE, BATCH_SIZE*n_print/(time.time()-tick)
                )
            )
            tick = time.time()
        output = net(data)
        features[i * BATCH_SIZE:(i+1)*max(BATCH_SIZE, len(output)), :] = output.asnumpy().squeeze()

Error message

save(x)
File "/usr/lib/python2.7/pickle.py", line 286, in save
f(self, obj) # Call unbound method with explicit self
File "/usr/lib/python2.7/multiprocessing/forking.py", line 66, in dispatcher
rv = reduce(obj)
File "/usr/local/lib/python2.7/dist-packages/mxnet/gluon/data/dataloader.py", line 43, in reduce_ndarray
return rebuild_ndarray, data._to_shared_mem()
File "/usr/local/lib/python2.7/dist-packages/mxnet/ndarray/ndarray.py", line 200, in _to_shared_mem
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
File "/usr/local/lib/python2.7/dist-packages/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
MXNetError: [14:48:14] src/operator/tensor/../tensor/elemwise_unary_op.h:301: Check failed: inputs[0].dptr_ == outputs[0].dptr_ (0x7fe0beffc040 vs. 0x7fe0bf001600)
Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x17ec9d) [0x7fe11ec74c9d]
[bt] (1) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x17f068) [0x7fe11ec75068]
[bt] (2) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x8f7034) [0x7fe11f3ed034]
[bt] (3) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x2825020) [0x7fe12131b020]
[bt] (4) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x27a3ad8) [0x7fe121299ad8]
[bt] (5) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x27a3b13) [0x7fe121299b13]
[bt] (6) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x27ab954) [0x7fe1212a1954]
[bt] (7) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x27af461) [0x7fe1212a5461]
[bt] (8) /usr/local/lib/python2.7/dist-packages/mxnet/libmxnet.so(+0x27ac01b) [0x7fe1212a201b]
[bt] (9) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7fe130197c80]

I have tried to figure this out but nothing so far.

Please help if you have any ideas.

Custom Codes in visualsearch.py

Hi Thomas,
Can you tell me how do I put custom codes in visualsearch.py

Let's say that I want to do some Object Detection using some other framework (Keras-retinanet) and then I want to find feature vector for every object detected in the image.
I was thinking to add these steps inside the _preprocess function of visualsearch.py

VisualSearch.py would look like -

  1. Importing additional keras_retinanet libraries
import keras
import tensorflow as tf
from keras_retinanet.models import load_model # Helps in loading object detection model
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
  1. A url of s3 where the .h5 is located -
    MODEL_URL = '........amazonaws.com/mod_objdet.h5'

  2. class VisualSearchService(MXNetBaseService) will have codes to download the model and load the model

model_url = os.environ.get('MODEL_URL', MODEL_URL)
mx.test_utils.download(model_url, dirname=data_dir)
self.model = load_model(os.path.join(data_dir, 'mod_objdet.h5'), backbone_name='resnet50')

_preprocess function should have the codes that will generate bounding boxes, cropping the image and then finding the feature vector for each cropped item.

I tried it locally by running the mxnet-model-server after archiving all the files but it doesnt seem to work.

Gives me the following error -
' ModuleNotFoundError: No module named 'keras_retinanet'

I m running this inside an env which has Keras-retinanet and mxnet and their dependencies.

  1. Is this approach appropriate?
  2. Can you suggest me something else which is better than this approach?

Inconsistencies in Predictions

@ThomasDelteil, thanks for this great tutorial. I have some questions in this regard. I get inconsistent results when the net functions is not used on the transformed image array. See below:

  1. When the net function is set
p.set_ef(300) # ef should always be > k
image = plt.imread(path)[:,:,:3]
image_t, _ = transform(nd.array(image), 1)
output = net(image_t.expand_dims(axis=0).as_in_context(ctx))
labels, distances = p.knn_query([output.asnumpy().reshape(-1,)], k = 25)
images = [image]
images += [plt.imread(dataset.items[label][0]) for label in labels[0]]

screen shot 2019-01-10 at 7 34 40 am

  1. When the net function not is set
p.set_ef(300) # ef should always be > k
image = plt.imread(path)[:,:,:3]
image_t, _ = transform(nd.array(image), 1)
output = image_t.expand_dims(axis=0).as_in_context(ctx)
labels, distances = p.knn_query([output.asnumpy().reshape(-1,)], k = 25)
images = [image]
images += [plt.imread(dataset.items[label][0]) for label in labels[0]]

screen shot 2019-01-10 at 7 35 03 am

Looking at the code for the model server, the net function is not used and the results from your deployment seems to work fine. However, when I did the same, I got the inconsistent results.

Can you please help explain why this is?

Thanks in advance

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.