티스토리 뷰

keep9oing

DRQN 구현

HTS3 2021. 2. 3. 14:08

Reference: arxiv.org/pdf/1507.06527.pdf


COMA 구현을 하다가 RNN을 포함하는 agent 업데이트를 해야해서 가장 기본적이라고 하는 DRQN을 구현 해봄.

Code

github.com/keep9oing/DRQN-Pytorch-CartPole-v1

 

keep9oing/DRQN-Pytorch-CartPole-v1

Deep recurrent Q learning on CartPole-v1 environment - keep9oing/DRQN-Pytorch-CartPole-v1

github.com

에러 제보 환영입니다. :)


POMDP (partially observable MDP)

대부분의 강화학습 문제는 MDP로 문제를 정의하고 최대 objective(reward, entropy 등)을 달성하는 agent를 학습하는 것을 목표로 한다. 풀고자 하는 문제가 MDP로써 잘 정의 된다면 좋겠지만, 실제로 RL을 적용해야하는 많은 문제들의 agent의 상황은 문제해결에 필요한 전체 state를 관측할 수 없는 입장이다. 이런 문제에 대한 좀더 일반적인 정의로써 POMDP개념이 제시됐다.

MDP와 그 variation들에 대해 알아보면 은근히 깊게 연구돼있는 것을 볼 수 있는데, 간단히 차이만 확인하자면 일반적은 MDP는 수학적으로

$$(\mathcal{S}, \mathcal{A}, \mathcal{P}, \mathcal{R})$$

이렇게 state, action, transition, reward 페어로 나타낼 수 있다. POMDP는 여기서 더 확장해서 observation에대한 정보를 추가한 문제 정의인데,

$$(\mathcal{S}, \mathcal{A}, \mathcal{P}, \mathcal{R}, \Omega, \mathcal{O})$$

이렇게 정의될 수 있다. 추가된 $\Omega, \mathcal{O}$ 는 각각, observation($\Omega$)과 observation이 특정 state에서 관측될 확률($\mathcal{O}$)을 뜻한다. 일반적으로 부분관측(partially observable)상황에서는 Q 값을 추정하기 아주 힘들다.

$$Q(o, a \mid \theta) \neq Q(s, a \mid \theta)$$

위 논문은 recurrent neural network(특히 LSTM)을 이용해 두 갭차이를 메우려는 시도를 했다.

 

POMDP Reference

[Udacity 강의]

www.youtube.com/watch?v=1TRoIuvWNuI

[POMDP 튜토리얼]
www.sciencedirect.com/science/article/pii/S0022249609000042

 

A tutorial on partially observable Markov decision processes

The partially observable Markov decision process (POMDP) model of environments was first explored in the engineering and operations research communiti…

www.sciencedirect.com

 

Stable Recurrent Udpates

DQN을 보면 각 action에 대한 한개의 transtion을 replay memory에 저장해서 샘플링 후 사용하는 알고리즘을 사용한다.

dnddnjs.gitbooks.io/rl/content/deep_q_networks.html

 

Deep Q Networks · Fundamental of Reinforcement Learning

 

dnddnjs.gitbooks.io

DRQN에서는 RNN을 이용해 agent를 업데이트 하기때문에 DQN의 메모리 저장 방식과 배치 샘플링 방식과 다른 방식으로 학습을 진행해야한다. DRQN 논문에서는 2가지 방식이 제안되었는데, 공통점은 일단 episode를 통째로 진행하고 episode들을 메모리에 저장한다는 것이다.

  1. Bootstrapped Sequential Updates
    • 저장된 episode 중 하나를 랜덤으로 선택하여 해당 episode에 대해 RNN을 rollout하고 그 값을 이용해 target Q, Q 신경망의 estimation 에러를 업데이트 한다.
    • 처음에 RNN을 어떤식으로 rollout하란건지 잘 몰라서 고생이 많았는데 pytorch의 경우 시퀀스에 initial hidden stae를 넣어주면 알아서 레이어가 rollout을 해주기 때문에 그냥 구할 수 있었다.
  2. Bootstrapped Random Updates
    • 저장된 episode중 batch size만큼을 랜덤으로 선택하는데 여기서 Sequential update와 다른 점은 전체 episode를 사용하는게 아니라 랜덤한 step에 대해서 원하는 time step만큼을 잘라서 이것들로 배치를 구성하여 업데이트를 하는 것이다.
  • Sequential update의 경우 episode의 길이가 각기 달라 batch 업데이트를 하지 않는것이 정신건강에 좋은 것같다.
  • Random upate의 경우 초반에 내가 설정한 time step 만큼 episode길이가 안나와서 해당 부분에 대해서 예외 처리를 잘 해줘야 해서 코드가 좀 복잡해졌다.
  • Rollout할 때 initial hidden state를 zero로 설정을 해줘야한다. (중요)
  • 논문에서는 Sequntial update는 DQN의 랜덤 샘플링 policy가정을 위반하고, Random update는 "zeroing initial hidden state"로 인해 학습이 어렵다고한다. 논문에서는 Random update만을 이용해 결과를 보여줌
  • 내가 실험한 Cartpole 환경에서는 sequential update가 좀 더 나았음.

학습 환경 설정 (POMDP of CartPole)

일반적으로 Openai gym Cartpole 환경에서는 아래와 같은 변수들을 state로 하여 학습을 진행한다.

카트의 위치/속도, 폴의 각도/각속도

이 모든 state를 다 이용할 수 있을때 Fully observable한 상황이고 MDP문제 정의에 대한 학습을 성공적으로 할 수 있다.

 

나는 POMDP상황을 만들어주기 위해 카트의 속도, 폴의 각속도를 배제한 Position, Angle만을 관측할 수 있도록 설정하여 학습을 진행했다.

 

학습 결과

텐서보드

  • DQN(Fully observe)[Orange]: 정상적으로 학습할 경우 최대 리워드를 달성 할 수 있다.
  • DQN(POMDP)[Blue]: 내가설정한 POMDP(위치/각도만 확인가능)한 상황에서는 전혀 학습을 할 기미를 보이지 않았다
  • DRQN(POMDP)[Red]: DQN(POMDP)에 비해 확실히 더 잘되는 모습을 보여준다. 종종 최대 리워드를 달성하는 모습도 보여준다. 하지만 역시 Partially observe환 환경에서 완벽한 대안책은 되지 않는듯 하다. 이부분은 논문에서 보여준 결과와도 어느정도 일치한다. 내 경우 Random update가 아니라 Sequential update를 사용했다.

 

논문의 결과


Comment

  • 논문에서 DRQN이 Fully observable MDP의 상황에서도 잘된다고 하는데, 나 또한 그것이 가능함을 확인함. 이 경우 Random update를 파라미터 조정하면 좀 더 좋은 결과가 나왔음 
  • 하지만 DRQN으로 확장함에 따라 내가 제어해야하는 하이퍼파라미터들이 더 다양해져서 설정이 더 힘들어지는 것은 단점. 나는 DQN과 최대한 동일한 하이퍼 파라미터를 통해 비교하려고 해봤다.
  • Random update의 경우 roll out step을 몇으로 하는지에 민감하게 반응하는 듯 함

해당 구현체의 모든 결과는 Cartpole에서만 해본것이기에 좀 더 다양한 환경에서 확인해봐야한다. 다만 논문의 결과를 재현해낸것에 만족하며 여기까지만 진행.

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30