Giter Site home page Giter Site logo

amazon-science / semimtr-text-recognition Goto Github PK

View Code? Open in Web Editor NEW
79.0 4.0 14.0 1.26 MB

Multimodal Semi-Supervised Learning for Text Recognition (SemiMTR)

License: Apache License 2.0

Python 92.57% Jupyter Notebook 7.43%
computer-vision consistency-regularization contrastive-learning ocr scene-text-recognition semi-supervised-learning text-recognition deep-learning pytorch self-supervised-learning

semimtr-text-recognition's Introduction

Multimodal Semi-Supervised Learning for Text Recognition

The official code implementation of SemiMTR Paper | Pretrained Models | SeqCLR Paper | Citation | Demo.

Aviad Aberdam, Roy Ganz, Shai Mazor, Ron Litman

We introduce a multimodal semi-supervised learning algorithm for text recognition, which is customized for modern vision-language multimodal architectures. To this end, we present a unified one-stage pretraining method for the vision model, which suits scene text recognition. In addition, we offer a sequential, character-level, consistency regularization in which each modality teaches itself. Extensive experiments demonstrate state-of-the-art performance on multiple scene text recognition benchmarks.

Figures

semimtr vision model pretraining

Figure 1: SemiMTR vision model pretraining: Contrastive learning



semimtr fine-tuning

Figure 2: SemiMTR model fine-tuning: Consistency regularization

Getting Started

Run Demo with Pretrained Model Open In Colab

Dependencies

  • Inference and demo requires PyTorch >= 1.7.1
  • For training and evaluation, install the dependencies
pip install -r requirements.txt

Pretrained Models

Download pretrained models:

Pretrained vision models:

Pretrained language model:

For fine-tuning SemiMTR without vision and language pretraining, locate the above models in a workdir directory, as follows:

workdir
├── semimtr_vision_model_real_l_and_u.pth
├── abinet_language_model.pth
└── semimtr_real_l_and_u.pth

SemiMTR Models Accuracy

Training Data IIIT SVT IC13 IC15 SVTP CUTE Avg. COCO RCTW Uber ArT LSVT MLT19 ReCTS Avg.
Synth (ABINet) 96.4 93.2 95.1 82.1 89.0 89.2 91.2 63.1 59.7 39.6 68.3 59.5 85.0 86.7 52.0
Real-L+U 97.0 95.8 96.1 84.7 90.7 94.1 92.8 72.2 76.1 58.5 71.6 77.1 90.4 92.4 65.4
Real-L+U+Synth 97.4 96.8 96.5 84.7 92.9 95.1 93.3 73.0 75.7 58.6 72.4 77.5 90.4 93.1 65.8
Real-L+U+TextOCR 97.3 97.7 96.9 86.0 92.2 94.4 93.7 73.8 77.7 58.6 73.5 78.3 91.3 93.3 66.1

Datasets

  • Download preprocessed lmdb dataset for training and evaluation. Link
  • For training the language model, download WikiText103. Link
  • The final structure of data directory can be found in DATA.md.

Training

  1. Pretrain vision model
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/semimtr_pretrain_vision_model.yaml
    
  2. Pretrain language model
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/pretrain_language_model.yaml
    
  3. Train SemiMTR
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/semimtr_finetune.yaml
    

Note:

  • You can set the checkpoint path for vision and language models separately for specific pretrained model, or set to None to train from scratch

Training ABINet

  1. Pre-train vision model
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/abinet_pretrain_vision_model.yaml
    
  2. Pre-train language model
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/pretrain_language_model.yaml
    
  3. Train ABINet
    CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/abinet_finetune.yaml
    

Evaluation

CUDA_VISIBLE_DEVICES=0 python main.py --config configs/semimtr_finetune.yaml --run_only_test

Arguments:

  • --checkpoint /path/to/checkpoint set the path of evaluation model
  • --test_root /path/to/dataset set the path of evaluation dataset
  • --model_eval [alignment|vision] which sub-model to evaluate

Citation

If you find our method useful for your research, please cite

@article{aberdam2022multimodal,
  title={Multimodal Semi-Supervised Learning for Text Recognition},
  author={Aberdam, Aviad and Ganz, Roy and Mazor, Shai and Litman, Ron},
  journal={arXiv preprint arXiv:2205.03873},
  year={2022}
}

@inproceedings{aberdam2021sequence,
  title={Sequence-to-sequence contrastive learning for text recognition},
  author={Aberdam, Aviad and Litman, Ron and Tsiper, Shahar and Anschel, Oron and Slossberg, Ron and Mazor, Shai and Manmatha, R and Perona, Pietro},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={15302--15312},
  year={2021}
}

Acknowledgements

This implementation is based on the repository ABINet.

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.

