Pytorchを使って深層強化学習のモデルDQNを構築する 〜Deep Reinforcement Learning〜

スポンサーリンク

 
囲碁や将棋のコンピュータって強いですね。初期レベルでも全然勝てなくて、何度待ったをしたことか。
このようなゲームでは、ある手段を選択すると、状態が変化し、次の状態に移り、再び手段の選択をするということを繰り返し、最終的な勝ち負けが決まります。そして、取りうる手段と状態は有限です。まぁ、人間にとっては無限に感じますが。取りうる状態と手段が有限ならば、状態に合わせた選択手を学習して覚えればよいんじゃねぇということで、Deep Learning の出番なわけです。


これは、深層強化学習というテーマであり(ネタ的にはもう古いかもしれませんが)、ある状態に対して、選択した手段で状態が遷移し、それに合わせて報酬が得られるという環境において、最善の手段の組み合わせを求めるという強化学習を Deep Learning でやっています。強化学習の世界に真打ち登場といった感じでしょうか。


今回は、深層強化学習に一大ブームをもたらした Deep Q-Network 通称 DQN を構築します。Pytorch のチュートリアルであるREINFORCEMENT LEARNING (DQN) TUTORIALを参考にさせてもらい、自分なりにコードを再構築しました。今回のコードはこちら
に置いてあります。今回の実装によって、以下のようにプレイする DQN を構築できます。


f:id:dskomei:20211005131427g:plain:w400




準備


パラメータを保存しておく用の Python ファイルと必要なモジュールをインポートします。パラメータを保存するファイルを別途用意しておくことで、パラメータの管理がしやすくなります。


パラメータ保存用ファイル「settings.py」

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE_TRAIN = 128
BATCH_SIZE_VALID = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200

EPOCHS = 1000



必要なモジュールのインポート


今回使用するモジュールで重要なものは「gym」と「torch」です。「gym」はゲームのシミュレーターを立ち上げるために使います。また「torch」は、 Deep Learning の構築で使用します。どちらも pip を使えば簡単にインストールできます。他のモジュールも pip で簡単に入れられます。「settings」は上に記載した内容のファイルを読み込んでいます。

import gym
import math
import random
import math
import copy
import numpy as np
import pandas as pd
import time
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

import settings



ゲーム環境の構築


今回扱うゲームは OpenAI Gym の CartPole です。このゲームは、カートの上に棒が立てられており、棒が倒れきらないようにカートを動かして棒のバランスを整えるゲームです。1ステップごとに棒が倒れなければ、+1 の報酬が与えられ、棒が倒れた時点でゲームが終了し、-1 の得点が与えられます。つまり、棒を倒さないようにカートを左右に動かしながら、どれだけ長く保てるかを競うゲームです。


DQNの実装にすぐ入ってき行きたいのですが、ゲームのシミュレーション環境がないとどうしようもないので、先にゲーム環境を作ります。とはいっても、ライブラリが充実しているので、天下り的にコードを実行していけば、OKです。

env = gym.make('CartPole-v0').unwrapped


上記のコードによって、変数「env」に CartPole の必要な情報が入りました。以下のコードを実行すると、CartPole のウィンドウが立ち上がります。ここでは、コードだけを記すとして、このコード内の詳細に関しては、こちらをご覧ください。

def get_screen(env, resize):
    
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    _, screen_height, screen_width = screen.shape
    screen = screen[:, int(screen_height * 0.4): int(screen_height * 0.8)]
    
    view_width = int(screen_width * 0.6)
    cart_location = get_cart_location(env=env, screen_width=screen_width)
    
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)

    screen = screen[:, :, slice_range]
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)

    return resize(screen).unsqueeze(0)

def get_cart_location(env, screen_width):
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)

resize = T.Compose([
    T.ToPILImage(),
    T.Resize(40, interpolation=Image.CUBIC),
    T.ToTensor()
])

env.reset()
screen = get_screen(env=env, resize=resize)

f:id:dskomei:20211001183808p:plain:w500


DQNのモデルを構築する


今回構築するモデルでは、事前に学習データがあるわけではなく、Deep Learning が自らゲームをプレイし、学習データをためていきます。そして、その学習データでゲームをプレイしている Deep Learning を学習させていきます。強化学習ということもあって、ここらへんの流れは通常の Deep Learning の学習フローからすると特殊なので、アルゴリズムを先に見ておきます。


DQNのアルゴリズムを概観


ゲームは、膨大な種類の状態に対して最適な手の組み合わせを選択するものであり、状態の数が多ければ多いほど人間の思考能力では手に負えなくなってしまいます。ただ、状態に対する手の組み合わせにはパターンがあります。そのため、強化学習に Deep Learning を組み合わせれば人間を凌駕できるのではないかと考えられました。下の図のように、ある状態を入力として、Deep Learning にそれぞの手の価値を出力させるようにし、価値が最大である手を選択し、ゲームを進めます。


f:id:dskomei:20211005082335p:plain:w600


