GitHub開源項目代碼分析:淺談強化學習中的CartPole

GitHub開源項目代碼分析:淺談強化學習中的CartPole

今天給大家講解一下GitHub上一個叫OpenAI Gym的強化學習環境的代碼示例。希望能夠幫助大家理解強化學習是如何用代碼實現的。

什麼是CartPole

CartPole又叫倒立擺。如下圖,小車上放了一根桿,桿會因重力而倒下。我們要通過移動小車保持桿樹立,不讓其倒下。

腳本文件

CartPole的SAPR

(關於什麼是SAPR可以參考強化學習(Reinforcement Learning)知識整理)

S(State)

查看cem.py的ob的值我們發現,State是一個4個數字組成的數組,這4個數字分別對應了倒立擺兩端的坐標。

A(Action)

在cartpole.py的setp函數中,我們看到CartPole只有兩個Action = (0, 1)。

if Action == 0: 小車受到的力向左elif Action == 1: 小車受到的力向右

P(Policy)

在_policies.py的BinaryActionLinearPolicy的act函數中,我們把Policy看成是關於Action的函數。

y = wx + bif y >= 0: Action = 0else: Action = 1

學過線代和機械學系的同學看了是不是很熟悉,就是Y = f(權重 * X + 截距)。這裡的X就是上面說的坐標。

R(Reward)

還是在setp函數中,我們看到

if 倒立擺倒了: reward += 0else reward += 1

在SAPR四要素中,SAR都是gym已經封裝好的內容,並不需要我們去操心。而且和寫遊戲一樣,不需要機械學習的知識一樣可以自定義一個學習環境。可以是Pong,可以是跳一跳也可以是其他環境。

如何決定Policy

本文的重點來了,我們知道強化學習要滿足貝爾曼方程從而找出在每一步回報最大的行動。那解貝爾曼方程就不就能得到結果了嗎?難道要用代碼解貝爾曼方程啦!?好激動啊!

但實際上並沒有!

cem採用了遺傳學演算法

1. 設定初始值th_mean = [0, 0, 0, 0, 0],th_std=1。2. 用式1生成25個長度為5的數組,前4個數為w,最後1個數字是b。 ths = [rnd, rnd, rnd, rnd, rnd] * th_std + [th_mean0, th_mean1, th_mean2, th_mean3, th_mean4](式1)3. 將25個ths分別帶入環境模擬器,計算在這個策略下的reward。(參考上面的SAPR)4. 挑選reward最高的5個樣本, th_mean = 5個樣本的平均值 th_std = 5個樣本的標準偏差5. 重複步驟2到步驟4十遍(即episode = 10)

這樣我們一般能夠在episode = 8的時候達到reward = 200。

通過每一次篩選成績最好的5個基因,我們每一次都在向最優解靠攏。再在這些樣本的基因上加上了一些隨機性同時保證了探索空間。

代碼實現

我把源代碼的重新排序並簡化了,讓代碼能夠閱讀上更直觀一些。

env = gym.make("CartPole-v0")n_iter=10batch_size=25elite_frac = 0.2th_mean = np.zeros(env.observation_space.shape[0]+1)initial_std = 1.0n_elite = int(np.round(batch_size*elite_frac))th_std = np.ones_like(th_mean) * initial_stddef noisy_evaluation(theta): agent = BinaryActionLinearPolicy(theta) # 根據權重重新制定策略 rew, T = do_rollout(agent, env, num_steps) # 計算reward return rewfor x in range(n_iter): # 更新權重 ths = np.array([th_mean + dth for dth in th_std[None,:]*np.random.randn(batch_size, th_mean.size)]) ys = np.array([noisy_evaluation(th) for th in ths]) # 計算reward elite_inds = ys.argsort()[::-1][:n_elite] elite_ths = ths[elite_inds] # 把ths按reward排序 th_mean = elite_ths.mean(axis=0) th_std = elite_ths.std(axis=0)

源代碼

