Skip to content
This repository has been archived by the owner on Dec 24, 2024. It is now read-only.

Some refactoring to files in pols directory #252

Open
tadashiK opened this issue Aug 20, 2019 · 4 comments
Open

Some refactoring to files in pols directory #252

tadashiK opened this issue Aug 20, 2019 · 4 comments

Comments

@tadashiK
Copy link

tadashiK commented Aug 20, 2019

As of now, some files in pols directory seem to need some modifications. For example, comments of BasePol class is difficult to understand. I suggest some improvements below.

@tadashiK tadashiK changed the title Poli Some refactoring to files in pols directory Aug 20, 2019
@tadashiK
Copy link
Author

tadashiK commented Aug 20, 2019

How about changing base.py to the following for readability. Main modifications are on comments and some if statements that makes the code unnecessarily deep.

import copy

import gym
import numpy as np
import torch.nn as nn

from machina.utils import get_device


class BasePol(nn.Module):
    """
    A base class of the policy. This class works as a "head" of a given neural network.
    The head can be RNN, appropriately normalize the output range of the neural network,
    and make the computation parallel.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0):
        nn.Module.__init__(self)
        self.observation_space = observation_space
        self.action_space = action_space
        self.net = net

        self.rnn = rnn
        self.hs = None  # A hidden state vector of the RNN.

        self.normalize_ac = normalize_ac
        self.data_parallel = data_parallel
        if data_parallel is True:
            self.dp_net = nn.DataParallel(self.net, dim=parallel_dim)
        elif data_parallel == 'ddp':
            self.net.to(get_device())
            self.dp_net = nn.parallel.DistributedDataParallel(
                self.net, device_ids=[get_device()], dim=parallel_dim)
        elif data_parallel is not False:
            raise ValueError(
                'data_parallel must be either Boolean value or str(ddp).')
        self.dp_run = False

        self.multi = isinstance(action_space, gym.spaces.MultiDiscrete)
        self.discrete =\
            self.multi or isinstance(action_space, gym.spaces.Discrete)

        if self.discrete is False:
            self.a_i_shape = action_space.shape
        else:
            if self.multi:
                nvec = action_space.nvec
                assert any([nvec[0] == nv for nv in nvec])
                self.a_i_shape = (len(nvec), nvec[0])
            else:
                self.a_i_shape = (action_space.n, )

    def __getstate__(self):
        state = self.__dict__.copy()
        if 'dp_net' in state['_modules']:
            _modules = copy.deepcopy(state['_modules'])
            del _modules['dp_net']
            state['_modules'] = _modules
        return state

    def __setstate__(self, state):
        if 'dp_net' in state:
            state.pop('dp_net')
        self.__dict__.update(state)

    def convert_ac_for_real(self, x):
        """
        Scales an action (x), which is the output of self.net, such
        that the action is appropriate for a real world task.
        """
        if not self.discrete:
            lb, ub = self.action_space.low, self.action_space.high
            if self.normalize_ac:
                x = lb + (x + 1.) * 0.5 * (ub - lb)
                x = np.clip(x, lb, ub)
            else:
                x = np.clip(x, lb, ub)
        return x

    def reset(self):
        """
        Resets a hidden state vector of the RNN.
        """
        self.hs = None if self.rnn else None

    def _check_obs_shape(self, obs):
        """
        Reshapes input (obs) appropriately.
        """
        additional_shape = 2 if self.rnn else 1
        if len(obs.shape) < additional_shape + len(self.observation_space.shape):
            for _ in range(additional_shape + len(self.observation_space.shape) - len(obs.shape)):
                obs = obs.unsqueeze(0)
        return obs

@tadashiK
Copy link
Author

tadashiK commented Aug 20, 2019

The codes of categorical_pol.py and multi_categorical_pol.py seem to be almost identical. Rather than having two separated files, how about combining them into a single file called discrete.py with the following code?

import torch
from machina.pols import BasePol
from machina.pds.categorical_pd import CategoricalPd
from machina.pds.multi_categorical_pd import MultiCategoricalPd
from machina.utils import get_device


class CategoricalPol(BasePol):
    r"""
    A policy for a discrete action space with one dimension.
    For example, such action space is given as
        :math:`\{ 0, 1, \\dots, n-1 \}`.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
        This must be an instance of gym.spaces.Discrete
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False,
                 normalize_ac=True, data_parallel=False, parallel_dim=0):
        BasePol.__init__(self, observation_space, action_space, net, rnn,
                         normalize_ac, data_parallel, parallel_dim)
        self.pd = CategoricalPd()
        self.to(get_device())

    def forward(self, obs, hs=None, h_masks=None):
        obs = self._check_obs_shape(obs)

        if self.rnn:
            pi = self.forward_with_rnn(obs, hs, h_masks, self.dp_run)
        else:
            pi = self.dp_net(obs) if self.dp_run else self.net(obs)

        ac = self.pd.sample(dict(pi=pi))
        ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
        return ac_real, ac, dict(pi=pi, hs=hs)

    def forward_with_rnn(self, obs, hs=None, h_masks=None, dp_run=False):
        time_seq, batch_size, *_ = obs.shape

        # If hs is None while self.hs is not, hs = self.hs.
        # If self.hs is again None, hs = self.net.init_hs(batch_size)
        hs = hs or self.hs or self.net.init_hs(batch_size)
        hs = tuple(h.unsqueeze(0) for h in hs[0:2]) if dp_run else hs
        
        h_masks = h_masks or hs[0].new(time_seq, batch_size, 1).zero_()
        h_masks = h_masks.reshape(time_seq, batch_size, 1)

        pi, hs = self.dp_net(obs, hs, h_masks) if dp_run else self.net(obs, hs, h_masks)
        self.hs = hs

        return pi

    def deterministic_ac_real(self, obs, hs=None, h_masks=None):
        """
        action for deployment
        """
        obs = self._check_obs_shape(obs)
        pi = self.forward_with_rnn(obs, hs, h_masks, dp_run=False)
        _, ac = torch.max(pi, dim=-1)
        ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
        return ac_real, ac, dict(pi=pi, hs=hs)