DQN では、手を選択する Deep Learning の学習データとして必要なデータは、自分自身が選択した手とそのときの状態、報酬、次の状態になります。そして、予測する値は、その状態で選択できる手それぞれに対して、手を選択したときのゲーム終了までの獲得価値(≒累積報酬)です。このとき、\( 状態_ t \) に対して出力する価値を \( \mathbf {Q}_t \) 、\( 状態_ t \) であるアクションを選んで得られた報酬を \(r_t\) とすると、次式が成り立ちます。


\[ \mathbf {Q}_t = r_t + \gamma \mathbf {Q}_{t + 1} \]


これは、ベルマン方程式と呼ばれます。この式の説明に関しては、多くの記事がすでにあるので、ここでは省かせてもらいます。この式で注目するのは、\( \mathbf {Q}_{t+1} \) です。これは次の状態の価値であり、これがわかっていれば Deep Learning の出力値の教師データとして使えるようになるのですが、学習中の Deep Learning を使ってその値を求めてそれを教師データとしてしまうと、 自分で自分を教えるという形になるため安定しなくなります。そこで、\( \mathbf {Q}_{t+1} \) を教えてくれる教師的な Deep Learning を使います。ただし、完璧な \( \mathbf {Q}_{t+1} \) を知っている Deep Learning はありません。もしすでにあるならば、最初からこのネットワークを使ってプレイすればよいわけです。なので、学習中の Deep Learning を適宜コピーして教師 Deep Learning とします。つまり、教師 Deep Learning も学習されていきます。


f:id:dskomei:20211005094203p:plain:w600


DQNのネットワーク構造


ゲームから得られる状態の情報は、画像情報であるため、今回構築する Deep Learning は CNN です。出力は、選択手の数分の実数値(価値)になるようにしています。

class DQN(nn.Module):
    
    def __init__(self, h, w, output_dim):
        
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        
        convw = self.conv2d_size_out(self.conv2d_size_out(self.conv2d_size_out(w)))
        convh = self.conv2d_size_out(self.conv2d_size_out(self.conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, output_dim)
        
    def forward(self, x):
        x = x.to(settings.device)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))
    
    def conv2d_size_out(self, size, kernel_size=5, stride=2):
        return (size - (kernel_size - 1) - 1) // stride + 1



学習データの保存


学習データは、直近の一定数分を保存しておくようにするため、「deque」を使っています。「capacity」 で指定した以上のデータ数になった場合は、古いデータから消されていきます。

class ReplayMemory(object):
    
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        
    def push(self, Transition, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)



手の選択


CNN を使って状態に対して選択手それぞれの価値を求め、最も高い価値の手を選択します。ただし、これでは選択する手が偏ってしまうため、ランダムに手を選択する処理を追加し、選択手が偏らないようにもします。その際に、ステップ数が増えるにつれランダム選択を採用しないようにし、学習の序盤ではランダム選択を重視しつつ、終盤では CNN の選択手を重視するようにしています。

def select_action(policy_net, state, n_action, steps_done):
    threshold = settings.EPS_END + (settings.EPS_START - settings.EPS_END) * math.exp(-1. * steps_done / settings.EPS_DECAY)
    if random.random() > threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_action)]], device=settings.device, dtype=torch.long)



DQNの学習


これまでの処理で、DQN を構築するためのコードができたので、学習データを作る関数と学習部分の関数を作り、実際に DQN の学習を行います。


学習データを作る関数と最適化関数


下のコードは、学習データと損失値確認用データを作るための関数です。「Transition」は後ほど定義がでてきますが、「namedtuple」型の変数です。名前付きで学習データを保存しておくことで、アクセスしやすいようにしています。
namedtupleで美しいpythonを書く!(翻訳) - Qiita


def make_train_and_valid_data(memory, Transition):
    
    transitions = memory.sample(settings.BATCH_SIZE_TRAIN + settings.BATCH_SIZE_VALID)
    indexes = list(range(settings.BATCH_SIZE_TRAIN + settings.BATCH_SIZE_VALID))
    random.shuffle(indexes)
    transitions_train = [transitions[i] for i in indexes[:settings.BATCH_SIZE_TRAIN]]
    transitions_valid = [transitions[i] for i in indexes[settings.BATCH_SIZE_TRAIN:]]
    batch_train = Transition(*zip(*transitions_train))
    batch_valid = Transition(*zip(*transitions_valid))

    return batch_train, batch_valid


「policy_net」に \( \mathbf{Q}_t \) を、「target_net」に \( \mathbf{Q}_{t+1} \) を予測させています。\( 状態_{t+1} \) がゲーム終了の時(「next_state」が None の時)、 \( \mathbf{Q}_t \) は0になるように学習しています。その他のときは、\( r + \gamma \mathbf{Q}_{t+1} \) になるように学習しています。損失値確認用関数は、最適化されないだけで学習用とほぼ同じ処理の流れです。


def train_model(policy_net, target_net, batch, Transition, criterion):

    policy_net.train()

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=settings.device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(settings.BATCH_SIZE_TRAIN, device=settings.device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * settings.GAMMA) + reward_batch

    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

    return loss

