PyTorchを使って連続値制御の深層強化学習のSoft Actor Criticを構築する

スポンサーリンク

 
前回、深層強化学習における連続値制御のモデル構築について書きました。今回は、その時に構築したモデルを改良し、精度向上に取り組みます。前回の記事をまだ見ていない方は、ぜひご覧ください。


www.dskomei.com


連続値制御の深層強化学習を改良するために、今回は Soft Actor Critic というモデルを実装します。これは、学習時の選択手の幅をよりもたせるために、目的関数にエントロピー項を追加しています。従来の目的関数は期待報酬の合計でしたが、この目的関数にエントロピー項が加わり、それらを最大化させます。そのことで、期待報酬を大きくしながら選択手の探索幅を広げることも可能にしています。


本記事のコードを実行すると、棒を立て続けるゲームにおいて、以下の画像のようにほぼ揺れることなく立て続ける手を選択するモデルを構築できます。今回のコードはこちらに置いてあります。





SAC の理論


SACが記載されている論文は以下になります。今回はこれを参考にしながらモデルを構築しています。
[1801.01290] Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor
また、以下の SAC の解説記事も非常に参考にさせてもらいました。
[論文解説] Soft Actor-Critic - Qiita


上記の記事でアルゴリズムに関しては詳しく書かれているので、ここでは要所だけを書きます。


シンプルな Q 学習では、Q 関数を最大化するアクションを採用するので、探索力の弱さという問題を抱えています。これを解決しようと Q 学習の目的関数にエントロピー項を加えたのが Soft Q 学習です。強化学習の目的関数自体に方策エントロピー項を組み込むことにより、報酬の最大化を行いながら探索範囲を広げるということが自然にできるようになりました。この Soft Q 学習を使った Actor Critic が Soft Actor Critic(SAC) と呼ばれるものです。


SAC の良い点は以下2つが挙げられます。

  • off-policy であるため学習時のサンプル効率が高い
  • 方策エントロピー項により、学習時の探索範囲が広がり、学習が安定化する



Soft Q 関数の更新


Soft Q 関数は、通常の Q 関数の目的関数である期待報酬に方策エントロピー項を加えたものです。


\[J(\pi) = \sum_{t=0}^T \mathbb {E}_{(s_t, a_t)∼ρ_\pi}[r(s_t, a_t)+ α \mathcal{H}(\cdot | s_t)] \]


\(α\) は温度パラメータであり、\( \mathcal{H} \) はエントロピー関数です。\(J(\pi)\) を最大化させると、期待報酬に加えて方策エントロピーも最大化します。エントロピーが大きくなるということは複雑性が上がるということであり、強化学習においては選択手が多様になる状態です。選択手確率分布の分散が大きくなったような状態がイメージしやすいかと思います。


今回の方策エントロピーは以下の形になります。特定の選択手の確率だけが高ければエントロピーは小さくなり、複数の選択手の確率が高ければエントロピーは高くなります。


\[-\log \pi(a_t | s_t)\]


これらを踏まえて、Soft Q 関数はベルマン方程式を修正して以下のようになります。


\[\mathcal{T}^\pi Q(s_t, a_t) \triangleq r(s_t, a_t) + γ \mathbb{E}_{s_{t+1}∼p} [V(s_{t+1})] \]
\[ V(s_t) = \mathbb{E}_{a_t \sim \pi} [Q(s_t, a_t) − \log \pi(a_t|s_t)] \]


