Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ __pycache__/
data/
report.html
models/tinyphysics_*.onnx
*.swp
venv/
comma_steer_command/
PPO_control/checkpoints
PPO_control/results
PPO_control/training_logs
*.swp
109 changes: 109 additions & 0 deletions PPO_control/ActorCritic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
from typing import Tuple, List, Union

import torch
import torch.nn as nn
from torch.distributions import Normal
import numpy as np

logger = logging.getLogger(__name__)

class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, obs_seq_len: int, target_dim: int, action_dim: int, has_continuous_action: bool, action_scale: float=1):
super().__init__()
self.obs_dim = obs_dim
self.obs_seq_len = obs_seq_len
self.target_dim = target_dim
self.action_dim = action_dim
self.has_continuous_action = has_continuous_action
self.action_scale = action_scale
self.rng = np.random.default_rng()

self.feature_lstm = nn.LSTM(input_size=self.obs_dim, hidden_size=16, num_layers=2, batch_first=True)

if self.has_continuous_action:
self.actor = nn.Sequential(
nn.Linear(16 + self.target_dim, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, self.action_dim),
nn.Tanh()
)
self.log_std = nn.Parameter(torch.zeros((self.action_dim), dtype=torch.float32))
else:
self.actor = nn.Sequential(
nn.Linear(16 + self.target_dim, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, self.action_dim),
nn.Softmax(dim=-1)
)

self.critic = nn.Sequential(
nn.Linear(16 + self.target_dim, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, 128),
nn.Tanh(),
nn.Linear(128, 1)
)

# Init all hidden layer weight to orthogonal weight with scale sqrt(2), bias to 2
# Init value output layer weight with scale 1, policy output layer with scale 0.01
for param in self.named_parameters():
if param[0].endswith('weight'):
if param[0].startswith('feature_lstm'):
nn.init.orthogonal_(param[1], gain=1)
else:
nn.init.orthogonal_(param[1], gain=np.sqrt(2))
elif param[0].endswith('bias'):
nn.init.zeros_(param[1])
nn.init.orthogonal_(list(self.critic.parameters())[-2], gain=1)
nn.init.orthogonal_(list(self.actor.parameters())[-2], gain=0.01)

def forward(self, past_obs, target):
condition, _ = self.feature_lstm(past_obs)
condition = condition[:, -1]
x = torch.hstack([condition, target])
action_logit = self.actor(x)
value = self.critic(x)

if self.has_continuous_action:
return Normal(action_logit.flatten(), torch.exp(self.log_std)), value
else:
return action_logit, value

def act(self, past_obs: torch.Tensor, target: torch.Tensor, eval: bool=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
with torch.no_grad():
self.eval()
action_logit, value = self.forward(past_obs, target)

if self.has_continuous_action:
# Continuous case
if eval:
action = action_logit.mean
action_prob = torch.exp(action_logit.log_prob(action)).cpu().numpy()
action = action.cpu().numpy() * self.action_scale
else:
action = action_logit.sample()
action_prob = torch.exp(action_logit.log_prob(action)).cpu().numpy()
action = action.cpu().numpy() * self.action_scale
else:
# Discrete case
action_logit = action_logit.cpu().numpy()
if eval:
action = action_logit.argmax(axis=1)
action_prob = action_logit[np.arange(len(action)), action]
else:
action = (action_logit.cumsum(axis=1) > self.rng.random(action_logit.shape[0])[:, np.newaxis]).argmax(axis=1) # Inverse transform sampling
action_prob = action_logit[np.arange(len(action)), action]

return action, action_prob, value.flatten().cpu().numpy()

70 changes: 70 additions & 0 deletions PPO_control/ExperienceBuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from typing import Union, List, Tuple

import torch
from torch.utils.data import Dataset
import numpy as np

from PPO_Loss import calculate_value_target_vec, generalized_advantage_estimation_vec

logger = logging.getLogger(__name__)

class PPOExperienceBuffer(Dataset):
def __init__(self, discount_factor: float=0.99, td_decay: float=0.9) -> None:
super().__init__()
self.actions = []
self.observations = []
self.targets = []
self.samples = torch.zeros((0, 4))
self.discount_factor = discount_factor
self.td_decay = td_decay

def __len__(self) -> int:
return self.samples.shape[0]

def __getitem__(self, idx) -> Tuple[Union[int, float], torch.Tensor, torch.Tensor, torch.Tensor]:
return self.actions[idx], self.observations[idx], self.targets[idx], self.samples[idx]

def batch_add_trajectory(self, observation: torch.Tensor, target: torch.Tensor, action: torch.Tensor, action_p: torch.Tensor, reward: torch.Tensor, value_estimate: torch.Tensor) -> None:
"""
Add a batch of trajectories to the buffer.

Args:
observation (batch_size, episode_len): Batch observation at each time step
target (batch_size, episode_len): Batch target at each time step
action (batch_size, episode_len): Batch actions taken during the episode.
action_p (batch_size, episode_len): Batch action probabilities corresponding to each action.
reward (batch_size, episode_len): Batch rewards received for each step.
value_estimate (batch_size, episode_len+1): Batch value estimates, including the estimate for the final state.

Raises:
ValueError: If input lengths are inconsistent with the episode length.

Note:
The length of value_estimate should be one more than the other inputs to include the final state estimate.
"""

episode_len = action.shape[1]

if action_p.shape[1] != episode_len:
raise ValueError(f'action_p length mismatch. episode length: {episode_len}, action_p length: {action_p.shape[1]}')
if reward.shape[1] != episode_len:
raise ValueError(f'reward length mismatch. episode length: {episode_len}, reward length: {reward.shape[1]}')
if value_estimate.shape[1] != episode_len + 1:
raise ValueError(f'value_estimate length mismatch. expected length: {episode_len + 1}, actual length: {len(value_estimate)}')

self.actions += action.flatten().tolist()
self.observations += [x for x in observation.flatten(0, 1)]
self.targets += [x for x in target.flatten(0, 1)]

value_target = calculate_value_target_vec(reward, value_estimate, self.discount_factor)
gae = generalized_advantage_estimation_vec(reward, value_estimate, self.discount_factor, self.td_decay)

batch_trajectory = torch.cat([action_p.unsqueeze(2), value_estimate[:, :-1].unsqueeze(2), value_target.unsqueeze(2), gae.unsqueeze(2)], dim=2)
self.samples = torch.vstack([self.samples, batch_trajectory.flatten(0, 1)])

def reset(self):
self.actions = []
self.observations = []
self.targets = []
self.samples = torch.zeros((0, 4))
92 changes: 92 additions & 0 deletions PPO_control/PPO_Loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Tuple

import torch

def actor_loss(action_p: torch.Tensor, action_p_old: torch.Tensor, advantage: torch.Tensor, clip_eps: float) -> torch.Tensor:
eps = 1e-8
r = action_p / (action_p_old + eps)
L_clip = torch.mean(torch.min(r * advantage, torch.clip(r, 1 - clip_eps, 1 + clip_eps) * advantage))
return -L_clip

def value_loss(value: torch.Tensor, value_target: torch.Tensor) -> torch.Tensor:
L_value = torch.mean((value - value_target) ** 2)
return L_value

def generalized_advantage_estimation(reward: torch.Tensor, value: torch.Tensor, discount_factor: float, decay: float) -> torch.Tensor:
'''
reward: batched rewards for each time step (batch_size, T)
value: batched value estimation for each time step plus last state (batch_size, T+1)
'''
eps = 1e-8

advantage = torch.zeros_like(reward)
last_adv = 0
for t in reversed(range(reward.shape[1])):
delta = reward[:, t] + discount_factor * value[:, t + 1] - value[:, t]
advantage[:, t] = delta + discount_factor * decay * last_adv
last_adv = advantage[:, t]

# Normalize across batch
advantage = (advantage - advantage.mean()) / (advantage.std() + eps)

return advantage

def calculate_value_target(reward: torch.Tensor, value: torch.Tensor, discount_factor: float) -> torch.Tensor:
'''
reward: batched rewards for each time step (batch_size, T)
value: batched value estimation for each time step plus last state (batch_size, T+1)
'''
eps = 1e-8

v_target = torch.zeros_like(reward)
last_reward = value[:, -1]
for t in reversed(range(reward.shape[1])):
v_target[:, t] = reward[:, t] + discount_factor * last_reward
last_reward = v_target[:, t]

v_target = (v_target - v_target.mean()) / (v_target.std() + eps)

return v_target

def calculate_value_target_vec(reward: torch.Tensor, value: torch.Tensor, discount_factor: float) -> torch.Tensor:
'''
reward: batched rewards for each time step (batch_size, T)
value: batched value estimation for each time step plus last state (batch_size, T+1)
'''
T = reward.shape[1]
eps = 1e-8

# Calculate discounted sum of rewards
discount_factors = discount_factor ** torch.arange(T, dtype=torch.float32, device=reward.device)
future_discounted_rewards = torch.cumsum((reward * discount_factors.unsqueeze(0)).flip(1), dim=1).flip(1)
v_target = future_discounted_rewards / discount_factors

# Add the discounted final value estimate
v_target = v_target + torch.outer(value[:, -1], (discount_factor * discount_factors).flip(0))

v_target = (v_target - v_target.mean()) / (v_target.std() + eps)

return v_target

def generalized_advantage_estimation_vec(reward: torch.Tensor, value: torch.Tensor, discount_factor: float, decay: float) -> torch.Tensor:
'''
reward: batched rewards for each time step (batch_size, T)
value: batched value estimation for each time step plus last state (batch_size, T+1)
'''
T = reward.shape[1]
eps = 1e-8

# Calculate delta
delta = reward + discount_factor * value[:, 1:] - value[:, :-1]

# Calculate discount factors
discount_factors = (discount_factor * decay) ** torch.arange(T, dtype=torch.float32, device=reward.device)

# Calculate advantages
advantage = torch.cumsum((delta * discount_factors.unsqueeze(0)).flip(1), dim=1).flip(1)
advantage = advantage / discount_factors.unsqueeze(0)

# Normalize across batch
advantage = (advantage - advantage.mean()) / (advantage.std() + eps)

return advantage
Loading
Loading