mirror of
https://github.com/ZhangXinNan/DL-with-Python-and-PyTorch2.git
synced 2025-10-20 23:34:18 +08:00
324 lines
16 KiB
Plaintext
324 lines
16 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 18.4 SARSA 算法\n",
|
||
"SARSA(State-Action-Reward-State-Action)算法与Q-Learning算法非常相似,不同的就是更新Q值时, \n",
|
||
"Sarsa的现实值取$r+γQ(s_{t+1},a_{t+1})$,而不是$r+γ\\underset{a}{max} Q(s_{t+1},a)$。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 18.4.1 SARSA算法主要步骤\n",
|
||
"SARSA算法更新Q函数的步骤为: \n",
|
||
"1)获取初始状态s。 \n",
|
||
"2)执行上一步选择的行动a,获得奖励r和新状态next_s。 \n",
|
||
"3)在新状态next_s,根据当前的Q表,选定要执行的下一行动next_a。 \n",
|
||
"4)用r、next_a、next_s,更加sarsa逻辑更新Q表。 \n",
|
||
"5)把next_s赋给s,把next_a赋给a。 \n",
|
||
"### 18.4.2 用PyTorch实现SARSA算法\n",
|
||
"SARSA算法与Q-Learning算法基本相同,只有少部分不同,为此,只要修改实现Q-Learning算法中的部分代码即可。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import time\n",
|
||
"import numpy as np\n",
|
||
"import tkinter as tk\n",
|
||
"from PIL import ImageTk, Image\n",
|
||
"\n",
|
||
"np.random.seed(1)\n",
|
||
"PhotoImage = ImageTk.PhotoImage\n",
|
||
"UNIT = 100\n",
|
||
"HEIGHT = 5\n",
|
||
"WIDTH = 5\n",
|
||
"\n",
|
||
"\n",
|
||
"class Env(tk.Tk):\n",
|
||
" def __init__(self):\n",
|
||
" super(Env, self).__init__()\n",
|
||
" self.action_space = ['u', 'd', 'l', 'r']\n",
|
||
" self.n_actions = len(self.action_space)\n",
|
||
" self.title('Q Learning')\n",
|
||
" self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))\n",
|
||
" self.shapes = self.load_images()\n",
|
||
" self.canvas = self._build_canvas()\n",
|
||
" self.texts = []\n",
|
||
"\n",
|
||
" def _build_canvas(self):\n",
|
||
" canvas = tk.Canvas(self, bg='white',\n",
|
||
" height=HEIGHT * UNIT,\n",
|
||
" width=WIDTH * UNIT)\n",
|
||
" # create grids\n",
|
||
" for c in range(0, WIDTH * UNIT, UNIT): # 0~400 by 100\n",
|
||
" x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT\n",
|
||
" canvas.create_line(x0, y0, x1, y1)\n",
|
||
" for r in range(0, HEIGHT * UNIT, UNIT): # 0~400 by 100\n",
|
||
" x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r\n",
|
||
" canvas.create_line(x0, y0, x1, y1)\n",
|
||
"\n",
|
||
" # 把图标加载到环境中\n",
|
||
" self.rectangle = canvas.create_image(50, 50, image=self.shapes[0])\n",
|
||
" self.tree1 = canvas.create_image(250, 150, image=self.shapes[1])\n",
|
||
" self.tree2 = canvas.create_image(150, 250, image=self.shapes[1])\n",
|
||
" self.star = canvas.create_image(250, 250, image=self.shapes[2])\n",
|
||
"\n",
|
||
" # 对环境进行包装\n",
|
||
" canvas.pack()\n",
|
||
"\n",
|
||
" return canvas\n",
|
||
"\n",
|
||
" def load_images(self):\n",
|
||
" rectangle = PhotoImage(\n",
|
||
" Image.open(\"img/bob.png\").resize((65, 65)))\n",
|
||
" tree = PhotoImage(\n",
|
||
" Image.open(\"img/tree.png\").resize((65, 65)))\n",
|
||
" star = PhotoImage(\n",
|
||
" Image.open(\"img/star.jpg\").resize((65, 65)))\n",
|
||
"\n",
|
||
" return rectangle, tree, star\n",
|
||
"\n",
|
||
" def text_value(self, row, col, contents, action, font='Helvetica', size=10,\n",
|
||
" style='normal', anchor=\"nw\"):\n",
|
||
" if action == 0:\n",
|
||
" origin_x, origin_y = 7, 42\n",
|
||
" elif action == 1:\n",
|
||
" origin_x, origin_y = 85, 42\n",
|
||
" elif action == 2:\n",
|
||
" origin_x, origin_y = 42, 5\n",
|
||
" else:\n",
|
||
" origin_x, origin_y = 42, 77\n",
|
||
"\n",
|
||
" x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)\n",
|
||
" font = (font, str(size), style)\n",
|
||
" text = self.canvas.create_text(x, y, fill=\"black\", text=contents,\n",
|
||
" font=font, anchor=anchor)\n",
|
||
" return self.texts.append(text)\n",
|
||
"\n",
|
||
" def print_value_all(self, q_table):\n",
|
||
" for i in self.texts:\n",
|
||
" self.canvas.delete(i)\n",
|
||
" self.texts.clear()\n",
|
||
" for i in range(HEIGHT):\n",
|
||
" for j in range(WIDTH):\n",
|
||
" for action in range(0, 4):\n",
|
||
" state = [i, j]\n",
|
||
" if str(state) in q_table.keys():\n",
|
||
" temp = q_table[str(state)][action]\n",
|
||
" self.text_value(j, i, round(temp, 2), action)\n",
|
||
"\n",
|
||
" def coords_to_state(self, coords):\n",
|
||
" x = int((coords[0] - 50) / 100)\n",
|
||
" y = int((coords[1] - 50) / 100)\n",
|
||
" return [x, y]\n",
|
||
"\n",
|
||
" def state_to_coords(self, state):\n",
|
||
" x = int(state[0] * 100 + 50)\n",
|
||
" y = int(state[1] * 100 + 50)\n",
|
||
" return [x, y]\n",
|
||
"\n",
|
||
" def reset(self):\n",
|
||
" self.update()\n",
|
||
" time.sleep(0.5)\n",
|
||
" x, y = self.canvas.coords(self.rectangle)\n",
|
||
" self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)\n",
|
||
" self.render()\n",
|
||
" # return observation\n",
|
||
" return self.coords_to_state(self.canvas.coords(self.rectangle))\n",
|
||
"\n",
|
||
" def step(self, action):\n",
|
||
" state = self.canvas.coords(self.rectangle)\n",
|
||
" base_action = np.array([0, 0])\n",
|
||
" self.render()\n",
|
||
"\n",
|
||
" if action == 0: # up\n",
|
||
" if state[1] > UNIT:\n",
|
||
" base_action[1] -= UNIT\n",
|
||
" elif action == 1: # down\n",
|
||
" if state[1] < (HEIGHT - 1) * UNIT:\n",
|
||
" base_action[1] += UNIT\n",
|
||
" elif action == 2: # left\n",
|
||
" if state[0] > UNIT:\n",
|
||
" base_action[0] -= UNIT\n",
|
||
" elif action == 3: # right\n",
|
||
" if state[0] < (WIDTH - 1) * UNIT:\n",
|
||
" base_action[0] += UNIT\n",
|
||
"\n",
|
||
" # 移动\n",
|
||
" self.canvas.move(self.rectangle, base_action[0], base_action[1])\n",
|
||
" self.canvas.tag_raise(self.rectangle)\n",
|
||
" next_state = self.canvas.coords(self.rectangle)\n",
|
||
" # 判断得分条件\n",
|
||
" if next_state == self.canvas.coords(self.star):\n",
|
||
" reward = 100\n",
|
||
" done = True\n",
|
||
" elif next_state in [self.canvas.coords(self.tree1),\n",
|
||
" self.canvas.coords(self.tree2)]:\n",
|
||
" reward = -100\n",
|
||
" done = True\n",
|
||
" else:\n",
|
||
" reward = 0\n",
|
||
" done = False\n",
|
||
"\n",
|
||
" next_state = self.coords_to_state(next_state)\n",
|
||
" return next_state, reward, done\n",
|
||
"\n",
|
||
" # 渲染环境\n",
|
||
" def render(self):\n",
|
||
" time.sleep(0.03)\n",
|
||
" self.update()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"1. 修改学习函数"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import random\n",
|
||
"from collections import defaultdict\n",
|
||
"\n",
|
||
"\n",
|
||
"class QLearningAgent:\n",
|
||
" def __init__(self, actions):\n",
|
||
" # 四种动作分别用序列表示:[0, 1, 2, 3]\n",
|
||
" self.actions = actions\n",
|
||
" self.learning_rate = 0.01\n",
|
||
" self.discount_factor = 0.9\n",
|
||
" #epsilon贪婪策略取值\n",
|
||
" self.epsilon = 0.1\n",
|
||
" self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])\n",
|
||
"\n",
|
||
" # 采样 <s, a, r,a',s'>\n",
|
||
" def learn(self, state, action, reward,next_action,next_state):\n",
|
||
" current_q = self.q_table[state][action]\n",
|
||
" # 更新Q表\n",
|
||
" new_q = reward + self.discount_factor * (self.q_table[next_state][next_action])\n",
|
||
" self.q_table[state][action] += self.learning_rate * (new_q - current_q)\n",
|
||
"\n",
|
||
" # 从Q-table中选取动作\n",
|
||
" def get_action(self, state):\n",
|
||
" if np.random.rand() < self.epsilon:\n",
|
||
" # 贪婪策略随机探索动作\n",
|
||
" action = np.random.choice(self.actions)\n",
|
||
" else:\n",
|
||
" # 从q表中选择\n",
|
||
" state_action = self.q_table[state]\n",
|
||
" action = self.arg_max(state_action)\n",
|
||
" return action\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def arg_max(state_action):\n",
|
||
" max_index_list = []\n",
|
||
" max_value = state_action[0]\n",
|
||
" for index, value in enumerate(state_action):\n",
|
||
" if value > max_value:\n",
|
||
" max_index_list.clear()\n",
|
||
" max_value = value\n",
|
||
" max_index_list.append(index)\n",
|
||
" elif value == max_value:\n",
|
||
" max_index_list.append(index)\n",
|
||
" return random.choice(max_index_list)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"2.修改训练代码 \n",
|
||
"这里主要修改的内容如下: \n",
|
||
"1)新增获取下一步动作的函数next_action = agent.get_action(str(state)); \n",
|
||
"2)把next_action赋给action。 \n",
|
||
"具体实现代码如下:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "TclError",
|
||
"evalue": "invalid command name \".!canvas\"",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mTclError\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[1;32m<ipython-input-3-5849bab91513>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;31m#获取新的状态、奖励分数\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mnext_state\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdone\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maction\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[1;31m#产生新的动作\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mnext_action\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||
"\u001b[1;32m<ipython-input-1-aa377bdeb320>\u001b[0m in \u001b[0;36mstep\u001b[1;34m(self, action)\u001b[0m\n\u001b[0;32m 122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 123\u001b[0m \u001b[1;31m# 移动\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 124\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmove\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbase_action\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbase_action\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 125\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtag_raise\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[0mnext_state\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcoords\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||
"\u001b[1;32m~\\Anaconda3\\lib\\tkinter\\__init__.py\u001b[0m in \u001b[0;36mmove\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 2589\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mmove\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2590\u001b[0m \u001b[1;34m\"\"\"Move an item TAGORID given in ARGS.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2591\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtk\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_w\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'move'\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2592\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mpostscript\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcnf\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2593\u001b[0m \"\"\"Print the contents of the canvas to a postscript\n",
|
||
"\u001b[1;31mTclError\u001b[0m: invalid command name \".!canvas\""
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"env = Env()\n",
|
||
"agent = QLearningAgent(actions=list(range(env.n_actions)))\n",
|
||
"#共进行200次游戏\n",
|
||
"for episode in range(200):\n",
|
||
" state = env.reset()\n",
|
||
" action = agent.get_action(str(state))\n",
|
||
" while True:\n",
|
||
" env.render()\n",
|
||
" #获取新的状态、奖励分数\n",
|
||
" next_state, reward, done = env.step(action)\n",
|
||
" #产生新的动作\n",
|
||
" next_action = agent.get_action(str(state))\n",
|
||
" # 更新Q表,sarsa根据新的状态及动作获取Q表的值\n",
|
||
" #而不是基于新状态对所有动作的最大值\n",
|
||
" agent.learn(str(state), action, reward, next_action,str(next_state))\n",
|
||
" state = next_state\n",
|
||
" action=next_action\n",
|
||
" env.print_value_all(agent.q_table)\n",
|
||
" # 当到达终点就终止游戏开始新一轮训练\n",
|
||
" if done:\n",
|
||
" break"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|