1番目の式はいつもどおりのベルマン方程式の形ですが、状態関数 \(V(s_t)\) が異なります。状態関数 \(V(s_t)\) にエントロピー項 \(-\log \pi(a_t|s_t)\) が入っています。
以上のことから、損失関数は以下になります。ここで、Q 関数の近似を \(f_\theta\) 、Q 関数のコピーを \(f_{\theta'}'\) としています。


\[\mathbf{Loss} = (f_\theta(s_t, a_t) - r(s_t, a_t) + γ [f_{\theta'}'(s_{t+1}, a_{t+1}) − \log \pi(a_{t+1}|s_{t+1})] )^2 \]


方策の更新


方策の更新に関しては、論文に以下の式を更新していくことで Soft Q 関数を最大化する方策に収束していくことが示されています。この式の導出に関しては論文を御覧ください。


\[ \pi_{new} = \arg \min_{\pi' \in Π} { \mathrm D_{\mathrm{KL}} \Biggl ( \pi'(\cdot | s_t) \, \bigg\Vert \, \frac {\exp( \frac {1} {\alpha} Q^{\pi_{old}} (s_t, \cdot))}{Z^{\pi_{old}}(s_t)} \Biggr ) } \]


上式の \( \mathrm D_{ \mathrm {KL} } \) は KLダイバージェンスのことであり、Soft Q 関数の指数分布と方策確率の分布の距離を求めるという式です。温度パラメータの \( \alpha \) が0に近づくと Soft Q 関数の値が重視されるため決定論的な方策選択になり、\(\alpha\) が \(\infty\) に近づくと Soft Q 関数は重視されなくなるので一様なランダム方策選択になります。つまり、方策の更新において \(\alpha\) をどうするかが大事です。今回の実装では、\(\alpha\) も最適化しています。


Actor を近似するネットワークは連続値制御であり、正規分布のノイズをつけた出力値であるため、以下の形で定義します。ここではネットワークのパラメータを \( \phi \) としています。


\[ a_t=f_\phi(\epsilon; s_t) \]


これをKLダイバージェンスに代入して目的関数を作ります。


\[ J_\pi(\phi) = \mathbb{E}_{s_t \sim \mathcal{D}, \epsilon_t \ sim \mathcal{N}} \Biggl [ { \mathrm D_{\mathrm{KL}} \biggl ( \pi_\phi (f_\phi(\epsilon_t; s_t) | s_t) \, \bigg\Vert \, \frac {\exp( \frac {1} {\alpha} Q_\theta (s_t, f_\phi(\epsilon; s_t) ))}{Z_\theta (s_t)} \biggr ) } \Biggr ] \]


\(Z_\theta (s_t)\) は \( \phi \) とは関係ないので無視すると、以下の式に変形できます。


\[ J_\pi(\phi) = \mathbb{E}_{s_t \sim \mathcal{D}, \epsilon_t \sim \mathcal{N} } \Bigl [ \alpha \log {\pi_\phi (f_\phi(\epsilon_t; s_t) | s_t) - Q_\theta(s_t, f_\phi (\epsilon_t; s_t))} \Bigr ]\]


方策の更新においては、この \(J_\pi(\phi)\) を最小化させます。


SACの構築


ここからは、SACを実装していきます。まずは必要なモジュールやパラメータの設定をします。


実装のための準備


今回使う主要なモジュールは「torch」と「gym」です。「torch」はディープラーニングを構築するためのもであり、「gym」は強化学習の環境を作るためのものです。今回使用する主要なモジュールのバージョンを以下に記載します。


モジュール名 バージョン
gym 0.24.1
torch 1.11.0
seaborn 0.11.2
numpy 1.21.5
pandas 1.4.1


モジュールをインポートします。

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import torch
import gym


ゲームの環境のゲーム名やディレクトリ、ランダムのシードを設定します。

gym_game_name = 'Pendulum-v1'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir_path = Path('model')
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)
    
result_dir_path = Path('result')
if not result_dir_path.exists():
    result_dir_path.mkdir(parents=True)

seed = 123456
torch.manual_seed(seed)
np.random.seed(seed)



ゲーム環境の構築


今回行う強化学習用のゲームは「Pendulm-v1」です。これは、振り子を立たせ続けることを目的とし、そのために -1 〜 1 の範囲で行動を選択するゲームです。

env = gym.make(gym_game_name)
env.action_space.seed(seed)



SAC の設計


Soft Actor Critic では、連続値制御の Actor Critic から両方のネットワークで変更が加えられています。それぞれについて見ていきます。


まず、Critic ではネットワーク内に2つのネットワークが作られています。これは Q 関数の過大見積もりを防ぐためです。Q 関数の学習では、Q 関数のコピーネットワークの値を使うため、それが過大に見積もっていた場合に、その影響を受けて学習してしまいます。これを防ぐために、ネットワーク内に2つのネットワークを用意し、出力値が小さいものを使うようにします。(Clipped Network)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal


class ClippedCriticNet(nn.Module):

    def __init__(self, input_dim, output_dim, hidden_size):

        super().__init__()

        self.linear1 = nn.Linear(input_dim, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_dim)

        self.linear4 = nn.Linear(input_dim, hidden_size)
        self.linear5 = nn.Linear(hidden_size, hidden_size)
        self.linear6 = nn.Linear(hidden_size, output_dim)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)

        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2


Actor では、ネットワークの出力の際にエントロピー項を追加しています。この値が方策の更新時と Soft Q 関数の更新時の損失値を求めるために使われます。エントロピー項には、出力値の対数確率に \( - \log (1 - y^2 ) + \varepsilon \) を足しています。これは、出力値が上下限に張り付かないようにしており、探索範囲の拡大に寄与しています。

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6


class SoftActorNet(nn.Module):

    def __init__(self, input_dim, output_dim, hidden_size, action_scale):

        super().__init__()

        self.linear1 = nn.Linear(input_dim, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)

        self.mean_linear = nn.Linear(hidden_size, output_dim)
        self.log_std_linear = nn.Linear(hidden_size, output_dim)

        self.action_scale = torch.tensor(action_scale)
        self.action_bias = torch.tensor(0.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super().to(device)



ネットワークの設計が完了したので、これらのネットワークを使って SAC モデルを設計します。ポイントになるのは、「update_parameters 関数」内での Critic と Actor の損失値を求めるところです。両方ともエントロピー項を加味した計算になっています。そして、最後に \( \alpha \) を最適化しています。

class SoftActorCriticModel(object):

    def __init__(self, state_dim, action_dim, action_scale, args, device):

        self.gamma = args['gamma']
        self.tau = args['tau']
        self.alpha = args['alpha']
        self.device = device
        self.target_update_interval = args['target_update_interval']

        self.actor_net = SoftActorNet(
            input_dim=state_dim, output_dim=action_dim, hidden_size=args['hidden_size'], action_scale=action_scale
        ).to(self.device)
        self.critic_net = ClippedCriticNet(input_dim=state_dim + action_dim, output_dim=1, hidden_size=args['hidden_size']).to(device=self.device)
        self.critic_net_target = ClippedCriticNet(input_dim=state_dim + action_dim, output_dim=1, hidden_size=args['hidden_size']).to(self.device)

        hard_update(self.critic_net_target, self.critic_net)
        convert_network_grad_to_false(self.critic_net_target)

        self.actor_optim = optim.Adam(self.actor_net.parameters())
        self.critic_optim = optim.Adam(self.critic_net.parameters())

        self.target_entropy = -torch.prod(torch.Tensor(action_dim).to(self.device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = optim.Adam([self.log_alpha])

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        if not evaluate:
            action, _, _ = self.actor_net.sample(state)
        else:
            _, _, action = self.actor_net.sample(state)
        return action.cpu().detach().numpy().reshape(-1)

    def update_parameters(self, memory, batch_size, updates):

        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).unsqueeze(1).to(self.device)
        mask_batch = torch.FloatTensor(mask_batch).unsqueeze(1).to(self.device)

        with torch.no_grad():
            next_action, next_log_pi, _ = self.actor_net.sample(next_state_batch)
            next_q1_values_target, next_q2_values_target = self.critic_net_target(next_state_batch, next_action)
            next_q_values_target = torch.min(next_q1_values_target, next_q2_values_target) - self.alpha * next_log_pi
            next_q_values = reward_batch + mask_batch * self.gamma * next_q_values_target

        q1_values, q2_values = self.critic_net(state_batch, action_batch)
        critic1_loss = F.mse_loss(q1_values, next_q_values)
        critic2_loss = F.mse_loss(q2_values, next_q_values)
        critic_loss = critic1_loss + critic2_loss

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        action, log_pi, _ = self.actor_net.sample(state_batch)

        q1_values, q2_values = self.critic_net(state_batch, action)
        q_values = torch.min(q1_values, q2_values)

        actor_loss = ((self.alpha * log_pi) - q_values).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_net_target, self.critic_net, self.tau)

        return critic_loss.item(), actor_loss.item()


ここまでで SAC モデルの設計のための大部分が終わりました。
上記に出てきたネットワークをコピーする関数とネットワークのすべてのパラメータを学習できないようする関数は以下のとおりです。

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


def convert_network_grad_to_false(network):
    for param in network.parameters():
        param.requires_grad = False


学習データを保存しておくためのメモリクラスを記載します。指定したサイズ分の学習データを保存し、指定したサイズ以上にはならないようにしています。

import random
import numpy as np


class ReplayMemory:

    def __init__(self, memory_size):
        self.memory_size = memory_size
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, mask):
        if len(self.buffer) < self.memory_size:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, mask)
        self.position = (self.position + 1) % self.memory_size

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)


以上で、SAC のすべての設計が完了です。次からは実際にゲームをプレイし、SAC モデルの学習を行います。


SAC モデルの学習


SAC モデルを学習します。ただ、モデルの設計の部分で今回のポイントとなるエントロピー項が含まれているため、ここでは通常の Actor Critic と同じ学習処理になっています。

args = {
    'gamma': 0.99,
    'tau': 0.005,
    'alpha': 0.2,
    'seed': 123456,
    'batch_size': 256,
    'hidden_size': 256,
    'start_steps': 1000,
    'updates_per_step': 1,
    'target_update_interval': 1,
    'memory_size': 100000,
    'epochs': 100,
    'eval_interval': 10
}

agent = SoftActorCriticModel(
    state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0],
    action_scale=env.action_space.high[0], args=args, device=device
)
memory = ReplayMemory(args['memory_size'])

episode_reward_list = []
eval_reward_list = []

n_steps = 0
n_update = 0
for i_episode in range(1, args['epochs'] + 1):

    episode_reward = 0
    done = False
    state = env.reset()

    while not done:
        
        if args['start_steps'] > n_steps:
            action = env.action_space.sample()
        else:
            action = agent.select_action(state)

        if len(memory) > args['batch_size']:
            agent.update_parameters(memory, args['batch_size'], n_update)
            n_update += 1

        next_state, reward, done, _ = env.step(action)
        n_steps += 1
        episode_reward += reward

        memory.push(state=state, action=action, reward=reward, next_state=next_state, mask=float(not done))

        state = next_state

    episode_reward_list.append(episode_reward)

    if i_episode % args['eval_interval'] == 0:
        avg_reward = 0.
        for _  in range(args['eval_interval']):
            state = env.reset()
            episode_reward = 0
            done = False
            while not done:
                with torch.no_grad():
                    action = agent.select_action(state, evaluate=True)
                next_state, reward, done, _ = env.step(action)
                episode_reward += reward
                state = next_state
            avg_reward += episode_reward
        avg_reward /= args['eval_interval']
        eval_reward_list.append(avg_reward)

        print("Episode: {}, Eval Avg. Reward: {:.0f}".format(i_episode, avg_reward))

print('Game Done !! Max Reward: {:.2f}'.format(np.max(eval_reward_list)))

torch.save(agent.actor_net.to('cpu').state_dict(), model_dir_path.joinpath(f'{gym_game_name}_sac_actor.pth'))




モデルの学習中の評価獲得報酬の推移を見てみます。

plt.figure(figsize=(8, 6), facecolor='white')
g = sns.lineplot(
    data=pd.DataFrame({
        'episode': range(args['eval_interval'], args['eval_interval'] * (len(eval_reward_list) + 1), args['eval_interval']),
        'reward': eval_reward_list
    }),
    x='episode', y='reward', lw=2
)
plt.title('{}エピソードごとの学習済みモデルにおける\n評価報酬の平均値の推移'.format(args['eval_interval']), fontsize=18, weight='bold')
plt.xlabel('エピソード')
plt.ylabel('獲得報酬の平均値')
for tick in plt.yticks()[0]:
    plt.axhline(tick, color='grey', alpha=0.1)
plt.tight_layout()
plt.savefig(result_dir_path.joinpath('{}_sac_eval_reward_{}.png'.format(gym_game_name, args['eval_interval'])), dpi=500)


上図を見れば分かるとおり、学習中の獲得報酬はエピソードごとに上がっており、学習が進むと一定以上を保っています。つまり、モデルがゲームの報酬を上げるようにプレイできています。


学習済みモデルの検証


以前作成した連続値制御をする Actor Critic と Soft Actor Critic を比較します。


まずは、モデルの学習時の評価獲得報酬を比べます。Actor Critic と Soft Actor Critic のそれぞれで 10 回 モデルを作り、その時の評価獲得報酬の推移を見てみます。



上図を見ると、Soft Actor Critic の方が若干早く評価獲得報酬が高くなっているのがわかります。まぁ、若干ですが。


次に、学習したそれぞれのモデルで 100回ずつプレイした際の獲得報酬の平均を比較してみます。



上図のとおり、Soft Actor Critic の方が獲得報酬の平均値は高いです。方策エントロピー項による探索範囲の拡大の効果があることがわかります。


終わりに


強化学習を一から勉強したいと思った方は以下の本が非常に参考になったので、ぜひご覧ください。