def valid_model(policy_net, target_net, batch, Transition, criterion):

    policy_net.eval()
    with torch.no_grad():

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=settings.device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = policy_net(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(settings.BATCH_SIZE_VALID, device=settings.device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = (next_state_values * settings.GAMMA) + reward_batch

        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    return loss



DQNの学習


これまでに定義した関数と DQN のクラスを使って、学習させていきます。1回のゲームが終了するまで学習データをため、ゲームが終了するたびに学習させています。教師となる Deep Learning の重みは、学習中の Deep Learning で最も良い結果となったものにしています。ただし、100回更新がない場合は、その時点での学習中の Deep Learning の重みをコピーし、教師 Deep Learning が全く更新されなくなることを防いでいます。Deep Learning の入力は、現在の画像から前の画像を引いた差分です。これは、カートと棒がどのように動いていたかを画像情報からわかるようにするためです。

policy_net = DQN(h=screen_height, w=screen_width, output_dim=n_action).to(settings.device)
target_net = DQN(h=screen_height, w=screen_width, output_dim=n_action).to(settings.device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

memory = ReplayMemory(50000)

criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(policy_net.parameters())
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
steps_done = 0
best_net = None
best_loop = 0
counter = 1
results = []

for i_episode in range(settings.EPOCHS):

    env.reset()
    last_screen = get_screen(env=env, resize=resize)
    current_screen = get_screen(env=env, resize=resize)
    state = current_screen - last_screen
    
    for t in count():
        
        start_time = time.time()
        # DQNを使ったアクションの選択
        action = select_action(
            policy_net=policy_net, state=state, n_action=n_action, steps_done=steps_done
        )
        
        # 選択されたアクションで状態遷移し、報酬を得る
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=settings.device)
        
        # 状態遷移後の画像差分を計算
        last_screen = current_screen
        current_screen = get_screen(env=env, resize=resize)
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None

        memory.push(Transition, state, action, next_state, reward)

        state = next_state
        
        if done:
            break

    steps_done += 1

    if len(memory) >= settings.BATCH_SIZE_TRAIN + settings.BATCH_SIZE_VALID:
        batch_train, batch_valid = make_train_and_valid_data(memory=memory, Transition=Transition)

        loss_train = train_model(
            policy_net=policy_net, target_net=target_net, batch=batch_train,
            Transition=Transition, criterion=criterion
        )
        loss_valid = train_model(
            policy_net=policy_net, target_net=target_net, batch=batch_valid,
            Transition=Transition, criterion=criterion
        )
        elapsed_time = time.time() - start_time
        print('[{}/{} max loop: {}] train loss: {:.4f}, valid loss: {:.4f} [{}{:.0f}s] {}{}'.format(
            i_episode, settings.EPOCHS, t,
            loss_train, loss_valid,
            str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
            elapsed_time % 60,
            counter,
            ' **' if t > best_loop else ''
        ))

        results.append([i_episode, t, loss_train.item(), loss_valid.item()])
            
    if t > best_loop:
        best_net = copy.deepcopy(policy_net)
        best_loop = t
        target_net.load_state_dict(policy_net.state_dict())
        counter = 1
    else:
        counter += 1
        if counter > 100:
            target_net.load_state_dict(policy_net.state_dict())
            counter = 1
            
results = pd.DataFrame(results, columns=['i_episode', 'max_loop', 'loss_train', 'loss_valid'])
print('Complete')
env.render()
env.close()

f:id:dskomei:20211004214151p:plain:w500


学習中のゲーム成績の確認


下の図では、学習ループ回数と棒を倒さずに続けられた回数の関係を見ています。最初は、全然続けられていないですが、学習のループ回数が600回を過ぎたあたりから続く回数が増え、600回以上続いている試行も見られます。しかし、低回数になっている箇所もあり、安定性の向上は必要です。


f:id:dskomei:20211005105937p:plain:w500


学習済みの DQN でゲームをプレイ


学習した DQN を使ってゲームをプレイします。その結果を見る前に、学習初期段階のプレイ状況を先に確認してみましょう。


f:id:dskomei:20211005131213g:plain:w400


ゲーム開始後すぐに棒が倒れてしまっています。棒を倒さないように動かしている様子は見られません。
次に、学習後の DQN でプレイした結果を見てみます。


f:id:dskomei:20211005131427g:plain:w400


棒を倒さないようにカートを動かしている様子が見られます。DQN が状況に合わせて、価値が高くなる行動を取れるようになっていることがわかります。


終わりに


今回は、モデルを動かして結果を確認することに主眼をおいたため、精度を向上させるための取り組みはあまりできなかったです。結果を見てみると、学習ステップがある程度経過しても、早い段階でゲーム終了になってしまっているケースが見られます。パッと思いつく限りでも、以下の点を実験する必要がありそうです。

  • 学習データが直近の結果だけを残すようになっているため、悪い結果が続くと、学習データに良い結果のものが残らなくなる
  • モデルサイズが小さい


これらのことに関して、今後試してみようと思います。