Source code for narla.neurons.actor_critic.neuron

from __future__ import annotations

import torch

import narla
from narla.neurons.neuron import Neuron as BaseNeuron

GAMMA = 0.99


[docs]class Neuron(BaseNeuron): """ The ActorCritic Neuron is based on this `PyTorch example <https://github.com/pytorch/examples/blob/main/reinforcement_learning/actor_critic.py>`_ :param observation_size: Size of the observation which the Neuron will receive :param number_of_actions: Number of actions available to the Neuron :param learning_rate: Learning rate for the Neuron's Network """ def __init__(self, observation_size: int, number_of_actions: int, learning_rate: float = 1e-4): super().__init__( observation_size=observation_size, number_of_actions=number_of_actions, learning_rate=learning_rate, ) self._network = narla.neurons.actor_critic.Network( input_size=observation_size, output_size=number_of_actions, ).to(narla.experiment_settings.trial_settings.device) self._loss_function = torch.nn.SmoothL1Loss() self._optimizer = torch.optim.AdamW(self._network.parameters(), lr=learning_rate, amsgrad=True)
[docs] def act(self, observation: torch.Tensor) -> torch.Tensor: action_probabilities, value = self._network(observation) distribution = torch.distributions.Categorical(action_probabilities) action = distribution.sample() self._history.record( observation=observation, log_probability=distribution.log_prob(action), state_value=value, action=action, ) return action
[docs] def learn(self, *reward_types: narla.rewards.RewardTypes): value_losses = [] policy_losses = [] returns = self.get_returns(*reward_types) state_values = self._history.get(narla.history.saved_data.STATE_VALUE) log_probabilities = self._history.get(narla.history.saved_data.LOG_PROBABILITY) for return_value, state_value, log_probability in zip(returns, state_values, log_probabilities): advantage = return_value - state_value # calculate actor (policy) loss policy_losses.append(-log_probability * advantage) # calculate critic (value) loss using L1 smooth loss value_loss = torch.nn.functional.smooth_l1_loss(state_value.squeeze(), return_value) value_losses.append(value_loss) # reset gradients self._optimizer.zero_grad() # sum up all the values of policy_losses and value_losses loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() # perform backprop loss.backward() self._optimizer.step() self._history.clear()
[docs] def get_returns(self, *reward_types: narla.rewards.RewardTypes) -> torch.Tensor: rewards = self._history.stack(*reward_types) rewards = torch.sum(rewards, dim=-1) returns = [] return_value = 0 for reward in reversed(rewards): return_value = reward + GAMMA * return_value returns.insert(0, return_value) returns = torch.tensor(returns, device=narla.experiment_settings.trial_settings.device) returns = (returns - returns.mean()) / (returns.std() + 1e-7) return returns