from __future__ import annotations
from typing import Tuple
import numpy as np
import torch
import narla
from narla.neurons.neuron import Neuron as BaseNeuron
TAU = 0.005
GAMMA = 0.99
EPSILON_START = 0.9
EPSILON_END = 0.05
EPSILON_DECAY = 1000
[docs]class Neuron(BaseNeuron):
"""
The DeepQ Neuron is based on this `PyTorch example <https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html>`_
:param observation_size: Size of the observation which the Neuron will receive
:param number_of_actions: Number of actions available to the Neuron
:param learning_rate: Learning rate for the Neuron's Network
"""
def __init__(self, observation_size: int, number_of_actions: int, learning_rate: float = 1e-4):
super().__init__(
observation_size=observation_size,
number_of_actions=number_of_actions,
learning_rate=learning_rate,
)
network = narla.neurons.deep_q.Network(
input_size=observation_size,
output_size=number_of_actions,
).to(narla.experiment_settings.trial_settings.device)
self._policy_network = network
self._target_network = network.clone()
self._number_of_steps = 0
self._loss_function = torch.nn.SmoothL1Loss()
self._optimizer = torch.optim.AdamW(self._policy_network.parameters(), lr=learning_rate, amsgrad=True)
[docs] def act(self, observation: torch.Tensor) -> torch.Tensor:
eps_threshold = EPSILON_END + (EPSILON_START - EPSILON_END) * np.exp(-1.0 * self._number_of_steps / EPSILON_DECAY)
self._number_of_steps += 1
if np.random.rand() > eps_threshold:
with torch.no_grad():
output = self._policy_network(observation)
action = output.max(1)[1].view(1, 1)
else:
action = torch.tensor(
data=[[np.random.randint(0, self.number_of_actions)]],
device=narla.experiment_settings.trial_settings.device,
dtype=torch.long,
)
self._history.record(
observation=observation,
action=action,
)
return action
[docs] def learn(self, *reward_types: narla.rewards.RewardTypes):
if len(self._history) < narla.experiment_settings.trial_settings.batch_size:
return
state_batch, action_batch, reward_batch, non_final_next_states, non_final_mask = self.sample_history(*reward_types)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken. These are the
# actions which would've been taken for each batch state according to policy_net
state_action_values = self._policy_network(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based on the "older" target_net; selecting
# their best reward with max(1)[0]. This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(
narla.experiment_settings.trial_settings.batch_size,
device=narla.experiment_settings.trial_settings.device,
)
with torch.no_grad():
next_state_values[non_final_mask] = self._target_network(non_final_next_states).max(1)[0]
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
loss = self._loss_function(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize the model
self._optimizer.zero_grad()
loss.backward()
# In-place gradient clipping
torch.nn.utils.clip_grad_value_(self._policy_network.parameters(), 100)
self._optimizer.step()
# Update the weights of the target network
self.update_target_network()
[docs] def sample_history(self, *reward_types: narla.rewards.RewardTypes) -> Tuple[torch.Tensor, ...]:
*rewards, observations, actions, next_observations, terminated = self._history.sample(
names=[
*reward_types,
narla.history.saved_data.OBSERVATION,
narla.history.saved_data.ACTION,
narla.history.saved_data.NEXT_OBSERVATION,
narla.history.saved_data.TERMINATED,
],
sample_size=narla.experiment_settings.trial_settings.batch_size,
)
# Combine the all the rewards into a Tensor with shape (batch_size, number_of_rewards)
rewards = torch.stack([torch.stack(reward_type).squeeze() for reward_type in rewards], dim=-1)
# Then sum the rewards along the samples
reward_batch = torch.sum(rewards, dim=-1)
observation_batch = torch.cat(observations)
action_batch = torch.cat(actions)
non_final_next_states = torch.cat([observation for observation, done in zip(next_observations, terminated) if not done])
non_final_mask = torch.tensor(
data=[not done for done in terminated],
device=narla.experiment_settings.trial_settings.device,
dtype=torch.bool,
)
return observation_batch, action_batch, reward_batch, non_final_next_states, non_final_mask
[docs] def update_target_network(self):
# Soft update of the target network's weights
# θ′ ← τ θ + (1 −τ )θ′
target_network_state_dict = self._target_network.state_dict()
policy_network_state_dict = self._policy_network.state_dict()
for key in policy_network_state_dict:
target_network_state_dict[key] = policy_network_state_dict[key] * TAU + target_network_state_dict[key] * (1 - TAU)
self._target_network.load_state_dict(target_network_state_dict)