Skip to content

Julia implementations of temporal difference Reinforcement Learning algorithms like Q-Learning and SARSA

License

Notifications You must be signed in to change notification settings

JuliaPOMDP/TabularTDLearning.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

78 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TabularTDLearning

CI codecov

This repository provides Julia implementations of the following Temporal-Difference reinforcement learning algorithms:

  • Q-Learning
  • SARSA
  • SARSA lambda
  • Prioritized Sweeping

Note that these solvers are tabular, and will only work with MDPs that have discrete state and action spaces.

Installation

Pkg.add("TabularTDLearning")

Example

using POMDPs
using TabularTDLearning
using POMDPModels
using POMDPTools

mdp = SimpleGridWorld()
# use Q-Learning
exppolicy = EpsGreedyPolicy(mdp, 0.01)
solver = QLearningSolver(exploration_policy=exppolicy, learning_rate=0.1, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100)
policy = solve(solver, mdp)
# Use SARSA
solver = SARSASolver(exploration_policy=exppolicy, learning_rate=0.1, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100)
policy = solve(solver, mdp)
# Use SARSA lambda
solver = SARSALambdaSolver(exploration_policy=exppolicy, learning_rate=0.1, lambda=0.9, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100)
policy = solve(solver, mdp)
# Use Prioritized Sweeping
mdp_ps = SimpleGridWorld(tprob=1.0)
solver = PrioritizedSweepingSolver(exploration_policy=exppolicy, learning_rate=0.1, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100,pq_threshold=0.5)
policy = solve(solver,mdp_ps)

About

Julia implementations of temporal difference Reinforcement Learning algorithms like Q-Learning and SARSA

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages