Add Policy Gradient to Reinforcement Learning examples (#1500)

* added policy_gradient, modified main, ddpg and README

* fixed typo in README

* removed unnecessary imports

* small refactor

* Use clap for picking up the subcommand to run.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
s-casci 2023-12-30 09:01:29 +01:00 committed by GitHub
parent 0a245e6fa4
commit 51e577a682
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 275 additions and 124 deletions

View File

@ -8,9 +8,16 @@ Python package with:
pip install "gymnasium[accept-rom-license]"
```
In order to run the example, use the following command. Note the additional
In order to run the examples, use the following commands. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate.
For the Policy Gradient example:
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
```
For the Deep Deterministic Policy Gradient example:
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
```

View File

@ -8,6 +8,8 @@ use candle_nn::{
};
use rand::{distributions::Uniform, thread_rng, Rng};
use super::gym_env::GymEnv;
pub struct OuNoise {
mu: f64,
theta: f64,
@ -449,3 +451,106 @@ impl DDPG<'_> {
Ok(())
}
}
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
pub fn run() -> Result<()> {
let env = GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}

View File

@ -6,139 +6,32 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::Result;
use clap::{Parser, Subcommand};
mod gym_env;
mod vec_gym_env;
mod ddpg;
mod policy_gradient;
use candle::{Device, Result, Tensor};
use clap::Parser;
use rand::Rng;
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[command(subcommand)]
command: Command,
}
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[derive(Subcommand)]
enum Command {
Pg,
Ddpg,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let env = gym_env::GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = ddpg::DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
match args.command {
Command::Pg => policy_gradient::run()?,
Command::Ddpg => ddpg::run()?,
}
Ok(())
}

View File

@ -0,0 +1,146 @@
use super::gym_env::{GymEnv, Step};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::{
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
ParamsAdamW, VarBuilder, VarMap,
};
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
fn new_model(
input_shape: &[usize],
num_actions: usize,
dtype: DType,
device: &Device,
) -> Result<(impl Module, VarMap)> {
let input_size = input_shape.iter().product();
let mut varmap = VarMap::new();
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
let model = seq()
.add(linear(input_size, 32, var_builder.pp("lin1"))?)
.add(Activation::Relu)
.add(linear(32, num_actions, var_builder.pp("lin2"))?);
Ok((model, varmap))
}
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
let mut acc_reward = 0f64;
for (i, reward) in rewards.iter_mut().enumerate().rev() {
if steps[i].terminated {
acc_reward = 0.0;
}
acc_reward += *reward;
*reward = acc_reward;
}
rewards
}
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
let mut rng = rng;
Ok(distribution.sample(&mut rng))
}
pub fn run() -> Result<()> {
let env = GymEnv::new("CartPole-v1")?;
println!("action space: {:?}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let (model, varmap) = new_model(
env.observation_space(),
env.action_space(),
DType::F32,
&Device::Cpu,
)?;
let optimizer_params = ParamsAdamW {
lr: 0.01,
weight_decay: 0.01,
..Default::default()
};
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let mut rng = rand::thread_rng();
for epoch_idx in 0..100 {
let mut state = env.reset(rng.gen::<u64>())?;
let mut steps: Vec<Step<i64>> = vec![];
loop {
let action = {
let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
};
let step = env.step(action)?;
steps.push(step.copy_with_obs(&state));
if step.terminated || step.truncated {
state = env.reset(rng.gen::<u64>())?;
if steps.len() > 5000 {
break;
}
} else {
state = step.state;
}
}
let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
let episodes: i64 = steps
.iter()
.map(|s| (s.terminated || s.truncated) as i64)
.sum();
println!(
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
epoch_idx,
episodes,
total_reward / episodes as f64
);
let batch_size = steps.len();
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
.detach()?;
let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
let actions_mask: Vec<Tensor> = actions
.iter()
.map(|&action| {
// One-hot encoding
let mut action_mask = vec![0.0; env.action_space()];
action_mask[action as usize] = 1.0;
Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
.unwrap()
.to_dtype(DType::F32)
.unwrap()
})
.collect();
Tensor::stack(&actions_mask, 0)?.detach()?
};
let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach()?
};
let log_probs = actions_mask
.mul(&log_softmax(&model.forward(&states)?, 1)?)?
.sum(1)?;
let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
optimizer.backward_step(&loss)?;
}
Ok(())
}