You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
767 lines
31 KiB
767 lines
31 KiB
5 years ago
|
|
||
|
<!DOCTYPE HTML>
|
||
|
<html lang="" >
|
||
|
<head>
|
||
|
<meta charset="UTF-8">
|
||
|
<meta content="text/html; charset=utf-8" http-equiv="Content-Type">
|
||
|
<title>使用Policy Gradient玩乒乓球游戏 · GitBook</title>
|
||
|
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
|
||
|
<meta name="description" content="">
|
||
|
<meta name="generator" content="GitBook 3.2.3">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="../gitbook/style.css">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="../gitbook/gitbook-plugin-katex/katex.min.css">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="../gitbook/gitbook-plugin-highlight/website.css">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="../gitbook/gitbook-plugin-search/search.css">
|
||
|
|
||
|
|
||
|
|
||
|
<link rel="stylesheet" href="../gitbook/gitbook-plugin-fontsettings/website.css">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<meta name="HandheldFriendly" content="true"/>
|
||
|
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no">
|
||
|
<meta name="apple-mobile-web-app-capable" content="yes">
|
||
|
<meta name="apple-mobile-web-app-status-bar-style" content="black">
|
||
|
<link rel="apple-touch-icon-precomposed" sizes="152x152" href="../gitbook/images/apple-touch-icon-precomposed-152.png">
|
||
|
<link rel="shortcut icon" href="../gitbook/images/favicon.ico" type="image/x-icon">
|
||
|
|
||
|
|
||
5 years ago
|
<link rel="next" href="../recommand.html" />
|
||
|
|
||
5 years ago
|
|
||
|
<link rel="prev" href="Policy Gradient.html" />
|
||
|
|
||
|
|
||
|
</head>
|
||
|
<body>
|
||
|
|
||
|
<div class="book">
|
||
|
<div class="book-summary">
|
||
|
|
||
|
|
||
|
<div id="book-search-input" role="search">
|
||
|
<input type="text" placeholder="Type to search" />
|
||
|
</div>
|
||
|
|
||
|
|
||
|
<nav role="navigation">
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="summary">
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.1" data-path="../">
|
||
|
|
||
|
<a href="../">
|
||
|
|
||
|
|
||
|
简介
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.2" data-path="../machine_learning.html">
|
||
|
|
||
|
<a href="../machine_learning.html">
|
||
|
|
||
|
|
||
|
机器学习概述
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3" data-path="../algorithm.html">
|
||
|
|
||
|
<a href="../algorithm.html">
|
||
|
|
||
|
|
||
|
常见机器学习算法
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.3.1" data-path="../kNN.html">
|
||
|
|
||
|
<a href="../kNN.html">
|
||
|
|
||
|
|
||
|
近朱者赤近墨者黑-kNN
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.2" data-path="../linear_regression.html">
|
||
|
|
||
|
<a href="../linear_regression.html">
|
||
|
|
||
|
|
||
|
最简单的回归算法-线性回归
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.3" data-path="../logistic_regression.html">
|
||
|
|
||
|
<a href="../logistic_regression.html">
|
||
|
|
||
|
|
||
|
使用回归的思想进行分类-逻辑回归
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.4" data-path="../decision_tree.html">
|
||
|
|
||
|
<a href="../decision_tree.html">
|
||
|
|
||
|
|
||
|
最接近人类思维的算法-决策树
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.5" data-path="../random_forest.html">
|
||
|
|
||
|
<a href="../random_forest.html">
|
||
|
|
||
|
|
||
|
群众的力量是伟大的-随机森林
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.6" data-path="../kMeans.html">
|
||
|
|
||
|
<a href="../kMeans.html">
|
||
|
|
||
|
|
||
|
物以类聚人以群分-kMeans
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.3.7" data-path="../AGNES.html">
|
||
|
|
||
|
<a href="../AGNES.html">
|
||
|
|
||
|
|
||
|
以距离为尺-AGNES
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4" data-path="../metrics.html">
|
||
|
|
||
|
<a href="../metrics.html">
|
||
|
|
||
|
|
||
|
模型评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.4.1" data-path="../classification_metrics.html">
|
||
|
|
||
|
<a href="../classification_metrics.html">
|
||
|
|
||
|
|
||
|
分类性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4.2" data-path="../regression_metrics.html">
|
||
|
|
||
|
<a href="../regression_metrics.html">
|
||
|
|
||
|
|
||
|
回归性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.4.3" data-path="../cluster_metrics.html">
|
||
|
|
||
|
<a href="../cluster_metrics.html">
|
||
|
|
||
|
|
||
|
聚类性能评估指标
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.5" data-path="../sklearn.html">
|
||
|
|
||
|
<a href="../sklearn.html">
|
||
|
|
||
|
|
||
|
使用sklearn进行机器学习
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.6" >
|
||
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
综合实战案例
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
|
<li class="chapter " data-level="1.6.1" >
|
||
5 years ago
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
泰坦尼克生还预测
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.1.1" data-path="../titanic/introduction.html">
|
||
5 years ago
|
|
||
|
<a href="../titanic/introduction.html">
|
||
|
|
||
|
|
||
|
简介
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.1.2" data-path="../titanic/EDA.html">
|
||
5 years ago
|
|
||
|
<a href="../titanic/EDA.html">
|
||
|
|
||
|
|
||
|
探索性数据分析(EDA)
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.1.3" data-path="../titanic/feature engerning.html">
|
||
5 years ago
|
|
||
|
<a href="../titanic/feature engerning.html">
|
||
|
|
||
|
|
||
|
特征工程
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.1.4" data-path="../titanic/fit and predict.html">
|
||
5 years ago
|
|
||
|
<a href="../titanic/fit and predict.html">
|
||
|
|
||
|
|
||
|
构建模型进行预测
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.1.5" data-path="../titanic/tuning.html">
|
||
5 years ago
|
|
||
|
<a href="../titanic/tuning.html">
|
||
|
|
||
|
|
||
|
调参
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.2" >
|
||
5 years ago
|
|
||
|
<span>
|
||
|
|
||
|
|
||
|
使用强化学习玩乒乓球游戏
|
||
|
|
||
|
</span>
|
||
|
|
||
|
|
||
|
|
||
|
<ul class="articles">
|
||
|
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.2.1" data-path="what is reinforce learning.html">
|
||
5 years ago
|
|
||
|
<a href="what is reinforce learning.html">
|
||
|
|
||
|
|
||
|
什么是强化学习
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter " data-level="1.6.2.2" data-path="Policy Gradient.html">
|
||
5 years ago
|
|
||
|
<a href="Policy Gradient.html">
|
||
|
|
||
|
|
||
|
Policy Gradient原理
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
5 years ago
|
<li class="chapter active" data-level="1.6.2.3" data-path="coding.html">
|
||
5 years ago
|
|
||
|
<a href="coding.html">
|
||
|
|
||
|
|
||
|
使用Policy Gradient玩乒乓球游戏
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
5 years ago
|
</ul>
|
||
|
|
||
|
</li>
|
||
|
|
||
|
<li class="chapter " data-level="1.7" data-path="../recommand.html">
|
||
|
|
||
|
<a href="../recommand.html">
|
||
|
|
||
|
|
||
|
实训推荐
|
||
|
|
||
|
</a>
|
||
|
|
||
|
|
||
|
|
||
|
</li>
|
||
|
|
||
|
|
||
5 years ago
|
|
||
|
|
||
|
<li class="divider"></li>
|
||
|
|
||
|
<li>
|
||
|
<a href="https://www.gitbook.com" target="blank" class="gitbook-link">
|
||
|
Published with GitBook
|
||
|
</a>
|
||
|
</li>
|
||
|
</ul>
|
||
|
|
||
|
|
||
|
</nav>
|
||
|
|
||
|
|
||
|
</div>
|
||
|
|
||
|
<div class="book-body">
|
||
|
|
||
|
<div class="body-inner">
|
||
|
|
||
|
|
||
|
|
||
|
<div class="book-header" role="navigation">
|
||
|
|
||
|
|
||
|
<!-- Title -->
|
||
|
<h1>
|
||
|
<i class="fa fa-circle-o-notch fa-spin"></i>
|
||
|
<a href=".." >使用Policy Gradient玩乒乓球游戏</a>
|
||
|
</h1>
|
||
|
</div>
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
<div class="page-wrapper" tabindex="-1" role="main">
|
||
|
<div class="page-inner">
|
||
|
|
||
|
<div id="book-search-results">
|
||
|
<div class="search-noresults">
|
||
|
|
||
|
<section class="normal markdown-section">
|
||
|
|
||
|
<h1 id="使用policy-gradient玩乒乓球游戏">使用Policy Gradient玩乒乓球游戏</h1>
|
||
|
<h2 id="安装-gym">安装 gym</h2>
|
||
|
<p>想要玩乒乓球游戏,首先得有乒乓球游戏。OpenAI 的 gym 为我们提供了模拟游戏的环境。使得我们能够很方便地得到游戏的环境状态,并作出动作。想要安装 gym 非常简单,只要在命令行中输入<code>pip install gym</code>即可。</p>
|
||
|
<h2 id="安装-ataripy">安装 atari_py</h2>
|
||
|
<p>由于乒乓球游戏是雅达利游戏机上的游戏,所以需要安装 atari_py 来实现雅达利环境的模拟。安装 atari_py 也很方便,只需在命令行中输入<code>pip install --no-index -f https://github.com/Kojoley/atari-py/releases atari_py</code> 即可。</p>
|
||
|
<h2 id="开启游戏">开启游戏</h2>
|
||
|
<p>当安装好所需要的库之后,我们可以使用如下代码开始游戏:</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 开启乒乓球游戏环境</span>
|
||
|
<span class="hljs-keyword">import</span> gym
|
||
|
|
||
|
env = gym.make(<span class="hljs-string">'Pong-v0'</span>)
|
||
|
|
||
|
<span class="hljs-comment"># 一直渲染游戏画面</span>
|
||
|
<span class="hljs-keyword">while</span> <span class="hljs-keyword">True</span>:
|
||
|
env.render()
|
||
|
<span class="hljs-comment"># 随机做动作,并得到做完动作之后的环境(observation),反馈(reward),是否结束(done)</span>
|
||
|
observation, reward, done, _ = env.step(env.action_space.sample())
|
||
|
</code></pre>
|
||
|
<h2 id="游戏画面预处理">游戏画面预处理</h2>
|
||
|
<p>由于<code>env.step</code>返回出来的 observation 是一张RGB的三通道图,而且我们的挡板怎么移动只跟挡板和球有关系,所以我们可以尝试将三通道图转换成一张二值化的图,其中挡板和球是 1 ,背景是 0 。</p>
|
||
|
<pre><code class="lang-python">
|
||
|
<span class="hljs-comment"># 游戏画面预处理</span>
|
||
|
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">prepro</span><span class="hljs-params">(I)</span>:</span>
|
||
|
I = I[<span class="hljs-number">35</span>:<span class="hljs-number">195</span>] <span class="hljs-comment">#不要上面的记分牌</span>
|
||
|
I = I[::<span class="hljs-number">2</span>, ::<span class="hljs-number">2</span>, <span class="hljs-number">0</span>] <span class="hljs-comment">#scale 0.5,所以I是高为80,宽为80的单通道图</span>
|
||
|
I[I == <span class="hljs-number">144</span>] = <span class="hljs-number">0</span> <span class="hljs-comment"># 背景赋值为0</span>
|
||
|
I[I == <span class="hljs-number">109</span>] = <span class="hljs-number">0</span> <span class="hljs-comment"># 背景赋值为0</span>
|
||
|
I[I != <span class="hljs-number">0</span>] = <span class="hljs-number">1</span> <span class="hljs-comment"># 目标为1</span>
|
||
|
<span class="hljs-keyword">return</span> I.astype(np.float).ravel() <span class="hljs-comment">#将二维图压成一维的数组</span>
|
||
|
|
||
|
<span class="hljs-comment"># cur_x为预处理后的游戏画面</span>
|
||
|
cur_x = prepro(observation)
|
||
|
</code></pre>
|
||
|
<p>游戏的画面是逐帧组成的,如果我们将当前帧和上一帧的图像相减就能得到能够表示两帧之间的变化的帧差图,将这样的帧差图作为神经网络的输入的话会是个不错的选择。</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># x为帧差图</span>
|
||
|
x = cur_x - prev_x
|
||
|
<span class="hljs-comment"># 将当前帧更新为上一帧</span>
|
||
|
prev_x = cur_x
|
||
|
</code></pre>
|
||
|
<h2 id="搭建神经网络">搭建神经网络</h2>
|
||
|
<p>神经网络可以根据自己的喜好来搭建,在这里我使用最简单的只有两层全连接层的网络模型来进行预测,由于我们挡板的动作只有上和下,所以最后的激活函数为 sigmoid 函数。</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-comment"># 神经网络中神经元的参数</span>
|
||
|
model = {}
|
||
|
<span class="hljs-comment"># 随机初始化第一层的神经元参数,总共200个神经元</span>
|
||
|
model[<span class="hljs-string">'W1'</span>] = np.random.randn(H, D) / np.sqrt(D)
|
||
|
<span class="hljs-comment"># 随机初始化第二层的神经元参数,总共200个神经元</span>
|
||
|
model[<span class="hljs-string">'W2'</span>] = np.random.randn(H) / np.sqrt(H)
|
||
|
|
||
|
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">sigmoid</span><span class="hljs-params">(x)</span>:</span>
|
||
|
<span class="hljs-keyword">return</span> <span class="hljs-number">1.0</span> / (<span class="hljs-number">1.0</span> + np.exp(-x))
|
||
|
|
||
|
<span class="hljs-comment"># 神经网络的前向传播,x为输入的帧差图</span>
|
||
|
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">policy_forward</span><span class="hljs-params">(x)</span>:</span>
|
||
|
h = np.dot(model[<span class="hljs-string">'W1'</span>], x)
|
||
|
<span class="hljs-comment"># relu</span>
|
||
|
h[h < <span class="hljs-number">0</span>] = <span class="hljs-number">0</span>
|
||
|
logp = np.dot(model[<span class="hljs-string">'W2'</span>], h)
|
||
|
<span class="hljs-comment"># sigmoid激活</span>
|
||
|
p = sigmoid(logp)
|
||
|
<span class="hljs-comment"># p为下一步要往下挪的概率,h为隐藏层中神经元的参数</span>
|
||
|
<span class="hljs-keyword">return</span> p, h
|
||
|
|
||
|
|
||
|
<span class="hljs-comment"># 算每层的参数偏导,eph为一个游戏序列的隐藏层中神经元的参数,epdlogp为一个游戏序列中反馈期望的偏导。</span>
|
||
|
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">policy_backward</span><span class="hljs-params">(eph, epdlogp)</span>:</span>
|
||
|
dW2 = np.dot(eph.T, epdlogp).ravel()
|
||
|
dh = np.outer(epdlogp, model[<span class="hljs-string">'W2'</span>])
|
||
|
dh[eph <= <span class="hljs-number">0</span>] = <span class="hljs-number">0</span>
|
||
|
dW1 = np.dot(dh.T, epx)
|
||
|
<span class="hljs-keyword">return</span> {<span class="hljs-string">'W1'</span>: dW1, <span class="hljs-string">'W2'</span>: dW2}
|
||
|
</code></pre>
|
||
|
<h2 id="训练神经网络">训练神经网络</h2>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">while</span> <span class="hljs-keyword">True</span>:
|
||
|
env.render()
|
||
|
|
||
|
<span class="hljs-comment"># 游戏画面预处理</span>
|
||
|
cur_x = prepro(observation)
|
||
|
<span class="hljs-comment"># 得到帧差图</span>
|
||
|
x = cur_x - prev_x <span class="hljs-keyword">if</span> prev_x <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-keyword">None</span> <span class="hljs-keyword">else</span> np.zeros(D)
|
||
|
<span class="hljs-comment"># 将上一帧更新为当前帧</span>
|
||
|
prev_x = cur_x
|
||
|
|
||
|
<span class="hljs-comment">#前向传播</span>
|
||
|
aprob, h = policy_forward(x)
|
||
|
<span class="hljs-comment">#从动作概率分布中采样,action=2表示往上挪,action=3表示往下挪</span>
|
||
|
action = <span class="hljs-number">2</span> <span class="hljs-keyword">if</span> np.random.uniform() < aprob <span class="hljs-keyword">else</span> <span class="hljs-number">3</span>
|
||
|
|
||
|
<span class="hljs-comment"># 环境</span>
|
||
|
xs.append(x)
|
||
|
<span class="hljs-comment"># 隐藏层状态</span>
|
||
|
hs.append(h)
|
||
|
<span class="hljs-comment"># 将2和3改成1和0,因为sigmoid函数的导数为f(x)*(1-f(x))</span>
|
||
|
y = <span class="hljs-number">1</span> <span class="hljs-keyword">if</span> action == <span class="hljs-number">2</span> <span class="hljs-keyword">else</span> <span class="hljs-number">0</span>
|
||
|
dlogps.append(y - aprob)
|
||
|
|
||
|
<span class="hljs-comment"># 把采样到的动作传回环境</span>
|
||
|
observation, reward, done, info = env.step(action)
|
||
|
<span class="hljs-comment"># 如果得一分则reward为1,丢一份则reward为-1</span>
|
||
|
reward_sum += reward
|
||
|
|
||
|
<span class="hljs-comment"># 记录反馈</span>
|
||
|
drs.append(reward)
|
||
|
|
||
|
<span class="hljs-comment"># 当有一方得到21分后游戏结束</span>
|
||
|
<span class="hljs-keyword">if</span> done:
|
||
|
episode_number += <span class="hljs-number">1</span>
|
||
|
|
||
|
epx = np.vstack(xs)
|
||
|
eph = np.vstack(hs)
|
||
|
epdlogp = np.vstack(dlogps)
|
||
|
epr = np.vstack(drs)
|
||
|
discounted_epr = discount_rewards(epr)
|
||
|
<span class="hljs-comment"># 将反馈进行zscore归一化,有利于训练</span>
|
||
|
discounted_epr -= np.mean(discounted_epr)
|
||
|
discounted_epr /= np.std(discounted_epr)
|
||
|
|
||
|
<span class="hljs-comment">#算期望</span>
|
||
|
epdlogp *= discounted_epr
|
||
|
<span class="hljs-comment">#算梯度</span>
|
||
|
grad = policy_backward(eph, epdlogp)
|
||
|
<span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> model:
|
||
|
grad_buffer[k] += grad[k]
|
||
|
|
||
|
<span class="hljs-comment"># 每batch_size次游戏更新一次参数</span>
|
||
|
<span class="hljs-keyword">if</span> episode_number % batch_size == <span class="hljs-number">0</span>:
|
||
|
<span class="hljs-comment">#rmsprop梯度上升</span>
|
||
|
<span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> model.items():
|
||
|
g = grad_buffer[k]
|
||
|
rmsprop_cache[k] = decay_rate * rmsprop_cache[k] + (<span class="hljs-number">1</span> - decay_rate) * g ** <span class="hljs-number">2</span>
|
||
|
model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + <span class="hljs-number">1e-5</span>)
|
||
|
grad_buffer[k] = np.zeros_like(v)
|
||
|
|
||
|
<span class="hljs-comment"># 每100把之后保存模型</span>
|
||
|
<span class="hljs-keyword">if</span> episode_number % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:
|
||
|
pickle.dump(model, open(<span class="hljs-string">'save.p'</span>, <span class="hljs-string">'wb'</span>))
|
||
|
reward_sum = <span class="hljs-number">0</span>
|
||
|
<span class="hljs-comment"># 重置游戏</span>
|
||
|
observation = env.reset()
|
||
|
prev_x = <span class="hljs-keyword">None</span>
|
||
|
</code></pre>
|
||
|
<h2 id="加载模型玩游戏">加载模型玩游戏</h2>
|
||
|
<p>经过漫长的训练过程后,我们可以将训练好的模型加载进来开始玩游戏了。</p>
|
||
|
<pre><code class="lang-python"><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
|
||
|
<span class="hljs-keyword">import</span> pickle
|
||
|
<span class="hljs-keyword">import</span> gym
|
||
|
|
||
|
model = pickle.load(open(<span class="hljs-string">'save.p'</span>, <span class="hljs-string">'rb'</span>))
|
||
|
|
||
|
env = gym.make(<span class="hljs-string">"Pong-v0"</span>)
|
||
|
observation = env.reset()
|
||
|
|
||
|
<span class="hljs-keyword">while</span> <span class="hljs-keyword">True</span>:
|
||
|
env.render()
|
||
|
cur_x = prepro(observation)
|
||
|
x = cur_x - prev_x <span class="hljs-keyword">if</span> prev_x <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-keyword">None</span> <span class="hljs-keyword">else</span> np.zeros(<span class="hljs-number">80</span>*<span class="hljs-number">80</span>)
|
||
|
prev_x = cur_x
|
||
|
aprob, h = policy_forward(x)
|
||
|
<span class="hljs-comment">#从动作概率分布中采样</span>
|
||
|
action = <span class="hljs-number">2</span> <span class="hljs-keyword">if</span> np.random.uniform() < aprob <span class="hljs-keyword">else</span> <span class="hljs-number">3</span>
|
||
|
observation, reward, done, info = env.step(action)
|
||
|
|
||
|
<span class="hljs-keyword">if</span> done:
|
||
|
observation = env.reset()
|
||
|
prev_x = <span class="hljs-keyword">None</span>
|
||
|
</code></pre>
|
||
|
|
||
|
|
||
|
</section>
|
||
|
|
||
|
</div>
|
||
|
<div class="search-results">
|
||
|
<div class="has-results">
|
||
|
|
||
|
<h1 class="search-results-title"><span class='search-results-count'></span> results matching "<span class='search-query'></span>"</h1>
|
||
|
<ul class="search-results-list"></ul>
|
||
|
|
||
|
</div>
|
||
|
<div class="no-results">
|
||
|
|
||
|
<h1 class="search-results-title">No results matching "<span class='search-query'></span>"</h1>
|
||
|
|
||
|
</div>
|
||
|
</div>
|
||
|
</div>
|
||
|
|
||
|
</div>
|
||
|
</div>
|
||
|
|
||
|
</div>
|
||
|
|
||
|
|
||
|
|
||
5 years ago
|
<a href="Policy Gradient.html" class="navigation navigation-prev " aria-label="Previous page: Policy Gradient原理">
|
||
5 years ago
|
<i class="fa fa-angle-left"></i>
|
||
|
</a>
|
||
|
|
||
|
|
||
5 years ago
|
<a href="../recommand.html" class="navigation navigation-next " aria-label="Next page: 实训推荐">
|
||
|
<i class="fa fa-angle-right"></i>
|
||
|
</a>
|
||
|
|
||
5 years ago
|
|
||
|
|
||
|
</div>
|
||
|
|
||
|
<script>
|
||
|
var gitbook = gitbook || [];
|
||
|
gitbook.push(function() {
|
||
5 years ago
|
gitbook.page.hasChanged({"page":{"title":"使用Policy Gradient玩乒乓球游戏","level":"1.6.2.3","depth":3,"next":{"title":"实训推荐","level":"1.7","depth":1,"path":"recommand.md","ref":"recommand.md","articles":[]},"previous":{"title":"Policy Gradient原理","level":"1.6.2.2","depth":3,"path":"pingpong/Policy Gradient.md","ref":"./pingpong/Policy Gradient.md","articles":[]},"dir":"ltr"},"config":{"gitbook":"*","theme":"default","variables":{},"plugins":["katex"],"pluginsConfig":{"katex":{},"highlight":{},"search":{},"lunr":{"maxIndexSize":1000000,"ignoreSpecialCharacters":false},"sharing":{"facebook":true,"twitter":true,"google":false,"weibo":false,"instapaper":false,"vk":false,"all":["facebook","google","twitter","weibo","instapaper"]},"fontsettings":{"theme":"white","family":"sans","size":2},"theme-default":{"styles":{"website":"styles/website.css","pdf":"styles/pdf.css","epub":"styles/epub.css","mobi":"styles/mobi.css","ebook":"styles/ebook.css","print":"styles/print.css"},"showLevel":false}},"structure":{"langs":"LANGS.md","readme":"README.md","glossary":"GLOSSARY.md","summary":"SUMMARY.md"},"pdf":{"pageNumbers":true,"fontSize":12,"fontFamily":"Arial","paperSize":"a4","chapterMark":"pagebreak","pageBreaksBefore":"/","margin":{"right":62,"left":62,"top":56,"bottom":56}},"styles":{"website":"styles/website.css","pdf":"styles/pdf.css","epub":"styles/epub.css","mobi":"styles/mobi.css","ebook":"styles/ebook.css","print":"styles/print.css"}},"file":{"path":"pingpong/coding.md","mtime":"2019-07-05T01:30:47.528Z","type":"markdown"},"gitbook":{"version":"3.2.3","time":"2019-07-06T07:31:21.537Z"},"basePath":"..","book":{"language":""}});
|
||
5 years ago
|
});
|
||
|
</script>
|
||
|
</div>
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook.js"></script>
|
||
|
<script src="../gitbook/theme.js"></script>
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-search/search-engine.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-search/search.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-lunr/lunr.min.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-lunr/search-lunr.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-sharing/buttons.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
<script src="../gitbook/gitbook-plugin-fontsettings/fontsettings.js"></script>
|
||
|
|
||
|
|
||
|
|
||
|
</body>
|
||
|
</html>
|
||
|
|