wangyuxin87 / visionlan Goto Github PK
View Code? Open in Web Editor NEWA PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)
I wonder how you occlude the characters of the image? Which tools? Can you Please explain?
作者你好,论文提供的思路对我启发很大!
我有两个问题想请教一下:
我看到论文中描述的训练流程只包含language-free
和language-aware
两个环节,类似于代码中的LF_2
和LA
,但代码中还额外增加了LF_1
来专门预训练backbone + VRM
部分,并且在LF_2
过程的optimizer中还针对LF_1
训练过的params采用了不同的lr。请问LF_1 --> LF_2 --> LA
和LF_2 --> LA
两种训练方式差别大吗?
问题是我对实现逻辑不太理解
这里我和SRN中视觉部分(PVAM)中的attention过程作对比:
(a) SRN-PVAM中的attention过程(伪代码,假设qkv的维度都是d_model):
# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
key2att = nn.Linear(d_model, d_model)
query2att = nn.Linear(d_model, d_model)
embedding = nn.Embedding(max_seq_len, d_model)
score = nn.Linear(d_model, 1)
classifier = nn.Linear(d_model, vocab_size)
# input is encoder_out
reading_order = torch.arange(max_seq_len, dtype=torch.long)
Q = embedding(reading_order) # (max_seq_len, d_model)
K, V = encoder_out # (batch_size, seq_len_k, d_model)
# 这里计算att_weight的过程很容易理解,和经典的attention模型比如ASTER的attention过程相同
######
att_q = key2att(Q).unsqueeze(0).unsqueeze(2) # (1, seq_len_q, 1, d_model)
att_k = query2att(K).unsqueeze(1) # (batch_size, 1, seq_len_k, d_model)
att_weight = score(torch.tanh(att_q + att_k)).squeeze(3) # (batch_size, seq_len_q, seq_len_k)
######
att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K) # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out) # (batch_size, seq_len_q, vicab_size)
(b) VisionLAN中的attention过程:
# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
embedding = nn.Embedding(max_seq_len, d_model)
w0 = nn.Linear(max_seq_len, seq_len_k)
wv = nn.Linear(d_model, d_model)
we = nn.Linear(d_model, max_seq_len)
classifier = nn.Linear(d_model, vocab_size)
# input is encoder_out
K, V = encoder_out # (batch_size, seq_len_k, d_model)
reading_order = torch.arange(max_seq_len, dtype=torch.long)
# 如何理解下面这段计算att_weight的代码?
#####
reading_order = embedding(reading_order) # (seq_len_q, d_model)
reading_order = reading_order.unsqueeze(0).expand(K.size(0), -1) # (batch_size, seq_len_q, d_model)
t = w0(reading_order.permute(0, 2, 1)) # (batch_size, d_model, seq_len_q) ==> (batch_size, d_model, seq_len_k)
t = torch.tanh(t.permute(0, 2, 1) + wv(K)) # (batch_size, seq_len_k, d_model)
att_weight = we(t) # (batch_size, seq_len_k, d_model) ==> (batch_size, seq_len_k, seq_len_q)
att_weight = att_weight.permute(0, 2, 1)
######
att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K) # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out) # (batch_size, seq_len_q, vicab_size)
期待你的回复,谢谢!
Hi, i really want to test a code.
How can i know when i get it?
where is paper?
Hi, i can not use the Baidu Netdisk downloads and the others have expired. Can you create new links and update datasets.md?
https://github.com/HCIILAB/Scene-Text-Recognition-Recommendations/blob/main/datasets.md
Thanks!
Thanks for your excellent contributions!
I try to use your pre-trained LF_2 model to visualize the mask map, I pick the same image that was shown in the Visualization character-wise mask map (P=0) of ReadMe. But I got a different result from you. I resize the mask map to the original height and width and add it with the original image and got the result like this:
Then I just resize the mask map to the original height and width and directly visualize the mask map and got a result like this:
It seems that the above two visualizations are different from yours, was there a problem?
This is a excellent work. As title, i want cite this paper,could you give this citation? thanks~
utils.py中的cha_encdec类在编码时将字典中不存在的符号编码为len(self.dict)+1,这将导致训练程序中的crossentropyloss函数报错
How many images do you use in Benchmark Datasets IIIT5K, SVT, IC13, IC15, SVTP, and CUTE80 respectively? I am confused with the 4832 images you mentioned in your paper.
Thank you!
Dear yuxin,
sorry to bother you again. When I use your code, I found two new questions: 1. When I executed python train_LF_1.py
, I got a CUDA error in ClassNLLCriterion.cu
. 2. When I modify the code into Chinese training, the model could not converge.
ClassNLLCriterion.cu
.THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCTensorMath.cu line=29 error=710 : device-side assert triggered
/pytorch/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [12,0,0] Assertion `t >= 0 && t < n_classes` failed.
The problem seems like a bug (My PyTorch version is 1.71). When I modified the nclass
from 37 to 38, the problem is gone. I think 38
is reasonable: 36 normal chars, 1 , and 1 . I modified these two lines:
VisionLAN.py
71: self.Prediction = Prediction(n_position=256, N_max_character=26, n_class=37) # N_max_character = 1 eos + 25 characters
72: self.nclass = 37
I modified the codes for Chinese training, but the model could not be coverged. The loss drops very slowly. Do you modify the training config when training TRW15?
请问可以提供文章中使用的中文训练数据集和模型嘛
你好,这个地方 https://github.com/wangyuxin87/VisionLAN/blob/main/train_LF_2.py#L74-L76 通过id来选择不同参数,我通过print打印,第一组参数是空,也就是没有匹配到id_total的参数,是哪里有问题吗?
感谢您的代码,运行“CUDA_VISIBLE_DEVICES=0 python eval.py”命令后,报出如下错误:
Traceback (most recent call last):
File "eval.py", line 11, in
import cfgs.cfgs_eval as cfgs
File "/workspace/VisionLAN-main/cfgs/cfgs_eval.py", line 5, in
from data.dataset_scene import *
File "/workspace/VisionLAN-main/data/dataset_scene.py", line 16, in
from transforms import CVColorJitter, CVDeterioration, CVGeometry
ModuleNotFoundError: No module named 'transforms'
看起来似乎是环境的问题,但我的环境是严格按照requirements.txt配置的,这是怎么回事呢?热切盼望您的解答
Hello! It is an excellent work that inspires me a lot ;D
As I read your code, I found something strange (maybe a typo):
line145: text_pre, test_rem, text_mas, att_mask_sub = model(data, label_id, cfgs.global_cfgs['step'])
Should "test_rem" be modified to "text_rem"?
Thanks for your code. I found my evaluation results are different from the results post in README.md. I run the provided models and eval.py
, and I get zero accuracy on all datasets.
I used the test datasets and models that you provided on Ruike.
The command I executed:
python eval.py
The result I get:
------Average on 6 benchmarks--------
test accuracy:
Accuracy: 0.000000, AR: 0.815073, CER: 0.184927, WER: 1.000000, best_acc: 0.000000
------IIIT--------
test accuracy:
Accuracy: 0.000000, AR: 0.821556, CER: 0.178444, WER: 1.000000, best_acc: 0.000000
------IC13--------
test accuracy:
Accuracy: 0.000000, AR: 0.846731, CER: 0.153269, WER: 1.000000, best_acc: 0.000000
------IC15--------
test accuracy:
Accuracy: 0.000000, AR: 0.790097, CER: 0.209903, WER: 1.000000, best_acc: 0.000000
------SVT--------
test accuracy:
Accuracy: 0.000000, AR: 0.831195, CER: 0.168805, WER: 1.000000, best_acc: 0.000000
------SVTP--------
test accuracy:
Accuracy: 0.000000, AR: 0.797383, CER: 0.202617, WER: 1.000000, best_acc: 0.000000
------CUTE--------
Hope for your response~
hi thanks for your wonderful work i am trying to train the model on my own dataset but getting this error. Can you help me with this.
Thanks for sharing the code and data of your amazing work; I tried to download the OST datasets from both Baidu and Ruike, but I could not. It needed login. Can you please upload on google drive or share with me the link that I can download directly?
您好,在论文表一中显示ST-VQA的图片数据集为4000,但我下载得到的test_task1\2\3.json文件的图片数均大于4000,请问具体是如何筛选ST-VQA数据集的图片的呢
请问synthtext,MJtext等数据库的mdb文件使用哪一版本的access打开,我使用access2021打开显示不可识别的数据库格式。
感谢作者大大回复
您好,感谢您的工作,启发很大,关于文章中的VRM有个小问题。请问VRM的结构是什么样的?是只有多层transformer encoder的叠加,还是transformer后还做了parallel attention?
期待您的回复!
There is no password of test dataset(Ruike), hope fix this soon
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.