티스토리 뷰
네이버 deep-text-recognition 모델을 custom data로 학습 & 아키텍쳐 분석
hanyangrobot 2020. 5. 18. 19:22작성자 : 한양대학원 융합로봇시스템학과 유승환 석사과정 (CAI LAB)
네이버 Clova AI팀에서 연구한 OCR 딥러닝 모델을 custom data로 학습하는 과정을 정리해보겠습니다~!
* 2021년 3월 8일자 기준으로 내용 보완 중 입니다. (현재 6. 코드 분석 보완 중)
[필자 PC 환경]
OS : Ubuntu 18.04.03 LTS (네이버 클로바 공식 깃헙에서는 16.04로 진행함) & Window 10
그래픽 카드 : GTX 1080 Ti (Ubuntu) & RTX 2070 (Window)
CUDA : 10.1 (Ubuntu) & 10.2 (Ubuntu)
cuDNN : 7.5.0 (Ubuntu) & 7.6.5 (window)
python : 3.6.9 (Ubuntu 18.04의 default 값) & 3.6.12(Window)
pytorch : 1.3.1 (Ubuntu) & 1.8.0(Window)
torchvision : 0.4.2 (Ubuntu) & 0.9.0(Window)
[error 1] : lmdb.MapFullError: mdb_put: MDB_MAP_FULL: Environment mapsize limit reached
- 해결책 : create_lmdb_dataset.py에서 line 40의 env 변수에 map_size를 자신의 데이터 크기에 맞게 늘리기
- 원인 : lmdb.open 함수의 map_size의 default 값은 10485760으로, 데이터량이 10MB가 넘어가면 메모리 부족 문제 발생
- 필자는 데이터 용량이 300~500 MB여서 아래와 같이 수정함
- ex) env = lmdb.open(outputPath, map_size = 500000000)
- 참고 링크 : yochin47.blogspot.com/2016/10/mdbmapfull-environment-mapsize-limit.html
링크 0 (논문 원본) : https://arxiv.org/pdf/1904.01906.pdf
링크 1 (논문 리뷰) : https://ropiens.tistory.com/23
링크 2 (깃헙 코드) : https://github.com/clovaai/deep-text-recognition-benchmark
0. 딥러닝 환경 셋팅 (우분투) : https://ropiens.tistory.com/34
1. 딥러닝 환경 셋팅 (window) : ykarma1996.tistory.com/99
1. 추가적인 환경 설치
- 저는 Anaconda 환경에서 설치를 진행하였습니다.
# anaconda env 생성
conda create -n ocr python=3.6
# install pytorch (UBUNTU)
conda install torch==1.3.1 torchvision==0.4.2
# install pytorch (window)
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
# install lib for OCR
conda install -c anaconda pillow nltk natsort
conda install -c conda-forge fire python-lmdb
# install cv2 (선택)
conda install -c conda-forge opencv
# git을 설치 안한 경우 (Ubuntu)
sudo apt-get install git
# download Naver-OCR CODE
git clone https://github.com/clovaai/deep-text-recognition-benchmark.git
2. Demo 실행 (테스트용)
- pre-trained model(TPS-ResNet-BiLSTM-Attn.pth)을 다운받고, deep-text-recognition-benchmark 폴더 안에 넣기
- window에서는 python3 명령어 대신 python으로 하면 됩니다.
# OCR 폴더로 이동 (ubuntu)
cd ~/deep-text-recognition-benchmark
# window
cd deep-text-recognition-benchmark
# demo.py 실행 (UBUNTU)
CUDA_VISIBLE_DEVICES=0 python3 demo.py --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn --image_folder demo_image/ --saved_model TPS-ResNet-BiLSTM-Attn.pth
# window
python demo.py --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn --image_folder demo_image/ --saved_model TPS-ResNet-BiLSTM-Attn.pth
3. 커스텀 데이터셋 제작
- (3-1) deep-text-recognition-benchmark 폴더 안에 'data'와 'data_lmdb' 폴더 생성 (폴더 이름은 아무거나 해도 상관 없습니다.)
- (3-2) data, data_lmdb 폴더 안에 'training', 'validation' 폴더 생성 (폴더 이름은 크게 상관 없습니다. 두 폴더 중에 어느 것이 학습용이고 검증용인지만 결정하면 됩니다.)
- (3-3) training, validation 폴더 안에 학습 및 검증용 이미지 넣기 (추천 데이터 수 비율 = 6(학습) : 2(검증): 2(테스트))
- (3-4) data 폴더 안에 train_gt.txt, valid_gt.txt 파일 생성 후 라벨링하기 (이미지 경로와 라벨링 사이에 space가 아니라 꼭 tab을 넣어야 합니다!!!)
# 라벨링 하는 방법
# {imagepath}\tab{label}\n
# 라벨링 예시
A.jpg(\t)apple(\n)
B.jpg maplestroy
C.jpg hanyang@
D.jpg 안녕하세요
- (3-5) 이미지와 라벨링 텍스트 파일을 lmdb 형태로 변환 (window는 python으로 입력하기)
# create train_lmdb file
python3 create_lmdb_dataset.py --inputPath data/training --gtFile data/train_gt.txt --outputPath data_lmdb/training
# create valid_lmdb file
python3 create_lmdb_dataset.py --inputPath data/validation --gtFile data/valid_gt.txt --outputPath data_lmdb/validation
- 아래와 같은 내용이 나오면 테스트 성공!!
4. ~/deep-text-recognition-benchmark/train.py 코드 수정
- if __name__ == '__main__': 문 아래에 있는 add_argument들 수정
- 학습할 때 터미널 창에서 명령어 변수를 추가해도 되지만, 저는 코드 수정이 편해서 이렇게 진행합니다!
- --batch_size에서 default=192의 값을 자신의 그래픽카드 환경에 맞게 수정하기
- 저는 V-Ram이 부족해서 batch_size를 32로 수정했습니다.
- --select_data에서 default='MJ-ST'를 '/' 로 수정
- 우리는 OCR 학습데이터로 유명한 MJ, ST 데이터가 아니라, 커스텀 데이터셋을 사용하므로 그 경로에 맞게 수정
- --batch_ratio에서 default='0.5-0.5'를 '1' 로 수정
- 커스텀 데이터셋 1종류를 사용하므로, 1 사용
- --character에서 default='0123456789abcdefghijklmnopqrstuvwxyz'를 '(학습할 문자들)'로 수정
- 저의 경우 한글과 특수문자도 학습해야해서, default='@#$0123456789abcdefghijklmnopqrstuvwxyz한글대한민국만세' 으로 수정 (한글은 무조건 가, 나, 다 와 같이 자음+모음인 한 단어로 작성해야합니다.)
- --valInterval 인자
- 이 인자에 설정된 값만큼 해당 에포크에 검증 결과가 나옵니다. 처음에는 2000으로 설정되어 있어서, 2000에포크마다 결과가 나오는데, 저는 100 에포크마다 보고 싶어서 값을 100으로 설정했습니다.
5. 모델 학습
- 다음의 명령어를 터미널 창에 입력하면 학습이 진행됩니다!
- saved_models 폴더에서 log_train.txt, opt.txt 파일에서 학습 관련 로그를 확인할 수 있습니다!
- Window 에서 Multi GPU로 학습을 진행하는 경우, "TypeError : Can't Pickle Environment'라는 에러가 뜰 수 있습니다. 이 때에는 인자 --workers 0 으로 설정하면 됩니다. (train.py에서 주석 참고)
# Ubuntu
CUDA_VISIBLE_DEVICES=0 python3 train.py --train_data data_lmdb/training --valid_data data_lmdb/validation --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn
# Window
python train.py --train_data data_lmdb/training --valid_data data_lmdb/validation --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn
6. 코드 분석 (아키텍쳐 분석)
6-1) train.py (모델 학습 코드)
(1) 최적화 함수 (Optimizer)
default는 Adadelta 입니다. 그리고 추가적인 인자를 넣어서 Adam으로 변경할 수 있습니다. 각 알고리즘의 자세한 설명은 구글링이 훨씬 이해하기 좋을 것입니다! 추가적인 인자를 넣는 방법은 2가지가 있습니다. 첫 번째는 위의 train.py 명령어 뒤에 --adam을 추가하는 것입니다. 두 번째는 코드 자체에서 optimizer를 adam으로 설정하는 것입니다.
(2) Loss 함수
Loss 함수는 Prediction 모듈에 따라서 달라집니다. 먼저 Prediction 모듈을 CTC(Connectionist Temporal Classificaton)로 설정했을 경우, CTCLoss라는 CTC 전용 Loss 함수를 사용합니다. 이 함수는 torch의 기본 nn 라이브러리에서 제공됩니다. (CTCLoss함수의 코드 분석은 CTC 및 Attn 모듈을 분석한 다음에 작성하겠습니다.)
(만약 opt 인자로 baiduCTC를 추가하면, warpctc_pytorch에서 제공하는 CTCLoss 함수로 변경됩니다. 이 함수도 분석하고 싶지만, warpctc_pytorch 라이브러리 코드를 못찾은 관계로... 나중에 기회가 되면 분석해보겠습니다.)
Prediction 모듈을 Attn(Attention-based Sequence Prediction)으로 설정하면, Loss 함수는 우리가 흔히 알고 있는 CrossEntropyLoss로 변경됩니다.
왜 CTC에서는 CTC Loss를 사용하고, Attn은 CrossEntropyLoss를 사용하는지는 분석을 못했습니다. 이것도 추후 공부하면 업로드 하겠습니다.
7-2) modules 분석 (모델 모듈 코드)
(1) transformation.py (Trans : NONE | TPS)
TPS는 제가 아직은 관심이 없는 모듈이기에... OCR을 전부 분석한 다음에 업로드하겠습니다.
(2) feature_extraction.py (Feat : VGG | RESNET)
이미지의 feature를 추출하는 부분으로 VGG, RCNN, ResNet ,GRCL(Gated RCNN)으로 총 4종류가 있습니다. Feature를 추출하는 Backbone에 대한 설명은 주가 아니므로 설명을 생략하겠습니다.
(3) sequence_modeling.py (Seq : None | BiLSTM)
LSTM도 구글링을 해서 이해하는 것이 더 빠를 것입니다! BiLSTM과 LSTM의 차이점은 미래에서 과거로의 학습이 가능하냐의 차이점입니다. 코드에서는 nn.LSTM의 인자에 bidirectional=True만 추가하면 BiLSTM이 구현됩니다ㅋㅋ BiLSTM에 대한 설명이 잘 된 블로그가 있어서 아래에 첨부하겠습니다.
* BiLSTM이란? : intelligence.korea.ac.kr/members/wschoi/nlp/deeplearning/Bidirectional-RNN-and-LSTM/
(4-1) prediction.py (Pred : CTC | Attn)
제가 가장 중요하다고 생각되는 모듈입니다. prediction.py에는 Attention에 관한 코드 밖에 없는데요, CTC 및 Attn의 자세한 구현은 utils.py에 class CTCLabelConverter와 class AttnLabelConverter에 작성되어있습니다. 코드 분석을 하기 전에, 이론을 먼저 공부해보겠습니다.
(4-2) CTC : Connectionist Temporal Classificaton
내용 입력
* CTC 참고 링크 (영문) : distill.pub/2017/ctc/
OCR 딥러닝 실습 끝!! 여기까지 따라오느라 고생하셨습니다.
'sinanju06 > 딥러닝 환경 셋팅 및 코드 분석' 카테고리의 다른 글
NVIDIA Jeston 환경 셋팅 3편 (ROS, RealSense-SDK, RealSense-ROS설치) (16) | 2020.11.17 |
---|---|
NVIDIA Jeston 환경 셋팅 1-1편 (JetPack 설치 On AGX Xavier) (8) | 2020.10.30 |
YOLO V5 환경 셋팅 및 모델 아키텍쳐 분석하기 (104) | 2020.08.18 |
한글 + 영문 데이터 제작 with python-trdg (4) | 2020.05.25 |
Ubuntu 18.04 환경에서 CUDA 10.1, cuDNN 7.5.0, NVIDIA-Driver 그리고 Pytorch 1.3.1 설치하기 (9) | 2020.05.18 |