· 4 years ago · Jun 18, 2021, 10:08 AM
1import argparse
2import gym
3import numpy as np
4import os
5from itertools import count
6
7import torch
8import torch.distributed.rpc as rpc
9import torch.multiprocessing as mp
10import torch.nn as nn
11import torch.nn.functional as F
12import torch.optim as optim
13from torch.distributed.rpc import RRef, rpc_sync, rpc_async, remote, BackendType
14from torch.distributions import Categorical
15import platform
16
17TOTAL_EPISODE_STEP = 5000
18AGENT_NAME = "agent"
19OBSERVER_NAME = "observer{}"
20
21parser = argparse.ArgumentParser(description='PyTorch RPC RL example')
22parser.add_argument('--world-size', type=int, default=2, metavar='W',
23 help='world size for RPC, rank 0 is the agent, others are observers')
24parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
25 help='discount factor (default: 0.99)')
26parser.add_argument('--seed', type=int, default=543, metavar='N',
27 help='random seed (default: 543)')
28parser.add_argument('--log-interval', type=int, default=10, metavar='N',
29 help='interval between training status logs (default: 10)')
30args = parser.parse_args()
31
32torch.manual_seed(args.seed)
33
34
35def _call_method(method, rref, *args, **kwargs):
36 r"""
37 a helper function to call a method on the given RRef
38 """
39 return method(rref.local_value(), *args, **kwargs)
40
41
42def _remote_method(method, rref, *args, **kwargs):
43 r"""
44 a helper function to run method on the owner of rref and fetch back the
45 result using RPC
46 """
47 args = [method, rref] + list(args)
48 return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
49
50
51class Policy(nn.Module):
52 r"""
53 Borrowing the ``Policy`` class from the Reinforcement Learning example.
54 Copying the code to make these two examples independent.
55 See https://github.com/pytorch/examples/tree/master/reinforcement_learning
56 """
57 def __init__(self):
58 super(Policy, self).__init__()
59 self.affine1 = nn.Linear(4, 128)
60 self.dropout = nn.Dropout(p=0.6)
61 self.affine2 = nn.Linear(128, 2)
62
63 self.saved_log_probs = []
64 self.rewards = []
65
66 def forward(self, x):
67 x = self.affine1(x)
68 x = self.dropout(x)
69 x = F.relu(x)
70 action_scores = self.affine2(x)
71 return F.softmax(action_scores, dim=1)
72
73class Observer:
74 r"""
75 An observer has exclusive access to its own environment. Each observer
76 captures the state from its environment, and send the state to the agent to
77 select an action. Then, the observer applies the action to its environment
78 and reports the reward to the agent.
79
80 It is true that CartPole-v1 is a relatively inexpensive environment, and it
81 might be an overkill to use RPC to connect observers and trainers in this
82 specific use case. However, the main goal of this tutorial to how to build
83 an application using the RPC API. Developers can extend the similar idea to
84 other applications with much more expensive environment.
85 """
86 def __init__(self):
87 self.id = rpc.get_worker_info().id
88 self.env = gym.make('CartPole-v1')
89 self.env.seed(args.seed)
90
91 def run_episode(self, agent_rref, n_steps):
92 r"""
93 Run one episode of n_steps.
94
95 Arguments:
96 agent_rref (RRef): an RRef referencing the agent object.
97 n_steps (int): number of steps in this episode
98 """
99 state, ep_reward = self.env.reset(), 0
100 for step in range(n_steps):
101 # send the state to the agent to get an action
102 action = _remote_method(Agent.select_action, agent_rref, self.id, state)
103
104 # apply the action to the environment, and get the reward
105 state, reward, done, _ = self.env.step(action)
106
107 # report the reward to the agent for training purpose
108 _remote_method(Agent.report_reward, agent_rref, self.id, reward)
109
110 if done:
111 break
112
113class Agent:
114 def __init__(self, world_size):
115 self.ob_rrefs = []
116 self.agent_rref = RRef(self)
117 self.rewards = {}
118 self.saved_log_probs = {}
119 self.policy = Policy()
120 self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
121 self.eps = np.finfo(np.float32).eps.item()
122 self.running_reward = 0
123 self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
124 for ob_rank in range(1, world_size):
125 ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
126 self.ob_rrefs.append(remote(ob_info, Observer))
127 self.rewards[ob_info.id] = []
128 self.saved_log_probs[ob_info.id] = []
129
130 def select_action(self, ob_id, state):
131 r"""
132 This function is mostly borrowed from the Reinforcement Learning example.
133 See https://github.com/pytorch/examples/tree/master/reinforcement_learning
134 The main difference is that instead of keeping all probs in one list,
135 the agent keeps probs in a dictionary, one key per observer.
136
137 NB: no need to enforce thread-safety here as GIL will serialize
138 executions.
139 """
140 state = torch.from_numpy(state).float().unsqueeze(0)
141 probs = self.policy(state)
142 m = Categorical(probs)
143 action = m.sample()
144 self.saved_log_probs[ob_id].append(m.log_prob(action))
145 return action.item()
146
147 def report_reward(self, ob_id, reward):
148 r"""
149 Observers call this function to report rewards.
150 """
151 self.rewards[ob_id].append(reward)
152
153 def run_episode(self, n_steps=0):
154 r"""
155 Run one episode. The agent will tell each oberser to run n_steps.
156 """
157 futs = []
158 for ob_rref in self.ob_rrefs:
159 # make async RPC to kick off an episode on all observers
160 futs.append(
161 rpc_async(
162 ob_rref.owner(),
163 _call_method,
164 args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
165 )
166 )
167
168 # wait until all obervers have finished this episode
169 for fut in futs:
170 fut.wait()
171
172 def finish_episode(self):
173 r"""
174 This function is mostly borrowed from the Reinforcement Learning example.
175 See https://github.com/pytorch/examples/tree/master/reinforcement_learning
176 The main difference is that it joins all probs and rewards from
177 different observers into one list, and uses the minimum observer rewards
178 as the reward of the current episode.
179 """
180
181 # joins probs and rewards from different observers into lists
182 R, probs, rewards = 0, [], []
183 for ob_id in self.rewards:
184 probs.extend(self.saved_log_probs[ob_id])
185 rewards.extend(self.rewards[ob_id])
186
187 # use the minimum observer reward to calculate the running reward
188 min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
189 self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
190
191 # clear saved probs and rewards
192 for ob_id in self.rewards:
193 self.rewards[ob_id] = []
194 self.saved_log_probs[ob_id] = []
195
196 policy_loss, returns = [], []
197 for r in rewards[::-1]:
198 R = r + args.gamma * R
199 returns.insert(0, R)
200 returns = torch.tensor(returns)
201 returns = (returns - returns.mean()) / (returns.std() + self.eps)
202 for log_prob, R in zip(probs, returns):
203 policy_loss.append(-log_prob * R)
204 self.optimizer.zero_grad()
205 policy_loss = torch.cat(policy_loss).sum()
206 policy_loss.backward()
207 self.optimizer.step()
208 return min_reward
209
210
211def run_worker(rank, world_size):
212 r"""
213 This is the entry point for all processes. The rank 0 is the agent. All
214 other ranks are observers.
215 """
216 host_list = os.getenv('MYHOSTLIST').split(",")
217 master = host_list[0].split("*", 1)[0]
218 os.environ['MASTER_ADDR'] = master
219 os.environ['MASTER_PORT'] = '29500'
220 print("Spawned worker on node: {}".format(platform.node()))
221 print("Master: {}".format(master))
222 if rank == 0:
223 # rank0 is the agent
224 #rpc.init_rpc(AGENT_NAME, backend=BackendType.PROCESS_GROUP, rank=rank, world_size=world_size)
225 rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
226
227 agent = Agent(world_size)
228 for i_episode in count(1):
229 n_steps = int(TOTAL_EPISODE_STEP / (args.world_size - 1))
230 agent.run_episode(n_steps=n_steps)
231 last_reward = agent.finish_episode()
232
233 if i_episode % args.log_interval == 0:
234 print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
235 i_episode, last_reward, agent.running_reward))
236
237 if agent.running_reward > agent.reward_threshold-400:
238 print("Solved! Running reward is now {}!".format(agent.running_reward))
239 break
240 else:
241 # other ranks are the observer
242 #rpc.init_rpc(OBSERVER_NAME.format(rank), backend=BackendType.PROCESS_GROUP, rank=rank, world_size=world_size)
243 rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
244 # observers passively waiting for instructions from agents
245 rpc.shutdown()
246
247
248def main():
249 #mp.spawn(
250 # run_worker,
251 # args=(args.world_size, ),
252 # nprocs=args.world_size,
253 # join=True
254 #)
255 run_worker(int(os.environ['SLURM_PROCID']), args.world_size)
256
257if __name__ == '__main__':
258 main()
259