Source code for narla.neurons.deep_q.network

from __future__ import annotations

import copy

import torch

from narla.neurons.network import Network as BaseNetwork


[docs]class Network(BaseNetwork): def __init__(self, input_size: int, output_size: int, embedding_size: int = 128): super().__init__() self._neural_network = self._build_network( input_size=input_size, output_size=output_size, embedding_size=embedding_size, ) @staticmethod def _build_network(input_size: int, output_size: int, embedding_size) -> torch.nn.Sequential: layers = [ torch.nn.Linear(input_size, embedding_size), torch.nn.LeakyReLU(), torch.nn.Linear(embedding_size, embedding_size), torch.nn.LeakyReLU(), torch.nn.Linear(embedding_size, output_size), ] return torch.nn.Sequential(*layers)
[docs] def clone(self) -> Network: cloned_network = copy.deepcopy(self) cloned_network._neural_network.load_state_dict(self._neural_network.state_dict()) return cloned_network
[docs] def forward(self, X: torch.Tensor): output = self._neural_network(X) return output