我把上面提到的腳本貼在這裡了,有興趣的同學也可以直接上GitHub去看看。

# gym/examples/agents/cem.pyfrom __future__ import print_functionimport gymfrom gym import wrappers, loggerimport numpy as npfrom six.moves import cPickle as pickleimport json, sys, osfrom os import pathfrom _policies import BinaryActionLinearPolicy # Different file so it can be unpickledimport argparsedef cem(f, th_th_mean, batch_size, n_iter, elite_frac, initial_std=1.0): """ Generic implementation of the cross-entropy method for maximizing a black-box function f: a function mapping from vector -> scalar th_th_mean: initial th_mean over input distribution 初始[0., 0., 0., 0., 0.] batch_size: number of samples of theta to evaluate per batch n_iter: number of batches elite_frac: each batch, select this fraction of the top-performing samples initial_std: initial standard deviation over parameter vectors """ n_elite = int(np.round(batch_size*elite_frac)) th_std = np.ones_like(th_th_mean) * initial_std for _ in range(n_iter): ths = np.array([th_th_mean + dth for dth in th_std[None,:]*np.random.randn(batch_size, th_th_mean.size)]) ys = np.array([f(th) for th in ths]) elite_inds = ys.argsort()[::-1][:n_elite] elite_ths = ths[elite_inds] th_th_mean = elite_ths.th_mean(axis=0) th_std = elite_ths.std(axis=0) yield {ys : ys, theta_th_mean : th_th_mean, y_th_mean : ys.th_mean()}def do_rollout(agent, env, num_steps, render=False): total_rew = 0 ob = env.reset() for t in range(num_steps): a = agent.act(ob) (ob, reward, done, _info) = env.step(a) total_rew += reward if render and t%3==0: env.render() if done: break return total_rew, t+1if __name__ == __main__: logger.set_level(logger.INFO) parser = argparse.ArgumentParser() parser.add_argument(--display, action=store_true) parser.add_argument(target, nargs="?", default="CartPole-v0") args = parser.parse_args() env = gym.make(args.target) env.seed(0) np.random.seed(0) params = dict(n_iter=10, batch_size=25, elite_frac = 0.2) num_steps = 200 # You provide the directory to write to (can be an existing # directory, but cant contain previous monitor results. You can # also dump to a tempdir if youd like: tempfile.mkdtemp(). outdir = /tmp/cem-agent-results env = wrappers.Monitor(env, outdir, force=True) # Prepare snapshotting # ---------------------------------------- def writefile(fname, s): with open(path.join(outdir, fname), w) as fh: fh.write(s) info = {} info[params] = params info[argv] = sys.argv info[env_id] = env.spec.id # ------------------------------------------ def noisy_evaluation(theta): agent = BinaryActionLinearPolicy(theta) rew, T = do_rollout(agent, env, num_steps) return rew # Train the agent, and snapshot each stage for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0]+1), **params)): # 還能夠這麼遞交參數 print(Iteration %2i. Episode th_mean reward: %7.3f%(i, iterdata[y_th_mean])) agent = BinaryActionLinearPolicy(iterdata[theta_th_mean]) if args.display: do_rollout(agent, env, 200, render=True) writefile(agent-%.4i.pkl%i, str(pickle.dumps(agent, -1))) # Write out the env at the end so we store the parameters of this # environment. writefile(info.json, json.dumps(info)) env.close()

