From 957f052821b4b0e6fdf1b31e9dbd0bad79ee9f22 Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Wed, 14 Oct 2020 18:22:31 +0100 Subject: [PATCH 1/3] Added option for multiple episodes per batch Re-arranged imports --- squiRL/common/data_stream.py | 16 +++++++----- squiRL/vpg/vpg.py | 47 ++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index 8917e6a..93798c4 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -3,14 +3,16 @@ Attributes: Experience (namedtuple): An environment step experience """ -import numpy as np -from torch.utils.data.dataset import IterableDataset from collections import deque from collections import namedtuple -from squiRL.common.policies import MLP -import gym from typing import Tuple +import gym +import numpy as np +from torch.utils.data.dataset import IterableDataset + +from squiRL.common.policies import MLP + Experience = namedtuple('Experience', ('state', 'action', 'reward', 'done', 'last_state')) @@ -88,8 +90,9 @@ class RLDataset(IterableDataset): net (nn.Module): Policy network replay_buffer: Replay buffer """ + def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, - agent) -> None: + agent, episodes_per_batch: int = 1) -> None: """Summary Args: @@ -102,6 +105,7 @@ def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, self.env = env self.net = net self.agent = agent + self.episodes_per_batch = episodes_per_batch def populate(self) -> None: """ @@ -119,7 +123,7 @@ def __iter__(self): Yields: Tuple: Sampled experience """ - for i in range(1): + for i in range(self.episodes_per_batch): self.populate() states, actions, rewards, dones, new_states = self.replay_buffer.sample( ) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 91bd846..06663fb 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -2,8 +2,10 @@ """ import argparse from argparse import ArgumentParser +from collections import OrderedDict from copy import copy from typing import Tuple, List + import gym import numpy as np import pytorch_lightning as pl @@ -13,11 +15,10 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.data._utils import collate -from collections import OrderedDict from squiRL.common import reg_policies -from squiRL.common.data_stream import RLDataset, RolloutCollector from squiRL.common.agents import Agent +from squiRL.common.data_stream import RLDataset, RolloutCollector class VPG(pl.LightningModule): @@ -44,6 +45,7 @@ def __init__(self, hparams: argparse.Namespace) -> None: self.env = gym.make(self.hparams.env) self.gamma = self.hparams.gamma self.eps = self.hparams.eps + self.episodes_per_batch = self.hparams.episodes_per_batch obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n @@ -84,6 +86,10 @@ def add_model_specific_args( type=int, default=20, help="num of dataloader cpu workers") + parser.add_argument("--episodes_per_batch", + type=int, + default=1, + help="number of episodes to be sampled per training step") return parser def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: @@ -104,8 +110,8 @@ def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: res.append(copy(sum_r)) return list(reversed(res)) - def vpg_loss(self, batch: Tuple[torch.Tensor, - torch.Tensor]) -> torch.Tensor: + def vpg_loss(self, + batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: """ Calculates the loss based on the REINFORCE objective, using the discounted @@ -128,14 +134,14 @@ def vpg_loss(self, batch: Tuple[torch.Tensor, discounted_rewards = self.reward_to_go(rewards) discounted_rewards = torch.tensor(discounted_rewards) advantage = (discounted_rewards - discounted_rewards.mean()) / ( - discounted_rewards.std() + self.eps) + discounted_rewards.std() + self.eps) advantage = advantage.type_as(log_probs) loss = -advantage * log_probs return loss.sum() - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], - nb_batch) -> OrderedDict: + def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + nb_batch) -> pl.TrainResult: """ Carries out an entire episode in env and calculates loss @@ -143,14 +149,19 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], OrderedDict: Training step result Args: - batch (Tuple[torch.Tensor, torch.Tensor]): Current mini batch of - replay data + batch (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]): Current + mini batch of replay data nb_batch (TYPE): Current index of mini batch of replay data """ - _, _, rewards, _, _ = batch - episode_reward = rewards.sum().detach() + loss = None + for episode in batch: + _, _, rewards, _, _ = episode + episode_reward = rewards.sum().detach() - loss = self.vpg_loss(batch) + if loss is None: + loss = self.vpg_loss(episode) + else: + loss += self.vpg_loss(episode) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) @@ -191,13 +202,7 @@ def collate_fn(self, batch): """ batch = collate.default_convert(batch) - states = torch.cat([s[0] for s in batch]) - actions = torch.cat([s[1] for s in batch]) - rewards = torch.cat([s[2] for s in batch]) - dones = torch.cat([s[3] for s in batch]) - next_states = torch.cat([s[4] for s in batch]) - - return states, actions, rewards, dones, next_states + return batch def __dataloader(self) -> DataLoader: """Initialize the RL dataset used for retrieving experiences @@ -205,11 +210,11 @@ def __dataloader(self) -> DataLoader: Returns: DataLoader: Handles loading the data for training """ - dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent) + dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent, self.episodes_per_batch) dataloader = DataLoader( dataset=dataset, collate_fn=self.collate_fn, - batch_size=1, + batch_size=self.episodes_per_batch, ) return dataloader From b68d168895201f23b09f738dce8310077e50632c Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Sat, 17 Oct 2020 17:23:27 +0100 Subject: [PATCH 2/3] Prepare code for shared buffer Add num_workers to VPG --- squiRL/common/data_stream.py | 61 +++++++++++++++++++++++++++++------- squiRL/vpg/vpg.py | 11 +++++-- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index 93798c4..c012ca8 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -10,6 +10,8 @@ import gym import numpy as np from torch.utils.data.dataset import IterableDataset +import torch.multiprocessing as mp +import torch from squiRL.common.policies import MLP @@ -29,14 +31,33 @@ class RolloutCollector: capacity (int): Size of the buffer replay_buffer (deque): Experience buffer """ - def __init__(self, capacity: int) -> None: + def __init__(self, capacity: int, state_shape: tuple, action_shape: tuple, should_share: bool = False) -> None: """Summary Args: capacity (int): Description """ + + state_shape = [capacity] + list(state_shape) + action_shape = [capacity] + list(action_shape) + self.capacity = capacity - self.replay_buffer = deque(maxlen=self.capacity) + self.count = torch.tensor([0], dtype=torch.int64) + self.states = torch.zeros(state_shape, dtype=torch.float32) + self.actions = torch.zeros(action_shape, dtype=torch.float32) + self.rewards = torch.zeros((capacity), dtype=torch.float32) + self.dones = torch.zeros((capacity), dtype=torch.bool) + self.next_states = torch.zeros(state_shape, dtype=torch.float32) + + if should_share: + self.count.share_memory_() + self.states.share_memory_() + self.actions.share_memory_() + self.next_states.share_memory_() + self.rewards.share_memory_() + self.dones.share_memory_() + + self.lock = mp.Lock() def __len__(self) -> int: """Calculates length of buffer @@ -44,7 +65,7 @@ def __len__(self) -> int: Returns: int: Length of buffer """ - return len(self.replay_buffer) + return self.count.detach().numpy().item() def append(self, experience: Experience) -> None: """ @@ -52,9 +73,25 @@ def append(self, experience: Experience) -> None: Args: experience (Experience): Tuple (state, action, reward, done, - new_state) + last_state) """ - self.replay_buffer.append(experience) + + with self.lock: + if self.count[0] < self.capacity: + self.count[0] += 1 + + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + + self.states[nr] = torch.tensor(experience.state, dtype=torch.float32) + self.actions[nr] = torch.tensor(experience.action, dtype=torch.float32) + self.rewards[nr] = torch.tensor(experience.reward, dtype=torch.float32) + self.dones[nr] = torch.tensor(experience.done, dtype=torch.bool) + self.next_states[nr] = torch.tensor(experience.last_state, dtype=torch.float32) + + else: + exit("RolloutCollector: Buffer is full but samples are being added to it") + def sample(self) -> Tuple: """Sample experience from buffer @@ -62,17 +99,17 @@ def sample(self) -> Tuple: Returns: Tuple: Sampled experience """ - states, actions, rewards, dones, next_states = zip( - *[self.replay_buffer[i] for i in range(len(self.replay_buffer))]) - return (np.array(states), np.array(actions), - np.array(rewards, dtype=np.float32), - np.array(dones, dtype=np.bool), np.array(next_states)) + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + return (self.states[:nr], self.actions[:nr], self.rewards[:nr], self.dones[:nr], self.next_states[:nr]) def empty_buffer(self) -> None: - """Empty replay buffer + """Empty replay buffer by resetting the count (so old data gets overwritten) """ - self.replay_buffer.clear() + with self.lock: + # the [0] is very important, otherwise we throw the tensor out and the int that replaces it won't get shared + self.count[0] = 0 class RLDataset(IterableDataset): diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 06663fb..87f16d0 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -46,11 +46,12 @@ def __init__(self, hparams: argparse.Namespace) -> None: self.gamma = self.hparams.gamma self.eps = self.hparams.eps self.episodes_per_batch = self.hparams.episodes_per_batch + self.num_workers = hparams.num_workers obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n self.net = reg_policies[self.hparams.policy](obs_size, n_actions) - self.replay_buffer = RolloutCollector(self.hparams.episode_length) + self.replay_buffer = RolloutCollector(self.hparams.episode_length, self.env.observation_space.shape, self.env.action_space.shape) self.agent = Agent(self.env, self.replay_buffer) @@ -128,8 +129,10 @@ def vpg_loss(self, action_logit = self.net(states.float()) log_probs = F.log_softmax(action_logit, - dim=-1).squeeze(0)[range(len(actions)), - actions] + dim=-1)[range(len(actions)), + actions.long()] + + discounted_rewards = self.reward_to_go(rewards) discounted_rewards = torch.tensor(discounted_rewards) @@ -215,6 +218,8 @@ def __dataloader(self) -> DataLoader: dataset=dataset, collate_fn=self.collate_fn, batch_size=self.episodes_per_batch, + num_workers=self.num_workers, + pin_memory=True ) return dataloader From 1587d9ff79229e999ab7c5322c38b97784d32769 Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Mon, 19 Oct 2020 10:39:10 +0100 Subject: [PATCH 3/3] Fix num_workers to 0 when debugging --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index a8bb7ed..b7e9d3c 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ def main(hparams) -> None: if hparams.debug: hparams.logger = None hparams.profiler = True - hparams.num_workers = None + hparams.num_workers = 0 else: hparams.logger = WandbLogger(project=hparams.project) seed_everything(hparams.seed)