Files
DL-with-Python-and-PyTorch2/pytorch-18/pytorch-18-02.ipynb
2024-01-08 00:44:52 +08:00

324 lines
16 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 18.4 SARSA 算法\n",
"SARSA(State-Action-Reward-State-Action)算法与Q-Learning算法非常相似不同的就是更新Q值时 \n",
"Sarsa的现实值取$r+γQ(s_{t+1},a_{t+1})$,而不是$r+γ\\underset{a}{max} Q(s_{t+1},a)$。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 18.4.1 SARSA算法主要步骤\n",
"SARSA算法更新Q函数的步骤为 \n",
"1获取初始状态s。 \n",
"2执行上一步选择的行动a获得奖励r和新状态next_s。 \n",
"3在新状态next_s根据当前的Q表选定要执行的下一行动next_a。 \n",
"4用r、next_a、next_s,更加sarsa逻辑更新Q表。 \n",
"5把next_s赋给s把next_a赋给a。 \n",
"### 18.4.2 用PyTorch实现SARSA算法\n",
"SARSA算法与Q-Learning算法基本相同只有少部分不同为此只要修改实现Q-Learning算法中的部分代码即可。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import numpy as np\n",
"import tkinter as tk\n",
"from PIL import ImageTk, Image\n",
"\n",
"np.random.seed(1)\n",
"PhotoImage = ImageTk.PhotoImage\n",
"UNIT = 100\n",
"HEIGHT = 5\n",
"WIDTH = 5\n",
"\n",
"\n",
"class Env(tk.Tk):\n",
" def __init__(self):\n",
" super(Env, self).__init__()\n",
" self.action_space = ['u', 'd', 'l', 'r']\n",
" self.n_actions = len(self.action_space)\n",
" self.title('Q Learning')\n",
" self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))\n",
" self.shapes = self.load_images()\n",
" self.canvas = self._build_canvas()\n",
" self.texts = []\n",
"\n",
" def _build_canvas(self):\n",
" canvas = tk.Canvas(self, bg='white',\n",
" height=HEIGHT * UNIT,\n",
" width=WIDTH * UNIT)\n",
" # create grids\n",
" for c in range(0, WIDTH * UNIT, UNIT): # 0~400 by 100\n",
" x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT\n",
" canvas.create_line(x0, y0, x1, y1)\n",
" for r in range(0, HEIGHT * UNIT, UNIT): # 0~400 by 100\n",
" x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r\n",
" canvas.create_line(x0, y0, x1, y1)\n",
"\n",
" # 把图标加载到环境中\n",
" self.rectangle = canvas.create_image(50, 50, image=self.shapes[0])\n",
" self.tree1 = canvas.create_image(250, 150, image=self.shapes[1])\n",
" self.tree2 = canvas.create_image(150, 250, image=self.shapes[1])\n",
" self.star = canvas.create_image(250, 250, image=self.shapes[2])\n",
"\n",
" # 对环境进行包装\n",
" canvas.pack()\n",
"\n",
" return canvas\n",
"\n",
" def load_images(self):\n",
" rectangle = PhotoImage(\n",
" Image.open(\"img/bob.png\").resize((65, 65)))\n",
" tree = PhotoImage(\n",
" Image.open(\"img/tree.png\").resize((65, 65)))\n",
" star = PhotoImage(\n",
" Image.open(\"img/star.jpg\").resize((65, 65)))\n",
"\n",
" return rectangle, tree, star\n",
"\n",
" def text_value(self, row, col, contents, action, font='Helvetica', size=10,\n",
" style='normal', anchor=\"nw\"):\n",
" if action == 0:\n",
" origin_x, origin_y = 7, 42\n",
" elif action == 1:\n",
" origin_x, origin_y = 85, 42\n",
" elif action == 2:\n",
" origin_x, origin_y = 42, 5\n",
" else:\n",
" origin_x, origin_y = 42, 77\n",
"\n",
" x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)\n",
" font = (font, str(size), style)\n",
" text = self.canvas.create_text(x, y, fill=\"black\", text=contents,\n",
" font=font, anchor=anchor)\n",
" return self.texts.append(text)\n",
"\n",
" def print_value_all(self, q_table):\n",
" for i in self.texts:\n",
" self.canvas.delete(i)\n",
" self.texts.clear()\n",
" for i in range(HEIGHT):\n",
" for j in range(WIDTH):\n",
" for action in range(0, 4):\n",
" state = [i, j]\n",
" if str(state) in q_table.keys():\n",
" temp = q_table[str(state)][action]\n",
" self.text_value(j, i, round(temp, 2), action)\n",
"\n",
" def coords_to_state(self, coords):\n",
" x = int((coords[0] - 50) / 100)\n",
" y = int((coords[1] - 50) / 100)\n",
" return [x, y]\n",
"\n",
" def state_to_coords(self, state):\n",
" x = int(state[0] * 100 + 50)\n",
" y = int(state[1] * 100 + 50)\n",
" return [x, y]\n",
"\n",
" def reset(self):\n",
" self.update()\n",
" time.sleep(0.5)\n",
" x, y = self.canvas.coords(self.rectangle)\n",
" self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)\n",
" self.render()\n",
" # return observation\n",
" return self.coords_to_state(self.canvas.coords(self.rectangle))\n",
"\n",
" def step(self, action):\n",
" state = self.canvas.coords(self.rectangle)\n",
" base_action = np.array([0, 0])\n",
" self.render()\n",
"\n",
" if action == 0: # up\n",
" if state[1] > UNIT:\n",
" base_action[1] -= UNIT\n",
" elif action == 1: # down\n",
" if state[1] < (HEIGHT - 1) * UNIT:\n",
" base_action[1] += UNIT\n",
" elif action == 2: # left\n",
" if state[0] > UNIT:\n",
" base_action[0] -= UNIT\n",
" elif action == 3: # right\n",
" if state[0] < (WIDTH - 1) * UNIT:\n",
" base_action[0] += UNIT\n",
"\n",
" # 移动\n",
" self.canvas.move(self.rectangle, base_action[0], base_action[1])\n",
" self.canvas.tag_raise(self.rectangle)\n",
" next_state = self.canvas.coords(self.rectangle)\n",
" # 判断得分条件\n",
" if next_state == self.canvas.coords(self.star):\n",
" reward = 100\n",
" done = True\n",
" elif next_state in [self.canvas.coords(self.tree1),\n",
" self.canvas.coords(self.tree2)]:\n",
" reward = -100\n",
" done = True\n",
" else:\n",
" reward = 0\n",
" done = False\n",
"\n",
" next_state = self.coords_to_state(next_state)\n",
" return next_state, reward, done\n",
"\n",
" # 渲染环境\n",
" def render(self):\n",
" time.sleep(0.03)\n",
" self.update()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. 修改学习函数"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"from collections import defaultdict\n",
"\n",
"\n",
"class QLearningAgent:\n",
" def __init__(self, actions):\n",
" # 四种动作分别用序列表示:[0, 1, 2, 3]\n",
" self.actions = actions\n",
" self.learning_rate = 0.01\n",
" self.discount_factor = 0.9\n",
" #epsilon贪婪策略取值\n",
" self.epsilon = 0.1\n",
" self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])\n",
"\n",
" # 采样 <s, a, r,a',s'>\n",
" def learn(self, state, action, reward,next_action,next_state):\n",
" current_q = self.q_table[state][action]\n",
" # 更新Q表\n",
" new_q = reward + self.discount_factor * (self.q_table[next_state][next_action])\n",
" self.q_table[state][action] += self.learning_rate * (new_q - current_q)\n",
"\n",
" # 从Q-table中选取动作\n",
" def get_action(self, state):\n",
" if np.random.rand() < self.epsilon:\n",
" # 贪婪策略随机探索动作\n",
" action = np.random.choice(self.actions)\n",
" else:\n",
" # 从q表中选择\n",
" state_action = self.q_table[state]\n",
" action = self.arg_max(state_action)\n",
" return action\n",
"\n",
" @staticmethod\n",
" def arg_max(state_action):\n",
" max_index_list = []\n",
" max_value = state_action[0]\n",
" for index, value in enumerate(state_action):\n",
" if value > max_value:\n",
" max_index_list.clear()\n",
" max_value = value\n",
" max_index_list.append(index)\n",
" elif value == max_value:\n",
" max_index_list.append(index)\n",
" return random.choice(max_index_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.修改训练代码 \n",
"这里主要修改的内容如下: \n",
"1新增获取下一步动作的函数next_action = agent.get_action(str(state)) \n",
"2把next_action赋给action。 \n",
"具体实现代码如下:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "TclError",
"evalue": "invalid command name \".!canvas\"",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTclError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-3-5849bab91513>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;31m#获取新的状态、奖励分数\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mnext_state\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdone\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maction\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[1;31m#产生新的动作\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mnext_action\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m<ipython-input-1-aa377bdeb320>\u001b[0m in \u001b[0;36mstep\u001b[1;34m(self, action)\u001b[0m\n\u001b[0;32m 122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 123\u001b[0m \u001b[1;31m# 移动\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 124\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmove\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbase_action\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbase_action\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 125\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtag_raise\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[0mnext_state\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcoords\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrectangle\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\lib\\tkinter\\__init__.py\u001b[0m in \u001b[0;36mmove\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 2589\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mmove\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2590\u001b[0m \u001b[1;34m\"\"\"Move an item TAGORID given in ARGS.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2591\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtk\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_w\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'move'\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2592\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mpostscript\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcnf\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2593\u001b[0m \"\"\"Print the contents of the canvas to a postscript\n",
"\u001b[1;31mTclError\u001b[0m: invalid command name \".!canvas\""
]
}
],
"source": [
"env = Env()\n",
"agent = QLearningAgent(actions=list(range(env.n_actions)))\n",
"#共进行200次游戏\n",
"for episode in range(200):\n",
" state = env.reset()\n",
" action = agent.get_action(str(state))\n",
" while True:\n",
" env.render()\n",
" #获取新的状态、奖励分数\n",
" next_state, reward, done = env.step(action)\n",
" #产生新的动作\n",
" next_action = agent.get_action(str(state))\n",
" # 更新Q表sarsa根据新的状态及动作获取Q表的值\n",
" #而不是基于新状态对所有动作的最大值\n",
" agent.learn(str(state), action, reward, next_action,str(next_state))\n",
" state = next_state\n",
" action=next_action\n",
" env.print_value_all(agent.q_table)\n",
" # 当到达终点就终止游戏开始新一轮训练\n",
" if done:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}