# gym/gym/envs/classic_control/cartpole.py"""Classic cart-pole system implemented by Rich Sutton et al.Copied from http://incompleteideas.net/sutton/book/code/pole.cpermalink: incompleteideas.net - pole.c"""import mathimport gymfrom gym import spaces, loggerfrom gym.utils import seedingimport numpy as npclass CartPoleEnv(gym.Env): metadata = { render.modes: [human, rgb_array], video.frames_per_second : 50 } def __init__(self): self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 self.total_mass = (self.masspole + self.masscart) self.length = 0.5 # actually half the poles length self.polemass_length = (self.masspole * self.length) self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates # Angle at which to fail the episode self.theta_threshold_radians = 12 * 2 * math.pi / 360 self.x_threshold = 2.4 # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds high = np.array([ self.x_threshold * 2, np.finfo(np.float32).max, self.theta_threshold_radians * 2, np.finfo(np.float32).max]) self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(-high, high) self.seed() self.viewer = None self.state = None self.steps_beyond_done = None def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] def step(self, action): assert self.action_space.contains(action), "%r (%s) invalid"%(action, type(action)) state = self.state x, x_dot, theta, theta_dot = state force = self.force_mag if action==1 else -self.force_mag costheta = math.cos(theta) sintheta = math.sin(theta) temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass thetaacc = (self.gravity * sintheta - costheta* temp) / (self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass)) xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass x = x + self.tau * x_dot x_dot = x_dot + self.tau * xacc theta = theta + self.tau * theta_dot theta_dot = theta_dot + self.tau * thetaacc self.state = (x,x_dot,theta,theta_dot) done = x < -self.x_threshold or x > self.x_threshold or theta < -self.theta_threshold_radians or theta > self.theta_threshold_radians done = bool(done) if not done: reward = 1.0 elif self.steps_beyond_done is None: # Pole just fell! self.steps_beyond_done = 0 reward = 1.0 else: if self.steps_beyond_done == 0: logger.warn("You are calling step() even though this environment has already returned done = True. You should always call reset() once you receive done = True -- any further steps are undefined behavior.") self.steps_beyond_done += 1 reward = 0.0 return np.array(self.state), reward, done, {} def reset(self): self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) self.steps_beyond_done = None return np.array(self.state) def render(self, mode=human): screen_width = 600 screen_height = 400 world_width = self.x_threshold*2 scale = screen_width/world_width carty = 100 # TOP OF CART polewidth = 10.0 polelen = scale * 1.0 cartwidth = 50.0 cartheight = 30.0 if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(screen_width, screen_height) l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2 axleoffset =cartheight/4.0 cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) self.carttrans = rendering.Transform() cart.add_attr(self.carttrans) self.viewer.add_geom(cart) l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2 pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) pole.set_color(.8,.6,.4) self.poletrans = rendering.Transform(translation=(0, axleoffset)) pole.add_attr(self.poletrans) pole.add_attr(self.carttrans) self.viewer.add_geom(pole) self.axle = rendering.make_circle(polewidth/2) self.axle.add_attr(self.poletrans) self.axle.add_attr(self.carttrans) self.axle.set_color(.5,.5,.8) self.viewer.add_geom(self.axle) self.track = rendering.Line((0,carty), (screen_width,carty)) self.track.set_color(0,0,0) self.viewer.add_geom(self.track) if self.state is None: return None x = self.state cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART self.carttrans.set_translation(cartx, carty) self.poletrans.set_rotation(-x[2]) return self.viewer.render(return_rgb_array = mode==rgb_array) def close(self): if self.viewer: self.viewer.close()

# gym/examples/agents/_policies.py# Support code for cem.pyclass BinaryActionLinearPolicy(object): def __init__(self, theta): self.w = theta[:-1] self.b = theta[-1] def act(self, ob): y = ob.dot(self.w) + self.b a = int(y < 0) return aclass ContinuousActionLinearPolicy(object): def __init__(self, theta, n_in, n_out): assert len(theta) == (n_in + 1) * n_out self.W = theta[0 : n_in * n_out].reshape(n_in, n_out) self.b = theta[n_in * n_out : None].reshape(1, n_out) def act(self, ob): a = ob.dot(self.W) + self.b return a

推薦閱讀:

開源周報第1期:GitHub Trending 丁酉年叱吒開源項目風雲榜
Compare.NET Objects對象比較組件
2018 最值得關注的開源工具 Top10 ,造輪子就靠它們了!
新手通過互聯網賺取外快的最佳項目:頭條號
neofetch常用的配置文件

TAG:開源項目 | GitHub | Python |