Tutorial: Custom gym Environment¶
In this notebook, an environment is set up with continuous observation_space and discrete action_space.[source]
Importing Dependencies¶
# Importing Libraries
import gym
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import DQN
from stable_baselines3 import PPO
Shower Environment¶
class ShowerEnv(Env):
def __init__(self):
# Actions we can take, down, stay, up
self.action_space = Discrete(3)
# Temperature array
self.observation_space = Box(low=np.array([0]), high=np.array([100]))
# Set start temp
self.state = 38 + random.randint(-3,3)
# Set shower length
self.shower_length = 60
def step(self, action):
# Apply action
# 0 -1 = -1 temperature
# 1 -1 = 0
# 2 -1 = 1 temperature
self.state += action -1
# Reduce shower length by 1 second
self.shower_length -= 1
# Calculate reward
if self.state >=37 and self.state <=39:
reward =1
else:
reward = -1
# Check if shower is done
if self.shower_length <= 0:
done = True
else:
done = False
# Apply temperature noise
#self.state += random.randint(-1,1)
# Set placeholder for info
info = {}
# Return step information
return self.state, reward, done, info
def render(self):
# Implement viz
pass
def reset(self):
# Reset shower temperature
self.state = np.array([38 + random.randint(-3,3)]).astype(float)
# Reset shower time
self.shower_length = 60
return self.state
Checking Environment¶
# Checking whether our custom environment conforms with the necessary assertions
env = ShowerEnv()
check_env(env, warn=False)
# Observe the state initialization
env.reset()
env.render()
# Explore the state/observation and acion spaces
print(env.observation_space)
print(env.action_space)
action_samples = []
obs_samples = []
n_samples = 25
print('\nNum of Samples: ', n_samples)
for _ in range(n_samples):
action_samples.append(env.action_space.sample())
obs_samples.append(env.observation_space.sample())
print(len(np.unique(action_samples)), ':\t', np.unique(action_samples))
print(len(np.unique(obs_samples)), ':\t', np.unique(obs_samples), '\n')
# See that box has no attribte 'n'
print(env.action_space.n)
print(env.observation_space.n)
del n_samples, action_samples, obs_samples
Box(0.0, 100.0, (1,), float32)
Discrete(3)
Num of Samples: 25
3 : [0 1 2]
25 : [ 3.5996962 6.4574146 9.578733 13.806476 16.718254 31.083012
32.470306 39.09603 41.21765 41.593693 45.524384 46.594196
51.84048 60.28984 64.59786 69.74104 73.02744 76.58201
78.07482 78.15038 88.742775 92.168594 92.570755 95.42639
97.20068 ]
3
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-12-1ddd63988e80> in <module>
15 # See that box has no attribte 'n'
16 print(env.action_space.n)
---> 17 print(env.observation_space.n)
18
19 del n_samples, action_samples, obs_samples
AttributeError: 'Box' object has no attribute 'n'
Random action episodes¶
# Perform episodes of random actions, i.e. No decision making
episodes = 5
for episode in range(1, episodes+1):
state = env.reset()
done = False
score = 0
while not done:
env.render()
action = env.action_space.sample()
n_state, reward, done, info = env.step(action)
score+=reward
print('Episode:{} Score:{}'.format(episode, score))
env.close()
Episode:1 Score:-44
Episode:2 Score:-54
Episode:3 Score:-30
Episode:4 Score:-38
Episode:5 Score:-30
Defining DQN model¶
model = DQN("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=100000)
<stable_baselines3.dqn.dqn.DQN at 0x2837d4df320>
evaluate_policy(model, env, n_eval_episodes=100, render=False)
C:\Users\sarth\.conda\envs\master\lib\site-packages\stable_baselines3\common\evaluation.py:69: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.
UserWarning,
(-12.0, 58.787753826796276)
Learning model further¶
model.learn(total_timesteps=100000)
<stable_baselines3.dqn.dqn.DQN at 0x2837d4df320>
evaluate_policy(model, env, n_eval_episodes=100, render=False)
(-9.6, 59.227020860414704)
model.learn(total_timesteps=100000)
<stable_baselines3.dqn.dqn.DQN at 0x2837d4df320>
evaluate_policy(model, env, n_eval_episodes=100, render=False)
(-7.2, 59.56643350075611)
Defining PPO model¶
model = PPO("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=100000)
<stable_baselines3.ppo.ppo.PPO at 0x2837d4f4278>
evaluate_policy(model, env, n_eval_episodes=100, render=False)
(59.7, 0.714142842854285)
model.learn(total_timesteps=100000)
<stable_baselines3.ppo.ppo.PPO at 0x2837d4f4278>
evaluate_policy(model, env, n_eval_episodes=100, render=False)
(59.48, 0.8772684879784526)