Giter Site home page Giter Site logo

fangpin / siamese-pytorch Goto Github PK

View Code? Open in Web Editor NEW
278.0 3.0 58.0 65 KB

Implementation of Siamese Networks for image one-shot learning by PyTorch, train and test model on dataset Omniglot

Python 100.00%
siamese siamese-network one-shot-learning one-shot pytorch

siamese-pytorch's Introduction

Siamese Networks for One-Shot Learning

A reimplementation of the original paper in pytorch with training and testing on the Omniglot dataset.

requirement

  • pytorch
  • torchvision
  • python3.5+
  • python-gflags

See requirements.txt

run step

  • download dataset
git clone https://github.com/brendenlake/omniglot.git
cd omniglot/python
unzip images_evaluation.zip
unzip images_background.zip
cd ../..
# setup directory for saving models
mkdir models
  • train and test by running
python3 train.py --train_path omniglot/python/images_background \
                 --test_path  omniglot/python/images_evaluation \
                 --gpu_ids 0 \
                 --model_path models

experiment result

Loss value is sampled after every 200 batches img My final precision is 89.5% a little smaller than the result of the paper (92%).

The small result difference might be caused by some difference between my implementation and the paper's. I list these differences as follows:

  • learning rate

instead of using SGD with momentum I just use ADAM.

  • parameters initialization and settings

Instead of using individual initialization methods, learning rates and regularization rates at different layers I simply use the default setting of pytorch and keep them same.

siamese-pytorch's People

Contributors

fangpin avatar urbanophile avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

siamese-pytorch's Issues

makedataset

Hi.
How make_datset.py works?
I did run it but nothing happens!!

Thank you in advance.
Best,
Saeid.

Shared memory error

Hello, I am trying to run your code, but I am getting errors about share memory which don't make sense.

Output below along with some info about the machine (~50 GB RAM + V100 with an 8 core Xeon )

python3 train.py --train_path omniglot/python/images_background \

             --test_path  omniglot/python/images_evaluation \
             --gpu_ids 0 \
             --model_path models

use gpu: 0 to train.
begin loading training dataset to memory
finish loading training dataset to memory
begin loading test dataset to memory
finish loading test dataset to memory
/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.7/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/reductions.py", line 289, in rebuild_storage_fd
fd = df.detach()
File "/usr/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/usr/lib/python3.7/multiprocessing/resource_sharer.py", line 87, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/usr/lib/python3.7/multiprocessing/connection.py", line 498, in Client
answer_challenge(c, authkey)
File "/usr/lib/python3.7/multiprocessing/connection.py", line 741, in answer_challenge
message = connection.recv_bytes(256) # reject large message
File "/usr/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/usr/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/usr/lib/python3.7/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 79, in
for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1014, in _try_get_data
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1014, in
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
File "/usr/lib/python3.7/tempfile.py", line 677, in NamedTemporaryFile
prefix, suffix, dir, output_type = _sanitize_params(prefix, suffix, dir)
File "/usr/lib/python3.7/tempfile.py", line 265, in _sanitize_params
dir = gettempdir()
File "/usr/lib/python3.7/tempfile.py", line 433, in gettempdir
tempdir = _get_default_tempdir()
File "/usr/lib/python3.7/tempfile.py", line 349, in _get_default_tempdir
fp.write(b'blat')
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 488) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.

itamblyn@host free -m
total used free shared buff/cache available
Mem: 52329 716 41895 8 9717 51060
Swap: 0 0 0

itamblyn@host ipcs -lm

------ Shared Memory Limits --------
max number of segments = 4096
max seg size (kbytes) = 18014398509465599
max total shared memory (kbytes) = 18014398509481980
min seg size (bytes) = 1

nvidia-smi
Fri Aug 20 01:01:45 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01 Driver Version: 460.73.01 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |
| N/A 37C P0 38W / 300W | 0MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |

92.5% precision

Thanks for your code, I got 92.5% precision in my first training

can I use torch1.6?

the requirement says the torch version is 1.0.1, can I use torch1.6 to train and test the network?

dataset

Thanks for your code
i have 2 question

  1. make_dataset.py used in this code?
    2.mydataset.py->OmniglotTrain->loadToMem : Even if it's the same char, different labels are given depending on the rotate. is this right?

python3 train.py gives the following error: line 23, in find_classes FileNotFoundError: [Errno 2] No such file or directory: 'background'

python3 train.py
gives the following error:
Traceback (most recent call last): File "train.py", line 29, in train_dataset = dset.ImageFolder(root=train_path) File "/home/sumes/anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/datasets/folder.py", line 178, in init File "/home/sumes/anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/datasets/folder.py", line 75, in init File "/home/sumes/anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/datasets/folder.py", line 23, in find_classes FileNotFoundError: [Errno 2] No such file or directory: 'background'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.