Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions squiRL/common/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
Attributes:
Experience (namedtuple): An environment step experience
"""
import numpy as np
from torch.utils.data.dataset import IterableDataset
from collections import deque
from collections import namedtuple
from squiRL.common.policies import MLP
import gym
from typing import Tuple

import gym
import numpy as np
from torch.utils.data.dataset import IterableDataset
import torch.multiprocessing as mp
import torch

from squiRL.common.policies import MLP

Experience = namedtuple('Experience',
('state', 'action', 'reward', 'done', 'last_state'))

Expand All @@ -27,50 +31,85 @@ class RolloutCollector:
capacity (int): Size of the buffer
replay_buffer (deque): Experience buffer
"""
def __init__(self, capacity: int) -> None:
def __init__(self, capacity: int, state_shape: tuple, action_shape: tuple, should_share: bool = False) -> None:
"""Summary

Args:
capacity (int): Description
"""

state_shape = [capacity] + list(state_shape)
action_shape = [capacity] + list(action_shape)

self.capacity = capacity
self.replay_buffer = deque(maxlen=self.capacity)
self.count = torch.tensor([0], dtype=torch.int64)
self.states = torch.zeros(state_shape, dtype=torch.float32)
self.actions = torch.zeros(action_shape, dtype=torch.float32)
self.rewards = torch.zeros((capacity), dtype=torch.float32)
self.dones = torch.zeros((capacity), dtype=torch.bool)
self.next_states = torch.zeros(state_shape, dtype=torch.float32)

if should_share:
self.count.share_memory_()
self.states.share_memory_()
self.actions.share_memory_()
self.next_states.share_memory_()
self.rewards.share_memory_()
self.dones.share_memory_()

self.lock = mp.Lock()

def __len__(self) -> int:
"""Calculates length of buffer

Returns:
int: Length of buffer
"""
return len(self.replay_buffer)
return self.count.detach().numpy().item()

def append(self, experience: Experience) -> None:
"""
Add experience to the buffer

Args:
experience (Experience): Tuple (state, action, reward, done,
new_state)
last_state)
"""
self.replay_buffer.append(experience)

with self.lock:
if self.count[0] < self.capacity:
self.count[0] += 1

# count keeps the exact length, but indexing starts from 0 so we decrease by 1
nr = self.count[0] - 1

self.states[nr] = torch.tensor(experience.state, dtype=torch.float32)
self.actions[nr] = torch.tensor(experience.action, dtype=torch.float32)
self.rewards[nr] = torch.tensor(experience.reward, dtype=torch.float32)
self.dones[nr] = torch.tensor(experience.done, dtype=torch.bool)
self.next_states[nr] = torch.tensor(experience.last_state, dtype=torch.float32)

else:
exit("RolloutCollector: Buffer is full but samples are being added to it")


def sample(self) -> Tuple:
"""Sample experience from buffer

Returns:
Tuple: Sampled experience
"""
states, actions, rewards, dones, next_states = zip(
*[self.replay_buffer[i] for i in range(len(self.replay_buffer))])

return (np.array(states), np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))
# count keeps the exact length, but indexing starts from 0 so we decrease by 1
nr = self.count[0] - 1
return (self.states[:nr], self.actions[:nr], self.rewards[:nr], self.dones[:nr], self.next_states[:nr])

def empty_buffer(self) -> None:
"""Empty replay buffer
"""Empty replay buffer by resetting the count (so old data gets overwritten)
"""
self.replay_buffer.clear()
with self.lock:
# the [0] is very important, otherwise we throw the tensor out and the int that replaces it won't get shared
self.count[0] = 0


class RLDataset(IterableDataset):
Expand All @@ -88,8 +127,9 @@ class RLDataset(IterableDataset):
net (nn.Module): Policy network
replay_buffer: Replay buffer
"""