Contact

Feel free to contact us if there is any question: Aviad Aberdam

semimtr-text-recognition's People

Contributors

aaberdam avatar amazon-auto avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

semimtr-text-recognition's Issues

Why still calculate loss when conducting evaluation? and IndexError reported also

[2022-10-09 21:10:13,443 main.py:283 INFO consistency-regularization] Construct dataset.
[2022-10-09 21:10:13,478 main.py:129 INFO consistency-regularization] 4485536 training items found.
[2022-10-09 21:10:13,478 main.py:131 INFO consistency-regularization] 62944 valid items found.
[2022-10-09 21:10:13,478 main.py:133 INFO consistency-regularization] 147209 test items found.
[2022-10-09 21:10:13,478 main.py:289 INFO consistency-regularization] Construct model.
[2022-10-09 21:10:13,789 model_vision.py:38 INFO consistency-regularization] Read vision model from workdir/semimtr_vision_model_real_l_and_u.pth.
[2022-10-09 21:10:33,487 model_language.py:38 INFO consistency-regularization] Read language model from workdir/abinet_language_model.pth.
[2022-10-09 21:10:33,537 main.py:292 INFO consistency-regularization] Construct learner.
[2022-10-09 21:10:33,597 main.py:301 INFO consistency-regularization] Start testing
Traceback (most recent call last):------------------| 0.00% [0/2 00:00<?]
File "main.py", line 306, in
main()
File "main.py", line 302, in main
test_on_each_ds(learner)
File "/media/disk4/flbl/cdsme/semimtr/semimtr/utils/test.py", line 19, in test_on_each_ds
last_metrics = learner.validate(dl=dl)
File "/home/lbl/miniconda3/envs/semimtr/lib/python3.7/site-packages/fastai/basic_train.py", line 391, in validate
val_metrics = validate(self.model, dl, self.loss_func, cb_handler)
File "/home/lbl/miniconda3/envs/semimtr/lib/python3.7/site-packages/fastai/basic_train.py", line 59, in validate
val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
File "/home/lbl/miniconda3/envs/semimtr/lib/python3.7/site-packages/fastai/basic_train.py", line 30, in loss_batch
loss = loss_func(out, *yb)
File "/home/lbl/miniconda3/envs/semimtr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/media/disk4/flbl/cdsme/semimtr/semimtr/losses/consistency_regularization_loss.py", line 73, in forward
pt_lengths_teacher, *args[2:], mask=threshold_mask)
File "/home/lbl/miniconda3/envs/semimtr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/media/disk4/flbl/cdsme/semimtr/semimtr/losses/losses.py", line 62, in forward
gt_labels, gt_lengths = gt_dict['label'], gt_dict['length']
IndexError: too many indices for tensor of dimension 3

I run !pip install -r requirements.txt then i config for mine then i run into this error.

[2023-08-26 09:23:35,166 main.py:283 INFO pretrain-language-model] Construct dataset.
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
[2023-08-26 09:23:36,980 main.py:84 INFO pretrain-language-model] 102085 training items found.
[2023-08-26 09:23:36,980 main.py:86 INFO pretrain-language-model] 50000 valid items found.
[2023-08-26 09:23:36,980 main.py:289 INFO pretrain-language-model] Construct model.
[2023-08-26 09:23:41,078 main.py:292 INFO pretrain-language-model] Construct learner.
[2023-08-26 09:23:46,050 main.py:296 INFO pretrain-language-model] Start training.
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2023-08-26 09:23:46,569 callbacks.py:179 INFO pretrain-language-model] Train ended
[2023-08-26 09:23:46,569 callbacks.py:143 INFO pretrain-language-model] average data time = 0.0000s, average running time = 0.0000s
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/fastai/basic_train.py", line 99, in fit
for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 39, in iter
if self.total != 0: self.update(0)
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 59, in update
self.update_bar(0)
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 81, in update_bar
self.on_update(val, f'{pct}[{val}/{tot} {elapsed_t}{self.lt}{remaining_t}{end}]')
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 134, in on_update
elif self.parent is not None: self.parent.show()
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 177, in show
self.out.update(HTML(self.html_code))
AttributeError: 'NoneType' object has no attribute 'update'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/content/drive/.shortcut-targets-by-id/12Fpyqv7ac7VkBnm5mHDZrlJ-pRtYx53q/AIO/Competition_Module/BKAI/semimtr-text-recognition/main.py", line 306, in
main()
File "/content/drive/.shortcut-targets-by-id/12Fpyqv7ac7VkBnm5mHDZrlJ-pRtYx53q/AIO/Competition_Module/BKAI/semimtr-text-recognition/main.py", line 297, in main
learner.fit(epochs=config.training_epochs,
File "/usr/local/lib/python3.10/dist-packages/fastai/basic_train.py", line 200, in fit
fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
File "/usr/local/lib/python3.10/dist-packages/fastai/basic_train.py", line 112, in fit
finally: cb_handler.on_train_end(exception)
File "/usr/local/lib/python3.10/dist-packages/fastai/callback.py", line 323, in on_train_end
self('train_end', exception=exception)
File "/usr/local/lib/python3.10/dist-packages/fastai/callback.py", line 251, in call
for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/fastai/callback.py", line 241, in call_and_update
new = ifnone(getattr(cb, f'on
{cb_name}')(**self.state_dict, **kwargs), dict())
File "/content/drive/.shortcut-targets-by-id/12Fpyqv7ac7VkBnm5mHDZrlJ-pRtYx53q/AIO/Competition_Module/BKAI/semimtr-text-recognition/semimtr/callbacks/callbacks.py", line 180, in on_train_end
self._eval_model()
File "/content/drive/.shortcut-targets-by-id/12Fpyqv7ac7VkBnm5mHDZrlJ-pRtYx53q/AIO/Competition_Module/BKAI/semimtr-text-recognition/semimtr/callbacks/callbacks.py", line 146, in _eval_model
last_metrics = self._validate()
File "/content/drive/.shortcut-targets-by-id/12Fpyqv7ac7VkBnm5mHDZrlJ-pRtYx53q/AIO/Competition_Module/BKAI/semimtr-text-recognition/semimtr/callbacks/callbacks.py", line 62, in _validate
val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
File "/usr/local/lib/python3.10/dist-packages/fastai/basic_train.py", line 57, in validate
for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 39, in iter
if self.total != 0: self.update(0)
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 59, in update
self.update_bar(0)
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 81, in update_bar
self.on_update(val, f'{pct}[{val}/{tot} {elapsed_t}{self.lt}{remaining_t}{end}]')
File "/usr/local/lib/python3.10/dist-packages/fastprogress/fastprogress.py", line 133, in on_update
if self.display: self.out.update(HTML(self.progress))
AttributeError: 'NoneType' object has no attribute 'update'

Question about training time.

Very Nice Work! I noticed in your implementation details, ST used for training, 25 epochs with 304 batchsize. I would like to ask

  1. Is the batchsize here the total batch size on 4 GPU, which means the batch size on each V100 is 76.

  2. How long does it take to train an epoch?

I have also tried to train ABINet. However, each epoch seemed to take a very long training time.

Image Size in One Batch.

Hello, thanks for your great work. I have some questions about the inputs for the SeqCLR model. Should the images in the same batch have the same size(especially the same width)? Since text in images varies in length, and some images with longer text may be distorted if all images are scaled to the same size. If the padding operation is performed on the image, some frames in the sequence feature will lose their semantics. So how do you preprocess the image size in one batch?
Looking forward to your response. Thanks very much!

NaN in input tensor

The language model gives NaN or Inf found in input tensor.
train.txt
Can help why it is failing to train on non English character?

The shape of projector's input

Hi, I'm a little confused about the forward function in SeqCLR. In the seqclr_proj.py line 59, the output of the visual backbone is reshaped to # (N, E, H, W) -> (N, H*W, E), but in OCR task they are usually processed as # (N, E, H, W) -> (N, W, E*H). And the explanation in paper is Note that the sequence length depends on the width of the input image. So what is the right shape for the feed of projector? Thanks!

Training stopped after 5 epochs while Semimtr finetuning

A problem occurred when trying to train the Semimtr:

The training stage will automatically stop after 5 epochs no matter what value of epoch I set in the config file 'configs/semimtr_finetune.yaml'.

However, if I set epoch to less than 5, then everything seems to work just fine.

Evaluation results?

Hello, thank you for your contribution. When I use the downloaded model to evaluate, how do I look at the accuracy of each dataset? What does ccr, cwr, ted, ned, ted /w mean?
22222
111111

imgaug augmenters

AttributeError: module 'imgaug.augmenters' has no attribute 'MultiplyBrightness'
train.txt

I tried to check from which path each of these files imports imgaug. I couldn't find the highlighted packages.
Screen Shot 2022-10-11 at 10 39 57 PM

Question about frame-to-instance implemenation

Hi, thanks for your great work!
When I run the code, I have a problem in the following line:

features = features.permute(0, 2, 3, 1).flatten(1, 2) # (N, E, H, W) -> (N, H*W, E)

It seems like this implementation performs row-wise avgpooling which concats each row and then performs avgpooling. However, in the original paper, column-wise avgpooling is performed, which should be:
features = features.permute(0, 3, 2, 1).flatten(1, 2)
I have no idea whether this difference is matter.
Thank you~

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.