class MultiCategoricalPol(CategoricalPol):
    r"""
    A policy for a discrete action space with multiple dimensions.
    For example, such action space is given as
        :math:`\{ 0, 1, \\dots, n-1 \} \times \{ 0, 1, \\dots, m-1 \}`.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
        This must be an instance of gym.spaces.Discrete
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False,
                 normalize_ac=True, data_parallel=False, parallel_dim=0):
        BasePol.__init__(self, observation_space, action_space, net, rnn,
                         normalize_ac, data_parallel, parallel_dim)
        self.pd = MultiCategoricalPd()
        self.to(get_device())

@rarilurelo
Copy link
Contributor

Thanks for suggesting an improvement of documents and a sharing components between Categorical and MultiCategorical. I agree with the improvement of documents. I also agree with the sharing, but we need to check carefully. Could you send two PRs?

@tadashiK
Copy link
Author

tadashiK commented Aug 21, 2019

Thank you for the reply.

Yes, I will. However, before sending PRs, I would like to ask you two questions.

First, ArgmaxQfPol requires an instance of SAVfunc as a variable qfunc. However, in forward method, ArgmaxQfPol uses SAVfunc.max method, which SAVfunc does not have in general. Should I comment on this in the code, or should I change SAVfunc to CEMDeterministicSAVfunc, which seems to be the only one subclass of SAVfunc having max method? I can implement max method in other subclasses of SAVfunc if you want me to.

The second question is related to the first question. When an MDP has only a finite number of actions, max method can be drastically simpler (no need of optimization). Furthermore, in such situation, a Q-value function is frequently represented by a neural network accepting a state and outputting Q-values of all actions. Is there any possibility of dividing SAVfunc to two classes like ContSDiscAVFunc and ContSContAVFunc? Again, if you want me to, I can give it a try.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants