强化学习实战 | 自定义Gym环境之扫雷

发布时间:2022-06-21 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了强化学习实战 | 自定义Gym环境之扫雷脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

开始之前

先考虑几个问题:

  • Q1:如何展开无雷区?
  • Q2:如何计算格子的提示数?
  • Q3:如何表示扫雷游戏的状态?

A1:可以使用递归函数,或是堆栈。

A2:一般的做法是,需要打开某格子时,再去统计周围的雷数。如果有方便的二维卷积函数可以调用,这会是个更简洁的方法:

$$begin{bmatrix}1 & 0 & 0 & 1 & 0\ 0 & 1 & 0 & 0 & 1\ 1 & 0 & 1 & 0 & 0\ 0 & 0 & 0 & 0 & 0\ 0 & 1 & 0 & 0 & 1end{bmatrix}Bigstar begin{bmatrix}1 & 1 & 1\ 1 & 0 & 1\ 1 & 1 & 1end{bmatrix}=begin{bmatrix}1 & 2 & 2 & 1 & 2\ 3 & 3 & 3 & 3 & 1\ 1 & 3 & 1 & 2 & 1\ 2 & 3 & 2 & 2 & 1\ 1 & 0 & 1 & 1 & 0end{bmatrix}$$

 不妨用 $bigstar$ 表示二维卷积运算。等号左边的5×5矩阵表示了雷的分布情况,值1表示有雷,值0表示无雷;等号左边的3×3矩阵是求解周围雷数的卷积核(或称滤波器,特征提取器);等号右边的矩阵即是所有格子的周围雷数。

代码实现起来也非常简单

From scipy import signal
import numpy as np
state_mine = np.array([[1,0,0,1,0],[0,1,0,0,1],[1,0,1,0,0],[0,0,0,0,0],[0,1,0,0,1]])
KERNAL = np.array([[1,1,1],[1,0,1],[1,1,1]])
state_num = signal.convolve2d(state_mine, KERNAL, 'same')

A3:对于玩家来说,游戏状态是不完全观测的,也即需要区分观测状态环境状态。环境状态包括雷分布矩阵,和提示数矩阵(也即上式提到的);观测状态是玩家部分可见的环境状态,需要根据格子的打开状态对雷分布矩阵进行部分屏蔽。观测状态不包括雷分布矩阵,因为一旦触雷即游戏结束,所以游戏中所有非终止状态都是无雷的。

那么对于一个大小为$M times N$的扫雷游戏,环境状态可以表示为 $M times N times 2$ 的张量:频道1是雷分布矩阵,频道2是提示数矩阵;观测状态可以表示为 $M times N times 2$ 的张量:频道1是表示格子打开状态的矩阵(值1为打开,值0为未打开),并以此矩阵对 提示数矩阵 进行元素乘,完成对环境状态的部分屏蔽,作为第二个频道。对于numpy.array而言,元素乘是容易的:

observe_num = state_num * state_oPEn

以下图的游戏状态为例说明:

强化学习实战 | 自定义Gym环境之扫雷

环境状态为:

$$begin{bmatrix} &  &  &  & \  & 1 &  &  & \  &  &  &  & \  &  &  & 1 & \ 1 & 1 &  &  & end{bmatrix}times begin{bmatrix}1 & 1 & 1 & 0 & 0\ 1 & 0 & 1 & 0 & 0\ 1 & 1 & 2 & 1 & 1\ 2 & 2 & 2 & 0 & 1\ 1 & 1 & 2 & 1 & 1end{bmatrix}$$

观测状态为:

$$begin{bmatrix}1 & 0 & 1 & 0 & 0\ 1 & 0 & 1 & 0 & 0\ 1 & 0 & 2 & 1 & 1\ 2 & 2 & 0 & 0 & 1\ 1 & 0 & 0 & 1 & 0end{bmatrix}times begin{bmatrix}1 &  & 1 & 1 & 1\ 1 &  & 1 & 1 & 1\ 1 &  & 1 & 1 & 1\ 1 & 1 &  &  & 1\ 1 &  &  & 1 & end{bmatrix}$$

 但这种表示方式不是唯一的,比如我们可以把提示数矩阵拆成9个频道,分别表示0~8的提示数。那么观测状态就变成了 $M times N times 10$ 的张量:

$$begin{bmatrix} & & & 1 & 1\ & & & 1 & 1\ & & & & \ & & & & \ & & & & end{bmatrix}times begin{bmatrix}1 & & 1 & & \ 1 & & 1 & & \ 1 & & & 1 & 1\ & & & & 1\ & & & 1 & end{bmatrix}times begin{bmatrix} & & & & \ & & & & \ & & 1 & & \ 1 & 1 & & & \ & & & & end{bmatrix}timesbegin{bmatrix} & & & & \ & & & & \ & & & & \ & & & & \ & & & & end{bmatrix}timescdots timesbegin{bmatrix} & & & & \ & & & & \ & & & & \ & & & & \ & & & & end{bmatrix}timesbegin{bmatrix}1 & & 1 & 1 & 1\ 1 & & 1 & 1 & 1\ 1 & & 1 & 1 & 1\ 1 & 1 & & & 1\ 1 & & & 1 & end{bmatrix}$$

状态空间的设计是灵活的,唯一的评价的标准是完整的学习系统的性能表现。如果采用以上多频道式的状态空间设计,那么后续可以很方便地使用卷积神经网络开展学习任务。你也可以把张量阵展成一维的向量,然后用全连接神经网络处理。本文后续的实现将采用 $M times N times 2$ 的状态空间表达。

步骤1:新建文件

为了运行pytorch,我使用anaconda的环境管理操作创建了名为pytorch1.1的环境名,并在这个环境下安装了openai gym,因此我来到目录:D:Anacondaenvspytorch1.1LibsITe-packagesgymenvSUSEr 下,新建文件 __init__.pyMineSweeper_env.py。

步骤2:编写文件 MineSweeper_env.py

一个标准的gym env类包含三个方法:reset(),step(action),和render()。

  • reset() 用于初始化环境;
  • step(action) 有四个返回值:state,reward,done,和info,因此我们需要在该函数中完成扫雷游戏的全部逻辑;
  • render() 用于可视化环境。我在网上没有找到gym的原生方法rendering可以显示文字的说法(如果有知晓的朋友请留言,感谢!),所以是通过pyglet + 动态变量名的方式实现大量字符的显示,具体做法可见 强化学习实战 | 自定义Gym环境之显示字符串。

MineSweeper_env.py 的整体代码如下:

import gym
import random
import time
import numpy as np
from scipy import signal # 二维卷积
import pyglet # 显示文字
from gym.envs.classic_control import rendering


class DrawText: # 用于在rendering中显示文字
    def __init__(self, label:pyglet.text.Label):
        self.label=label
    def render(self):
        self.label.draw()


class MineSweeperEnv(gym.Env):
    def __init__(self):
        self.MINE_NUM = 20
        self.ROW, self.COL = 12, 12
        self.SIZE = 40
        WIDTH = self.COL * self.SIZE
        HeiGHT = self.ROW * self.SIZE
        self.viewer = rendering.Viewer(WIDTH, HEIGHT)
        self.state_mine = None
        self.state_num = None
        self.state_open = None
        self.gameOver = False
        
        
    def reset(self):
        # 初始化:布雷状态
        MINE_NUM = self.MINE_NUM
        self.state_mine = np.zeros(self.ROW * self.COL) 
        self.state_mine[:MINE_NUM] = 1
        random.shuffle(self.state_mine)
        self.state_mine = self.state_mine.reShape(self.ROW, self.COL)
        # 初始化:提示数字
        KERNAL = np.array([[1,1,1], [1,0,1], [1,1,1]])
        self.state_num = signal.convolve2d(self.state_mine, KERNAL, 'same')
        # 初始化:打开状态
        self.state_open = np.zeros((self.ROW, self.COL))
        # 初始化:游戏是否结束
        self.gameOver = False
        
    
    def getRoundSet(self, x, y):
        roundSet = []
        for i in range(x-1, x+2):
            for j in range(y-1, y+2):
                if 0 <= i < self.ROW and 0 <= j < self.COL and (i, j) != (x, y):
                    roundSet.append((i, j))
        return roundSet
    
    
    def step(self, action):
        # 执行动作
        x, y = action
        # 若打开数字不为0
        if self.state_num[x, y] >= 1:
            self.state_open[x, y] = 1
        # 若打开数字为0 则展开无雷区
        if self.state_num[x, y] == 0:
            stack = []
            stack.append((x, y))
            while len(stack):
                row, col = stack.pop()
                self.state_open[row, col] = 1
                for one in self.getRoundSet(row, col):
                    # 排除已经打开的格子
                    if self.state_open[one] == 1:
                        continue
                    if self.state_num[one] >= 1:
                        self.state_open[one] = 1
                    else:
                        stack.append(one)         
    
        # 是否获胜或失败/获得奖励
        done, reward = False, 0
        # 若打开雷 则游戏失败
        if self.state_mine[x, y] == 1:
            self.state_open[x, y] = 1
            self.gameOver = True
            done, reward = True, -1
        # 若剩余未打开的格子数 = 雷数 则获胜
        if ROW*COL - self.state_open.sum() == self.MINE_NUM:
            self.gameOver = True
            done, reward = True, 1
        
        # 报告(维持gym step的标准格式)
        info = {}
        # 观测状态
        observe_num = self.state_num * self.state_open
        observe = [observe_num, self.state_open]
        return observe, reward, done, info
    
    
    def render(self, mode='human'):
        ROW, COL, SIZE = self.ROW, self.COL, self.SIZE
        # 画方块
        for i in range(ROW):
            for j in range(COL):
                X, Y = j*SIZE, (ROW-i-1)*SIZE
                tile = rendering.make_polygon([(X,Y), (X+SIZE,Y), (X+SIZE,Y+SIZE), (X,Y+SIZE)], filled=True)
                if self.state_open[i,j] == 0:
                    tile.set_color(106/255,116/255,166/255)
                if self.state_open[i,j] == 1 and self.state_mine[i,j] == 0:
                    tile.set_color(255/255,242/255,204/255)
                if self.state_open[i,j] == 1 and self.state_mine[i,j] == 1:
                    tile.set_color(220/255,20/255,60/255)
                self.viewer.add_geom(tile)
        # 画分隔线
        WIDTH = COL*SIZE
        HEIGHT = ROW*SIZE
        for i in range(ROW+1):
            line = rendering.Line((0, i*SIZE), (WIDTH, i*SIZE))
            line.set_color(80/255, 80/255, 80/255)
            self.viewer.add_geom(line)
        for j in range(COL+1):
            line = rendering.Line((j*SIZE, 0), (j*SIZE, HEIGHT))
            line.set_color(80/255, 80/255, 80/255)
            self.viewer.add_geom(line)
        # 画数字
        for i in range(ROW):
            for j in range(COL):
                exec('label_{}_{} = {}'.format(i, j, None))
                names = locals()
                NUM = int(self.state_num[i,j])
                COLOR = (255, 255, 255, 255)
                if NUM == 1:
                    COLOR = (46, 117, 182, 255)
                elif NUM == 2:
                    COLOR = (84, 130, 53, 255)
                elif NUM == 3:
                    COLOR = (192, 0, 0, 255)
                elif NUM == 4:
                    COLOR = (112, 48, 160, 255)
                elif NUM == 5:
                    COLOR = (132, 60, 12, 255)
                elif NUM == 6:
                    COLOR = (191, 144, 0, 255)
                elif NUM == 7:
                    COLOR = (32, 56, 100, 255)
                elif NUM == 8:
                    COLOR = (13, 13, 13, 255)
                names['label_' + str(i) + '_' + str(j)] = pyglet.text.Label('{}'.format(NUM), font_size=15,
                                  x=(j+0.32)*SIZE, y=(ROW-i-1+0.23)*SIZE, anchor_x='left', anchor_y='bottom',
                                  color=COLOR)
                label = names['label_{}_{}'.format(i, j)]
                label.draw()
                if self.state_mine[i,j] == 0 and self.state_open[i,j] == 1 and self.state_num[i,j] >= 1:
                    self.viewer.add_geom(DrawText(label))
                # 画雷
                if self.gameOver == True:
                    if self.state_mine[i,j] == 1:
                        mine = rendering.make_circle(10, 6, filled=True)
                        mine.set_color(30/255, 30/255, 30/255)
                        translation = rendering.Transform(translation=((j+0.5)*SIZE, (ROW-i-1+0.5)*SIZE))
                        mine.add_attr(translation)
                        self.viewer.add_geom(mine)
                
        return self.viewer.render(return_rgb_array=mode == 'rgb_array')
        

# 测试代码:以随机策略执行动作
if __name__ == '__main__': 
    MineSweeper = MineSweeperEnv()
    ROW, COL = MineSweeper.ROW, MineSweeper.COL
    MineSweeper.reset()
    MineSweeper.render()
    while MineSweeper.gameOver is not True:
        while True:
            rand = random.choice(range(ROW*COL))
            x, y = rand//ROW, rand%ROW
            if MineSweeper.state_open[x, y] == 0:
                action = (x, y)
                break
        state, reward, done, info = MineSweeper.step(action)
        MineSweeper.render()
        time.sleep(0.5)

直接运行文件,执行测试代码(以随机策略执行动作):

强化学习实战 | 自定义Gym环境之扫雷

步骤3:编写 __init__.py

在 __init__.py 中引入类的信息,添加:

from gym.envs.user.MineSweeper_env import MineSweeperEnv

步骤4:注册环境

来到目录:D:Anacondaenvspytorch1.1Libsite-packagesgym,打开 __init__.py,添加代码:

register(
    id="MineSweeperEnv-v0",
    entry_point="gym.envs.user:MineSweeperEnv",
    max_episode_steps=200,    
)

步骤5:测试环境

在相同的conda环境下,输入代码:

import gym
env = gym.make('MineSweeperEnv-v0')env.reset()env.render()

若无报错,则说明gym环境注册成功。

 

脚本宝典总结

以上是脚本宝典为你收集整理的强化学习实战 | 自定义Gym环境之扫雷全部内容,希望文章能够帮你解决强化学习实战 | 自定义Gym环境之扫雷所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。