· 6 years ago · Nov 06, 2019, 09:56 AM
1# -*- coding: utf-8 -*-
2import warnings
3from copy import deepcopy
4
5import numpy as np
6from keras.callbacks import History
7
8from rl.callbacks import TestLogger, TrainEpisodeLogger, TrainIntervalLogger, Visualizer, CallbackList
9
10
11class Agent(object):
12 """Abstract base class for all implemented agents.
13
14 Each agent interacts with the environment (as defined by the `Env` class) by first observing the
15 state of the environment. Based on this observation the agent changes the environment by performing
16 an action.
17
18 Do not use this abstract base class directly but instead use one of the concrete agents implemented.
19 Each agent realizes a reinforcement learning algorithm. Since all agents conform to the same
20 interface, you can use them interchangeably.
21
22 To implement your own agent, you have to implement the following methods:
23
24 - `forward`
25 - `backward`
26 - `compile`
27 - `load_weights`
28 - `save_weights`
29 - `layers`
30
31 # Arguments
32 processor (`Processor` instance): See [Processor](#processor) for details.
33 """
34 def __init__(self, processor=None):
35 self.processor = processor
36 self.training = False
37 self.step = 0
38
39 def get_config(self):
40 """Configuration of the agent for serialization.
41 """
42 return {}
43
44 def fit(self, env, nb_steps, action_repetition=1, callbacks=None, verbose=1,
45 visualize=False, nb_max_start_steps=0, start_step_policy=None, log_interval=10000,
46 nb_max_episode_steps=None):
47 """Trains the agent on the given environment.
48
49 # Arguments
50 env: (`Env` instance): Environment that the agent interacts with. See [Env](#env) for details.
51 nb_steps (integer): Number of training steps to be performed.
52 action_repetition (integer): Number of times the agent repeats the same action without
53 observing the environment again. Setting this to a value > 1 can be useful
54 if a single action only has a very small effect on the environment.
55 callbacks (list of `keras.callbacks.Callback` or `rl.callbacks.Callback` instances):
56 List of callbacks to apply during training. See [callbacks](/callbacks) for details.
57 verbose (integer): 0 for no logging, 1 for interval logging (compare `log_interval`), 2 for episode logging
58 visualize (boolean): If `True`, the environment is visualized during training. However,
59 this is likely going to slow down training significantly and is thus intended to be
60 a debugging instrument.
61 nb_max_start_steps (integer): Number of maximum steps that the agent performs at the beginning
62 of each episode using `start_step_policy`. Notice that this is an upper limit since
63 the exact number of steps to be performed is sampled uniformly from [0, max_start_steps]
64 at the beginning of each episode.
65 start_step_policy (`lambda observation: action`): The policy
66 to follow if `nb_max_start_steps` > 0. If set to `None`, a random action is performed.
67 log_interval (integer): If `verbose` = 1, the number of steps that are considered to be an interval.
68 nb_max_episode_steps (integer): Number of steps per episode that the agent performs before
69 automatically resetting the environment. Set to `None` if each episode should run
70 (potentially indefinitely) until the environment signals a terminal state.
71
72 # Returns
73 A `keras.callbacks.History` instance that recorded the entire training process.
74 """
75 if not self.compiled:
76 raise RuntimeError('Your tried to fit your agent but it hasn\'t been compiled yet. Please call `compile()` before `fit()`.')
77 if action_repetition < 1:
78 raise ValueError('action_repetition must be >= 1, is {}'.format(action_repetition))
79
80 self.training = True
81
82 callbacks = [] if not callbacks else callbacks[:]
83
84 if verbose == 1:
85 callbacks += [TrainIntervalLogger(interval=log_interval)]
86 elif verbose > 1:
87 callbacks += [TrainEpisodeLogger()]
88 if visualize:
89 callbacks += [Visualizer()]
90 history = History()
91 callbacks += [history]
92 callbacks = CallbackList(callbacks)
93 if hasattr(callbacks, 'set_model'):
94 callbacks.set_model(self)
95 else:
96 callbacks._set_model(self)
97 callbacks._set_env(env)
98 params = {
99 'nb_steps': nb_steps,
100 }
101 if hasattr(callbacks, 'set_params'):
102 callbacks.set_params(params)
103 else:
104 callbacks._set_params(params)
105 self._on_train_begin()
106 callbacks.on_train_begin()
107 self.stop_training = False
108 episode = 0
109 self.step = 0
110 observation = None
111 episode_reward = None
112 episode_step = None
113 did_abort = False
114 try:
115 while self.step < nb_steps:
116 if self.stop_training:
117 print("Early Stopping stopped training")
118 break
119 if observation is None: # start of a new episode
120 callbacks.on_episode_begin(episode)
121 episode_step = 0
122 episode_reward = 0.
123
124 # Obtain the initial observation by resetting the environment.
125 self.reset_states()
126 observation = deepcopy(env.reset())
127 if self.processor is not None:
128 observation = self.processor.process_observation(observation)
129 assert observation is not None
130
131 # Perform random starts at beginning of episode and do not record them into the experience.
132 # This slightly changes the start position between games.
133 nb_random_start_steps = 0 if nb_max_start_steps == 0 else np.random.randint(nb_max_start_steps)
134 for _ in range(nb_random_start_steps):
135 if start_step_policy is None:
136 action = env.action_space.sample()
137 else:
138 action = start_step_policy(observation)
139 if self.processor is not None:
140 action = self.processor.process_action(action)
141 callbacks.on_action_begin(action)
142 observation, reward, done, info = env.step(action)
143 observation = deepcopy(observation)
144 if self.processor is not None:
145 observation, reward, done, info = self.processor.process_step(observation, reward, done, info)
146 callbacks.on_action_end(action)
147 if done:
148 warnings.warn('Env ended before {} random steps could be performed at the start. You should probably lower the `nb_max_start_steps` parameter.'.format(nb_random_start_steps))
149 observation = deepcopy(env.reset())
150 if self.processor is not None:
151 observation = self.processor.process_observation(observation)
152 break
153
154 # At this point, we expect to be fully initialized.
155 assert episode_reward is not None
156 assert episode_step is not None
157 assert observation is not None
158
159 # Run a single step.
160 callbacks.on_step_begin(episode_step)
161 # This is were all of the work happens. We first perceive and compute the action
162 # (forward step) and then use the reward to improve (backward step).
163 action = self.forward(observation)
164 if self.processor is not None:
165 action = self.processor.process_action(action)
166 reward = 0.
167 accumulated_info = {}
168 done = False
169 for _ in range(action_repetition):
170 callbacks.on_action_begin(action)
171 observation, r, done, info = env.step(action)
172 observation = deepcopy(observation)
173 if self.processor is not None:
174 observation, r, done, info = self.processor.process_step(observation, r, done, info)
175 for key, value in info.items():
176 if not np.isreal(value):
177 continue
178 if key not in accumulated_info:
179 accumulated_info[key] = np.zeros_like(value)
180 accumulated_info[key] += value
181 callbacks.on_action_end(action)
182 reward += r
183 if done:
184 break
185 if nb_max_episode_steps and episode_step >= nb_max_episode_steps - 1:
186 # Force a terminal state.
187 done = True
188 metrics = self.backward(reward, terminal=done)
189 episode_reward += reward
190
191 step_logs = {
192 'action': action,
193 'observation': observation,
194 'reward': reward,
195 'metrics': metrics,
196 'loss': metrics[0],
197 'mae': metrics[1],
198 'mse': metrics[2],
199 'episode': episode,
200 'info': accumulated_info,
201 }
202 callbacks.on_step_end(episode_step, step_logs)
203 episode_step += 1
204 self.step += 1
205
206 if done:
207 # We are in a terminal state but the agent hasn't yet seen it. We therefore
208 # perform one more forward-backward call and simply ignore the action before
209 # resetting the environment. We need to pass in `terminal=False` here since
210 # the *next* state, that is the state of the newly reset environment, is
211 # always non-terminal by convention.
212 self.forward(observation)
213 self.backward(0., terminal=False)
214
215 # This episode is finished, report and reset.
216 episode_logs = {
217 'episode_reward': episode_reward,
218 'nb_episode_steps': episode_step,
219 'nb_steps': self.step,
220 'reward': reward,
221 'loss': metrics[0],
222 'mae': metrics[1],
223 'mse': metrics[2],
224 }
225 callbacks.on_episode_end(episode, episode_logs)
226
227 episode += 1
228 observation = None
229 episode_step = None
230 episode_reward = None
231
232 except KeyboardInterrupt:
233 # We catch keyboard interrupts here so that training can be be safely aborted.
234 # This is so common that we've built this right into this function, which ensures that
235 # the `on_train_end` method is properly called.
236 did_abort = True
237 callbacks.on_train_end(logs={'did_abort': did_abort})
238 self._on_train_end()
239
240 return history
241
242 def test(self, env, nb_episodes=1, action_repetition=1, callbacks=None, visualize=True,
243 nb_max_episode_steps=None, nb_max_start_steps=0, start_step_policy=None, verbose=1):
244 """Callback that is called before training begins."
245 """
246 if not self.compiled:
247 raise RuntimeError('Your tried to test your agent but it hasn\'t been compiled yet. Please call `compile()` before `test()`.')
248 if action_repetition < 1:
249 raise ValueError('action_repetition must be >= 1, is {}'.format(action_repetition))
250
251 self.training = False
252 self.step = 0
253
254 callbacks = [] if not callbacks else callbacks[:]
255
256 if verbose >= 1:
257 callbacks += [TestLogger()]
258 if visualize:
259 callbacks += [Visualizer()]
260 history = History()
261 callbacks += [history]
262 callbacks = CallbackList(callbacks)
263 if hasattr(callbacks, 'set_model'):
264 callbacks.set_model(self)
265 else:
266 callbacks._set_model(self)
267 callbacks._set_env(env)
268 params = {
269 'nb_episodes': nb_episodes,
270 }
271 if hasattr(callbacks, 'set_params'):
272 callbacks.set_params(params)
273 else:
274 callbacks._set_params(params)
275
276 self._on_test_begin()
277 callbacks.on_train_begin()
278 for episode in range(nb_episodes):
279 callbacks.on_episode_begin(episode)
280 episode_reward = 0.
281 episode_step = 0
282
283 # Obtain the initial observation by resetting the environment.
284 self.reset_states()
285 observation = deepcopy(env.reset())
286 if self.processor is not None:
287 observation = self.processor.process_observation(observation)
288 assert observation is not None
289
290 # Perform random starts at beginning of episode and do not record them into the experience.
291 # This slightly changes the start position between games.
292 nb_random_start_steps = 0 if nb_max_start_steps == 0 else np.random.randint(nb_max_start_steps)
293 for _ in range(nb_random_start_steps):
294 if start_step_policy is None:
295 action = env.action_space.sample()
296 else:
297 action = start_step_policy(observation)
298 if self.processor is not None:
299 action = self.processor.process_action(action)
300 callbacks.on_action_begin(action)
301 observation, r, done, info = env.step(action)
302 observation = deepcopy(observation)
303 if self.processor is not None:
304 observation, r, done, info = self.processor.process_step(observation, r, done, info)
305 callbacks.on_action_end(action)
306 if done:
307 warnings.warn('Env ended before {} random steps could be performed at the start. You should probably lower the `nb_max_start_steps` parameter.'.format(nb_random_start_steps))
308 observation = deepcopy(env.reset())
309 if self.processor is not None:
310 observation = self.processor.process_observation(observation)
311 break
312
313 # Run the episode until we're done.
314 done = False
315 while not done:
316 callbacks.on_step_begin(episode_step)
317
318 action = self.forward(observation)
319 if self.processor is not None:
320 action = self.processor.process_action(action)
321 reward = 0.
322 accumulated_info = {}
323 for _ in range(action_repetition):
324 callbacks.on_action_begin(action)
325 observation, r, d, info = env.step(action)
326 observation = deepcopy(observation)
327 if self.processor is not None:
328 observation, r, d, info = self.processor.process_step(observation, r, d, info)
329 callbacks.on_action_end(action)
330 reward += r
331 for key, value in info.items():
332 if not np.isreal(value):
333 continue
334 if key not in accumulated_info:
335 accumulated_info[key] = np.zeros_like(value)
336 accumulated_info[key] += value
337 if d:
338 done = True
339 break
340 if nb_max_episode_steps and episode_step >= nb_max_episode_steps - 1:
341 done = True
342 self.backward(reward, terminal=done)
343 episode_reward += reward
344
345 step_logs = {
346 'action': action,
347 'observation': observation,
348 'reward': reward,
349 'episode': episode,
350 'info': accumulated_info,
351 }
352 callbacks.on_step_end(episode_step, step_logs)
353 episode_step += 1
354 self.step += 1
355
356 # We are in a terminal state but the agent hasn't yet seen it. We therefore
357 # perform one more forward-backward call and simply ignore the action before
358 # resetting the environment. We need to pass in `terminal=False` here since
359 # the *next* state, that is the state of the newly reset environment, is
360 # always non-terminal by convention.
361 self.forward(observation)
362 self.backward(0., terminal=False)
363
364 # Report end of episode.
365 episode_logs = {
366 'episode_reward': episode_reward,
367 'nb_steps': episode_step,
368 }
369 callbacks.on_episode_end(episode, episode_logs)
370 callbacks.on_train_end()
371 self._on_test_end()
372
373 return history
374
375 def reset_states(self):
376 """Resets all internally kept states after an episode is completed.
377 """
378 pass
379
380 def forward(self, observation):
381 """Takes the an observation from the environment and returns the action to be taken next.
382 If the policy is implemented by a neural network, this corresponds to a forward (inference) pass.
383
384 # Argument
385 observation (object): The current observation from the environment.
386
387 # Returns
388 The next action to be executed in the environment.
389 """
390 raise NotImplementedError()
391
392 def backward(self, reward, terminal):
393 """Updates the agent after having executed the action returned by `forward`.
394 If the policy is implemented by a neural network, this corresponds to a weight update using back-prop.
395
396 # Argument
397 reward (float): The observed reward after executing the action returned by `forward`.
398 terminal (boolean): `True` if the new state of the environment is terminal.
399 """
400 raise NotImplementedError()
401
402 def compile(self, optimizer, metrics=[]):
403 """Compiles an agent and the underlaying models to be used for training and testing.
404
405 # Arguments
406 optimizer (`keras.optimizers.Optimizer` instance): The optimizer to be used during training.
407 metrics (list of functions `lambda y_true, y_pred: metric`): The metrics to run during training.
408 """
409 raise NotImplementedError()
410
411 def load_weights(self, filepath):
412 """Loads the weights of an agent from an HDF5 file.
413
414 # Arguments
415 filepath (str): The path to the HDF5 file.
416 """
417 raise NotImplementedError()
418
419 def save_weights(self, filepath, overwrite=False):
420 """Saves the weights of an agent as an HDF5 file.
421
422 # Arguments
423 filepath (str): The path to where the weights should be saved.
424 overwrite (boolean): If `False` and `filepath` already exists, raises an error.
425 """
426 raise NotImplementedError()
427
428 @property
429 def layers(self):
430 """Returns all layers of the underlying model(s).
431
432 If the concrete implementation uses multiple internal models,
433 this method returns them in a concatenated list.
434 """
435 raise NotImplementedError()
436
437 @property
438 def metrics_names(self):
439 """The human-readable names of the agent's metrics. Must return as many names as there
440 are metrics (see also `compile`).
441 """
442 return []
443
444 def _on_train_begin(self):
445 """Callback that is called before training begins."
446 """
447 pass
448
449 def _on_train_end(self):
450 """Callback that is called after training ends."
451 """
452 pass
453
454 def _on_test_begin(self):
455 """Callback that is called before testing begins."
456 """
457 pass
458
459 def _on_test_end(self):
460 """Callback that is called after testing ends."
461 """
462 pass
463
464
465class Processor(object):
466 """Abstract base class for implementing processors.
467
468 A processor acts as a coupling mechanism between an `Agent` and its `Env`. This can
469 be necessary if your agent has different requirements with respect to the form of the
470 observations, actions, and rewards of the environment. By implementing a custom processor,
471 you can effectively translate between the two without having to change the underlaying
472 implementation of the agent or environment.
473
474 Do not use this abstract base class directly but instead use one of the concrete implementations
475 or write your own.
476 """
477
478 def process_step(self, observation, reward, done, info):
479 """Processes an entire step by applying the processor to the observation, reward, and info arguments.
480
481 # Arguments
482 observation (object): An observation as obtained by the environment.
483 reward (float): A reward as obtained by the environment.
484 done (boolean): `True` if the environment is in a terminal state, `False` otherwise.
485 info (dict): The debug info dictionary as obtained by the environment.
486
487 # Returns
488 The tupel (observation, reward, done, reward) with with all elements after being processed.
489 """
490 observation = self.process_observation(observation)
491 reward = self.process_reward(reward)
492 info = self.process_info(info)
493 return observation, reward, done, info
494
495 def process_observation(self, observation):
496 """Processes the observation as obtained from the environment for use in an agent and
497 returns it.
498 """
499 return observation
500
501 def process_reward(self, reward):
502 """Processes the reward as obtained from the environment for use in an agent and
503 returns it.
504 """
505 return reward
506
507 def process_info(self, info):
508 """Processes the info as obtained from the environment for use in an agent and
509 returns it.
510 """
511 return info
512
513 def process_action(self, action):
514 """Processes an action predicted by an agent but before execution in an environment.
515 """
516 return action
517
518 def process_state_batch(self, batch):
519 """Processes an entire batch of states and returns it.
520 """
521 return batch
522
523 @property
524 def metrics(self):
525 """The metrics of the processor, which will be reported during training.
526
527 # Returns
528 List of `lambda y_true, y_pred: metric` functions.
529 """
530 return []
531
532 @property
533 def metrics_names(self):
534 """The human-readable names of the agent's metrics. Must return as many names as there
535 are metrics (see also `compile`).
536 """
537 return []
538
539
540# Note: the API of the `Env` and `Space` classes are taken from the OpenAI Gym implementation.
541# https://github.com/openai/gym/blob/master/gym/core.py
542
543
544class Env(object):
545 """The abstract environment class that is used by all agents. This class has the exact
546 same API that OpenAI Gym uses so that integrating with it is trivial. In contrast to the
547 OpenAI Gym implementation, this class only defines the abstract methods without any actual
548 implementation.
549 """
550 reward_range = (-np.inf, np.inf)
551 action_space = None
552 observation_space = None
553
554 def step(self, action):
555 """Run one timestep of the environment's dynamics.
556 Accepts an action and returns a tuple (observation, reward, done, info).
557
558 # Arguments
559 action (object): An action provided by the environment.
560
561 # Returns
562 observation (object): Agent's observation of the current environment.
563 reward (float) : Amount of reward returned after previous action.
564 done (boolean): Whether the episode has ended, in which case further step() calls will return undefined results.
565 info (dict): Contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
566 """
567 raise NotImplementedError()
568
569 def reset(self):
570 """
571 Resets the state of the environment and returns an initial observation.
572
573 # Returns
574 observation (object): The initial observation of the space. Initial reward is assumed to be 0.
575 """
576 raise NotImplementedError()
577
578 def render(self, mode='human', close=False):
579 """Renders the environment.
580 The set of supported modes varies per environment. (And some
581 environments do not support rendering at all.)
582
583 # Arguments
584 mode (str): The mode to render with.
585 close (bool): Close all open renderings.
586 """
587 raise NotImplementedError()
588
589 def close(self):
590 """Override in your subclass to perform any necessary cleanup.
591 Environments will automatically close() themselves when
592 garbage collected or when the program exits.
593 """
594 raise NotImplementedError()
595
596 def seed(self, seed=None):
597 """Sets the seed for this env's random number generator(s).
598
599 # Returns
600 Returns the list of seeds used in this env's random number generators
601 """
602 raise NotImplementedError()
603
604 def configure(self, *args, **kwargs):
605 """Provides runtime configuration to the environment.
606 This configuration should consist of data that tells your
607 environment how to run (such as an address of a remote server,
608 or path to your ImageNet data). It should not affect the
609 semantics of the environment.
610 """
611 raise NotImplementedError()
612
613 def __del__(self):
614 self.close()
615
616 def __str__(self):
617 return '<{} instance>'.format(type(self).__name__)
618
619
620class Space(object):
621 """Abstract model for a space that is used for the state and action spaces. This class has the
622 exact same API that OpenAI Gym uses so that integrating with it is trivial.
623 """
624
625 def sample(self, seed=None):
626 """Uniformly randomly sample a random element of this space.
627 """
628 raise NotImplementedError()
629
630 def contains(self, x):
631 """Return boolean specifying if x is a valid member of this space
632 """
633 raise NotImplementedError()