티스토리 뷰

editor, Junyeob Baek
Robotics Software Engineer /RL, Motion Planning and Control, SLAM, Vision

Linkedin Badge Github Badge


이번에는 얼마전 Github에 공개한 오픈소스 패키지에 대해 소개해보려한다. :)

 

policy distillation은 현재 연구중인 논문과 관련해 찾아보다가 유용하게 쓸 수 있겠다 싶어 자세히 공부하고 있었던 개념이다. 근데 생각보다 Github에 control task를 위한 policy distillation 모듈이 제대로 구현되어있는 repo가 없다는게 함정이다. 나름 DeepMind에서 나온 논문이고 쓸만하다고 생각하는 개념인데 인기가 생각보다 없나보다...T.T

 

policy distillation 검색결과.. 없어도 너무없다 ㅜ (만든지 2일 지난 내 패키지가 세번째에 똭..! 뿌듯)

 

 

어짜피 하던 연구를 진행하려면 제대로 된 policy distillation 모듈이 필요했고, 코드로 작동원리도 익혀볼 겸 간단히 프로토타이핑을 진행해보기로 했다.

 


 

이 때 참고한 오픈소스는 총 두가지인데,

첫번째 패키지인 Mee321/policy-distillation는 README에 제대로 된 가이드가 없어 일단 받아보고 전체 구조를 파악해보기로했다.

 

Mee321/policy-distillation

Contribute to Mee321/policy-distillation development by creating an account on GitHub.

github.com

 

근데 웬걸 생각보다 코드가 너무 좋았다.

기본적인 distillation 모듈은 물론 학습방법이나 loss function을 위한 옵션들도 갖추고 있었다.

(ray를 이용한 다중 학습환경까지 되어있었으니 말 다했다..)

Manual과 몇 개의 사진이나 영상만 있었어도 좋았을텐데..!

 

사실 그냥 이 모듈 그대로 연구주제에 적용해보기 시작하면 조금 더 빨리 할 수도 있었는지만..

조금은 욕심이 났다. 처음부터 baseline으로 학습시킨 어려운 control task를 distillation해도 잘 작동할지 궁금했다.


두번째 패키지인 DLR-RM/rl-baselines3-zoo은 독일 항공우주센터 DLR-RM에서 진행된 프로젝트인 stable_baselines 시리즈의 pytorch버전이다. 해당 패키지는 baselines3 라이브러리를 통해 학습시키고 실행시켜볼 수 있는 환경을 제공한다.

 

DLR-RM/rl-baselines3-zoo

A collection of pre-trained RL agents using Stable Baselines3, training and hyperparameter optimization included. - DLR-RM/rl-baselines3-zoo

github.com

해당 패키지로부터 잘 학습된 강화학습 에이전트와 이를 작동시킬 환경을 얻을 수 있었다.

 

자 이제 policy distillation 모듈과 합쳐주기만 하면 된다!

말은 쉽지만.. 간단히 인터페이스만 하면 될 줄 알았는데 생각보다 할 일이 많았다..
기본적으로 사용하고 있던 환경 자체가 달랐기 때문에 sample data를 수집하는 방법... 학습할때 생기는 문제... policy 형태가 다를 때 생기는 문제들.. 학습된 에이전트의 알고리즘 종류 때문에 생기는 문제들.. 할많하않.. TL;DR..

 


 

그리고 드디어..! 약 7일간의 고군분투 끝에.. 드디어 전체적인 시스템이 완성되었다.

(약간 그림이 깨지는 것 같으니 본 패키지에서 확인바람)

github.com/CUN-bjy/policy-distillation-baselines

 

CUN-bjy/policy-distillation-baselines

Pytorch Implementation of Policy Distillation for control, which has well-trained teachers via stable_baselines3. - CUN-bjy/policy-distillation-baselines

github.com

overview of the repo 'policy-distillation-baselines'

Flow)

1. Teacher(trained agent) policy로부터 sample data를 취득한다. (환경을 실제로 돌려 다양한 state/action쌍을 수집)

2. Teacher로부터 수집한 Experience Data와 같은 states를 넣었을 때 Student(distilled agent) policy로부터 계산된 action을 비교(KL loss)한다. 이 차이를 이용해 Student policy network을 학습시킨다.

3. 2번을 충분히 반복한다.

4. Evaluation을 위해 Student 모델을 환경에서 돌려본 후 평균 reward를 구해주며 Teacher모델과 비교한다.

5. 충분한 학습을 위해 3,4번을 반복하며, 일정 주기로 1번을 다시 실행시켜 sample에 대한 overfitting을 피해준다.

6. 충분히 reward가 비슷한 수준까지 학습되면 학습 종료.

 

policy distillation을 통한 policy 경량화 및 주입(?) 


위 데모영상은 Gym 환경에서 TD3알고리즘을 이용해 Ant task를 학습한 trained agent와 policy distillation 알고리즘으로 경량화한 distilled agent이다.

 

무려 400x300사이즈에서 64x64사이즈로 줄였는데도 단 100초(600 iter)만에 훌륭히 policy를 전달한다.. ㄷㄷ

(실제로 average rewards가 3200대로 비슷하게 나와 성능손실이 전혀없음)


policy distillation의 장점이라면 빠르게 모델의 메모리와 연산량을 크게 줄이도록 경량화 할 수 있다는 것이며 성능적으로 손실이 거의 없다는 것에 있다.

그리고 모델간 압축이 가능하듯이 각 task를 학습한 teacher를 하나의 student에 경량화하여 밀어넣는게 가능하다..(학생 한명이 국영수를 각 선생님한테 배우듯이..)

 

이러한 부분 덕분에 policy distillation은 최근 RL에서 주목하고 있는 continual learning과 multi-task learning 등에서 다루고 있는 문제에 대한 방법들 중 하나로서 종종 언급된다.

 

특히 Robotics 분야에서는 쉽게 다양한 알고리즘의 결합을 할 수 있다는 점과 On-board 연산을 위한 경량화에 적합해 매우 매력적이라고 생각된다!

 

 

 

해당 프로젝트에 대해 충분히 설명하기 위해 Policy Distillation 자체에 대한 리뷰를 별도 작성하도록 하겠습니다!

Links

repo : github.com/CUN-bjy/policy-distillation-baselines

 

CUN-bjy/policy-distillation-baselines

Pytorch Implementation of Policy Distillation for control, which has well-trained teachers via stable_baselines3. - CUN-bjy/policy-distillation-baselines

github.com

 

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30