Reinforcement Learning on Hamster

In [1]:
import gym
import gym.spaces
import time
import numpy as np
from collections import defaultdict, deque
import sys
env = gym.make('HamsterExperiment-v0')
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/special/__init__.py:640: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._ufuncs import *
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/linalg/basic.py:17: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._solve_toeplitz import levinson
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/linalg/__init__.py:207: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._decomp_update import *
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/special/_ellip_harm.py:7: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._ellip_harm_2 import _ellipsoid, _ellipsoid_norm
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/interpolate/_bsplines.py:10: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from . import _bspl
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/lil.py:19: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from . import _csparsetools
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/csgraph/__init__.py:165: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._shortest_path import shortest_path, floyd_warshall, dijkstra,\
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/csgraph/_validation.py:5: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._tools import csgraph_to_dense, csgraph_from_dense,\
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/csgraph/__init__.py:167: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._traversal import breadth_first_order, depth_first_order, \
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/csgraph/__init__.py:169: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._min_spanning_tree import minimum_spanning_tree
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/sparse/csgraph/__init__.py:170: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._reordering import reverse_cuthill_mckee, maximum_bipartite_matching, \
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/spatial/__init__.py:95: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from .ckdtree import *
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/spatial/__init__.py:96: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from .qhull import *
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/spatial/_spherical_voronoi.py:18: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from . import _voronoi
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/spatial/distance.py:122: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from . import _hausdorff
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/optimize/_trlib/__init__.py:1: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._trlib import TRLIBQuadraticSubproblem
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/optimize/_numdiff.py:10: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from ._group_columns import group_dense, group_sparse
/Users/admin/miniconda2/lib/python2.7/site-packages/scipy/stats/_continuous_distns.py:18: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  from . import _stats

Sarsa Algorithm

In [2]:
from IPython.display import clear_output
def sarsa(env, num_episodes, alpha, gamma=1.0):
    # initialize action-value function (empty dictionary of arrays)
    Q = defaultdict(lambda: np.zeros(env.nA))
    # initialize performance monitor
    # loop over episodes
    for i_episode in range(1, num_episodes+1):
        # monitor progress
        if i_episode % 100 == 0:
            
            clear_output(wait=True)
#             print("\rEpisode {}/{}".format(i_episode, num_episodes), end="")
            print("\rEpisode {}/{}.".format(i_episode, num_episodes),)
            sys.stdout.flush()   
        
        ## TODO: complete the function
        
        # Set epsilon
        epsilon = 1./(1.+i_episode)
        # Observe S_0, t<-0
        state = env.reset()
        # Get A_0 from Q (epsilon-greedy policy) for this state
        policy_state = epsilon_greedy_policy(env, Q[state], epsilon)
        action = np.random.choice(np.arange(env.nA), p=policy_state)
        
        max_demo = 1000
        
        # Repeat until terminal state reached
        while True:
            # Take A_t, get R_(t+1), S_(t+1)
            state_next, reward, done, info = env.step(action)
            
            # Choose A_(t+1) from Q (from policy for S_(t+1))
            policy_state = epsilon_greedy_policy(env, Q[state_next], epsilon)
            action_next = np.random.choice(np.arange(env.nA), p=policy_state)
            # Get G_t
            G_t = reward + gamma*Q[state_next][action_next]
            # Update action value function
            Q[state][action] = Q[state][action] + alpha*(G_t - Q[state][action])
#             print (Q[state][action])
            
            # print
        
            if 2 > i_episode > -1 and max_demo > 0:
                print("Demo:", max_demo)
                max_demo -= 1
                clear_output(wait=True)
                env.render()
                sys.stdout.flush()
            if i_episode > num_episodes-10:
                clear_output(wait=True)
                env.render()
                sys.stdout.flush()
            
            # Check if reached terminal state
            if done:
                break
            # Update state & action for next step
            state = state_next
            action = action_next
        
    return Q

def epsilon_greedy_policy(env, Q_state, epsilon):
    # Get greedy action (gives highest Q for state)
    greedy_action = np.argmax(Q_state)
    # Get number of possible actions     
    nA = env.nA
    # Use epsilon to get probability distribution to use in policy for state
    policy_state = np.ones(nA) * epsilon / nA
    policy_state[greedy_action] = 1 - epsilon + (epsilon / nA)
    return policy_state

def greedy_policy(env, Q_state):
    # Get greedy action (gives highest Q for state)
    greedy_action = np.argmax(Q_state)
    # Get number of possible actions     
    nA = env.nA
    policy_state = np.zeros(nA)
    policy_state[greedy_action] = 1
    return policy_state

def final_policy(env, Q):
    # Observe S_0, t<-0
    coords = [env.get_start()]
    state = env.reset()
    # Get A_0 from Q (epsilon-greedy policy) for this state
    policy_state = greedy_policy(env, Q[state])
    action = np.random.choice(np.arange(env.nA), p=policy_state)
        
    # Repeat until terminal state reached
    while True:
        # Take A_t, get R_(t+1), S_(t+1)
        state_next, reward, done, info = env.step(action)
            
        # Choose A_(t+1) from Q (from policy for S_(t+1))
        policy_state = greedy_policy(env, Q[state_next])
        action_next = np.random.choice(np.arange(env.nA), p=policy_state)
            
        coords.append(np.unravel_index(state_next, env.get_shape()))
        
        clear_output(wait=True)
        env.render()
        time.sleep(0.5)
        sys.stdout.flush()
        # Check if reached terminal state
        if done:
            break
        # Update state & action for next step
        state = state_next
        action = action_next
        
    return coords

Training Time

In [3]:
# obtain the estimated optimal policy and corresponding action-value function
Q_sarsa = sarsa(env, 5000, .01)

# print the estimated optimal policy
policy_sarsa = np.array([np.argmax(Q_sarsa[key]) if key in Q_sarsa else -1 for key in np.arange(7*12)]).reshape(7,12)
print("\nEstimated Optimal Policy (UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3, N/A = -1):")
print(policy_sarsa)


# transform quality map into actions
# start = env.get_start()
# for col in policy_sarsa:
#     for row in col:
#         if (col,row) == start:
list_of_coords = final_policy(env, Q_sarsa)
print(list_of_coords)

# plot the estimated optimal state-value function
V_sarsa = ([np.max(Q_sarsa[key]) if key in Q_sarsa else 0 for key in np.arange(48)])
v1.1
□  □  □  □  □  □  □  □  □  □  □  □
□  □  □  □  □  □  □  □  □  □  □  □
□  □  □  □  □  □  □  □  □  □  □  □
□  □  □  □  □  □  □  □  □  □  □  □
□  □  □  □  □  □  □  □  □  □  □  □
□  ×  ×  ×  ×  ×  ×  ×  ×  ×  ×  ★
□  □  □  □  □  □  □  □  □  □  □  □

[(5, 0), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (5, 11)]

Hamster Connection

In [4]:
import pickle
outfile = '/Users/admin/Documents/gym/gym/projects/data.txt'

with open(outfile, 'wb') as fp:
    pickle.dump(list_of_coords, fp)
    
with open (outfile, 'rb') as fp:
    test = pickle.load(fp)
    print(str(test))
[(5, 0), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (5, 11)]
In [ ]: