Here is some minimal code that implements Reinforcement Learning (dynamic programming using Value Iteration) for a grid world problem. Including comments, it’s less than 100 lines of code.
The code could be modified to use temporal-difference learning with a few small changes.
// -----------------------------------------------------------
// Minimal Reinforcement Learning in Scala
// -----------------------------------------------------------
// Problem: GridWorld with a goal terminal state
// Algorithm: Value Iteration
// Storage: Lookup table
// -----------------------------------------------------------
// Patryk Laurent (http://pakl.net)
// -----------------------------------------------------------
object SimpleRL extends App
{
case class State(x:Int, y:Int)
case class Action(deltaX:Int, deltaY:Int)
val maxX = 10; val maxY = 10;
val world = for (i <- 0 to maxX; j <- 0 to maxY) yield new State(i,j);
val goal = new State(0,0)
val north = new Action( 0,-1);
val south = new Action( 0, 1);
val east = new Action( 1, 0);
val west = new Action(-1, 0);
val possibleActions = List(north,south,east,west);
val epsilon = 0.8;
val discountFactor = 1.0;
val initialValue = 0.0;
var vf: Map[State,Double] = Map(world map (eachLocation => (eachLocation, initialValue)) : _*);
var newvf:Map[State,Double] = Map(world map (eachLocation => (eachLocation, initialValue)) : _*);
def rf(arrivedAt:State) = { -1.0 } // for each move
def nextState(s:State,a:Action) = { new State(s.x+a.deltaX, s.y+a.deltaY) }
def outcomeFor(s:State,a:Action) = { rf(nextState(s,a)) + discountFactor * vf(nextState(s,a)) }
def isTerminal(s:State):Boolean = { if (s == State(0,0)) return true; else return false; }
def isAllowed(s:State,a:Action):Boolean= { // Disallow movements outside the grid
val next = nextState(s,a);
if (next.x < 0 || next.x > maxX) return false;
if (next.y < 0 || next.y > maxY) return false;
return true;
}
// -----------------------------------------------------------
// Perform Value iteration from each non-terminal state
// -----------------------------------------------------------
(1 to 10000).seq.foreach(trial => {
(world filter (s => !isTerminal(s))) foreach (s => {
val possible = possibleActions filter (a => isAllowed(s,a));
val bestAction = possible.maxBy(outcomeFor(s,_));
val delta = outcomeFor(s, bestAction) - vf(s);
newvf += s -> (vf(s)+epsilon*delta);
})
val temp = vf; vf = newvf; newvf = temp; // Swap
})
// -----------------------------------------------------------
// Display value function
// -----------------------------------------------------------
println("Learned value function (values for for each (x,y) in the grid):");
for (x <- 0 to maxX) {
for (y <- 0 to maxY) {
print(vf(new State(x,y)) + " ");
}
println("");
}
// -----------------------------------------------------------
// Define how to make a policy from a value function.
// -----------------------------------------------------------
def policy(world:List[State], vf:Map[State,Double], possibleActions:List[Action])=
{
var result:Map[State,Action] = Map()
val nonTerminals = world filter (s => !isTerminal(s));
nonTerminals foreach (s => {
val allowedActions = possibleActions filter (a => isAllowed(s,a));
val bestAction = allowedActions.maxBy(outcomeFor(s,_));
result += (s -> bestAction)
})
result
}
val optimalPolicy = policy(world.toList, vf, possibleActions.toList);
println("Policy Demo: 'From 5,5 the best action is " + optimalPolicy.get(new State(5,5)).get+"'");
}