def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP,
agent) -> None:
agent, episodes_per_batch: int = 1) -> None:
"""Summary

Args:
Expand All @@ -102,6 +142,7 @@ def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP,
self.env = env
self.net = net
self.agent = agent
self.episodes_per_batch = episodes_per_batch

def populate(self) -> None:
"""
Expand All @@ -119,7 +160,7 @@ def __iter__(self):
Yields:
Tuple: Sampled experience
"""
for i in range(1):
for i in range(self.episodes_per_batch):
self.populate()
states, actions, rewards, dones, new_states = self.replay_buffer.sample(
)
Expand Down
58 changes: 34 additions & 24 deletions squiRL/vpg/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""
import argparse
from argparse import ArgumentParser
from collections import OrderedDict
from copy import copy
from typing import Tuple, List

import gym
import numpy as np
import pytorch_lightning as pl
Expand All @@ -13,11 +15,10 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data._utils import collate
from collections import OrderedDict

from squiRL.common import reg_policies
from squiRL.common.data_stream import RLDataset, RolloutCollector
from squiRL.common.agents import Agent
from squiRL.common.data_stream import RLDataset, RolloutCollector


class VPG(pl.LightningModule):
Expand All @@ -44,11 +45,13 @@ def __init__(self, hparams: argparse.Namespace) -> None:
self.env = gym.make(self.hparams.env)
self.gamma = self.hparams.gamma
self.eps = self.hparams.eps
self.episodes_per_batch = self.hparams.episodes_per_batch
self.num_workers = hparams.num_workers
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n

self.net = reg_policies[self.hparams.policy](obs_size, n_actions)
self.replay_buffer = RolloutCollector(self.hparams.episode_length)
self.replay_buffer = RolloutCollector(self.hparams.episode_length, self.env.observation_space.shape, self.env.action_space.shape)

self.agent = Agent(self.env, self.replay_buffer)

Expand Down Expand Up @@ -84,6 +87,10 @@ def add_model_specific_args(
type=int,
default=20,
help="num of dataloader cpu workers")
parser.add_argument("--episodes_per_batch",
type=int,
default=1,
help="number of episodes to be sampled per training step")
return parser

def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor:
Expand All @@ -104,8 +111,8 @@ def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor:
res.append(copy(sum_r))
return list(reversed(res))

def vpg_loss(self, batch: Tuple[torch.Tensor,
torch.Tensor]) -> torch.Tensor:
def vpg_loss(self,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Calculates the loss based on the REINFORCE objective, using the
discounted
Expand All @@ -122,35 +129,42 @@ def vpg_loss(self, batch: Tuple[torch.Tensor,

action_logit = self.net(states.float())
log_probs = F.log_softmax(action_logit,
dim=-1).squeeze(0)[range(len(actions)),
actions]
dim=-1)[range(len(actions)),
actions.long()]



discounted_rewards = self.reward_to_go(rewards)
discounted_rewards = torch.tensor(discounted_rewards)
advantage = (discounted_rewards - discounted_rewards.mean()) / (
discounted_rewards.std() + self.eps)
discounted_rewards.std() + self.eps)
advantage = advantage.type_as(log_probs)

loss = -advantage * log_probs
return loss.sum()

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor],
nb_batch) -> OrderedDict:
def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
nb_batch) -> pl.TrainResult:
"""
Carries out an entire episode in env and calculates loss

Returns:
OrderedDict: Training step result

Args:
batch (Tuple[torch.Tensor, torch.Tensor]): Current mini batch of
replay data
batch (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]): Current
mini batch of replay data
nb_batch (TYPE): Current index of mini batch of replay data
"""
_, _, rewards, _, _ = batch
episode_reward = rewards.sum().detach()
loss = None
for episode in batch:
_, _, rewards, _, _ = episode
episode_reward = rewards.sum().detach()

loss = self.vpg_loss(batch)
if loss is None:
loss = self.vpg_loss(episode)
else:
loss += self.vpg_loss(episode)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)
Expand Down Expand Up @@ -191,25 +205,21 @@ def collate_fn(self, batch):
"""
batch = collate.default_convert(batch)

states = torch.cat([s[0] for s in batch])
actions = torch.cat([s[1] for s in batch])
rewards = torch.cat([s[2] for s in batch])
dones = torch.cat([s[3] for s in batch])
next_states = torch.cat([s[4] for s in batch])

return states, actions, rewards, dones, next_states
return batch

def __dataloader(self) -> DataLoader:
"""Initialize the RL dataset used for retrieving experiences

Returns:
DataLoader: Handles loading the data for training
"""
dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent)
dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent, self.episodes_per_batch)
dataloader = DataLoader(
dataset=dataset,
collate_fn=self.collate_fn,
batch_size=1,
batch_size=self.episodes_per_batch,
num_workers=self.num_workers,
pin_memory=True
)
return dataloader

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main(hparams) -> None:
if hparams.debug:
hparams.logger = None
hparams.profiler = True
hparams.num_workers = None
hparams.num_workers = 0
else:
hparams.logger = WandbLogger(project=hparams.project)
seed_everything(hparams.seed)
Expand Down