mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)
* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
This commit is contained in:
@@ -756,6 +756,9 @@
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
|
||||
<span class="n">t</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">reward</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">'store_episode'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
|
||||
@@ -910,7 +913,8 @@
|
||||
<span class="c1"># update counters</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># if the batch returned empty then there are not enough samples in the replay buffer -> skip</span>
|
||||
<span class="c1"># training step</span>
|
||||
@@ -1020,7 +1024,8 @@
|
||||
<span class="c1"># informed action</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
|
||||
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
|
||||
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="ow">and</span> \
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
|
||||
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>
|
||||
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -1048,6 +1053,10 @@
|
||||
<span class="sd"> :return: The filtered state</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TODO actually we only want to run the observation filters. No point in running the reward filters as the</span>
|
||||
<span class="c1"># filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal</span>
|
||||
<span class="c1"># state).</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span>
|
||||
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_filter_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>
|
||||
|
||||
@@ -1177,7 +1186,7 @@
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent</span>
|
||||
<span class="sd"> has another master agent that is controlling it. In such cases, the master agent can define the goals for the</span>
|
||||
<span class="sd"> slave agent, define it's observation, possible actions, etc. The directive type is defined by the agent</span>
|
||||
<span class="sd"> slave agent, define its observation, possible actions, etc. The directive type is defined by the agent</span>
|
||||
<span class="sd"> in-action-space.</span>
|
||||
|
||||
<span class="sd"> :param action: The action that should be set as the directive</span>
|
||||
|
||||
@@ -295,7 +295,9 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimization_epochs</span> <span class="o">=</span> <span class="mi">10</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">clipping_decay_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">False</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">ClippedPPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
@@ -486,7 +488,9 @@
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)</span>
|
||||
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
|
||||
@@ -512,7 +516,9 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
|
||||
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
|
||||
<span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||||
|
||||
356
docs/_modules/rl_coach/agents/wolpertinger_agent.html
Normal file
356
docs/_modules/rl_coach/agents/wolpertinger_agent.html
Normal file
@@ -0,0 +1,356 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
|
||||
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>rl_coach.agents.wolpertinger_agent — Reinforcement Learning Coach 0.12.0 documentation</title>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/modernizr.min.js"></script>
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/jquery.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/underscore.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/doctools.js"></script>
|
||||
<script type="text/javascript" src="../../../_static/language_data.js"></script>
|
||||
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
<script type="text/javascript" src="../../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../../_static/css/custom.css" type="text/css" />
|
||||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../../search.html" />
|
||||
<link href="../../../_static/css/custom.css" rel="stylesheet" type="text/css">
|
||||
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../../index.html" class="icon icon-home"> Reinforcement Learning Coach
|
||||
|
||||
|
||||
|
||||
|
||||
<img src="../../../_static/dark_logo.png" class="logo" alt="Logo"/>
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Intro</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dist_usage.html">Usage - Distributed Coach</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../features/index.html">Features</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../selecting_an_algorithm.html">Selecting an Algorithm</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../dashboard.html">Coach Dashboard</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Design</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/control_flow.html">Control Flow</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/network.html">Network Design</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../design/horizontal_scaling.html">Distributed Coach - Horizontal Scale-Out</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Contributing</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_agent.html">Adding a New Agent</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../contributing/add_env.html">Adding a New Environment</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Components</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/agents/index.html">Agents</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/architectures/index.html">Architectures</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/data_stores/index.html">Data Stores</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/environments/index.html">Environments</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/exploration_policies/index.html">Exploration Policies</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/filters/index.html">Filters</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memories/index.html">Memories</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/memory_backends/index.html">Memory Backends</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/orchestrators/index.html">Orchestrators</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/core_types.html">Core Types</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/spaces.html">Spaces</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../components/additional_parameters.html">Additional Parameters</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../../index.html">Reinforcement Learning Coach</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../../index.html">Docs</a> »</li>
|
||||
|
||||
<li><a href="../../index.html">Module code</a> »</li>
|
||||
|
||||
<li>rl_coach.agents.wolpertinger_agent</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<h1>Source code for rl_coach.agents.wolpertinger_agent</h1><div class="highlight"><pre>
|
||||
<span></span><span class="c1">#</span>
|
||||
<span class="c1"># Copyright (c) 2019 Intel Corporation </span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||||
<span class="c1"># You may obtain a copy of the License at</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="c1">#</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">copy</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="k">import</span> <span class="n">Union</span>
|
||||
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
|
||||
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.agents.ddpg_agent</span> <span class="k">import</span> <span class="n">DDPGAlgorithmParameters</span><span class="p">,</span> <span class="n">DDPGActorNetworkParameters</span><span class="p">,</span> \
|
||||
<span class="n">DDPGCriticNetworkParameters</span><span class="p">,</span> <span class="n">DDPGAgent</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.base_parameters</span> <span class="k">import</span> <span class="n">AgentParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.core_types</span> <span class="k">import</span> <span class="n">ActionInfo</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.exploration_policies.additive_noise</span> <span class="k">import</span> <span class="n">AdditiveNoiseParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.episodic.episodic_experience_replay</span> <span class="k">import</span> <span class="n">EpisodicExperienceReplayParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.memories.non_episodic.differentiable_neural_dictionary</span> <span class="k">import</span> <span class="n">AnnoyDictionary</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.spaces</span> <span class="k">import</span> <span class="n">DiscreteActionSpace</span><span class="p">,</span> <span class="n">BoxActionSpace</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.architectures.head_parameters</span> <span class="k">import</span> <span class="n">WolpertingerActorHeadParameters</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">WolpertingerCriticNetworkParameters</span><span class="p">(</span><span class="n">DDPGCriticNetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">WolpertingerActorNetworkParameters</span><span class="p">(</span><span class="n">DDPGActorNetworkParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">heads_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">WolpertingerActorHeadParameters</span><span class="p">(</span><span class="n">batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="WolpertingerAlgorithmParameters"><a class="viewcode-back" href="../../../components/agents/policy_optimization/wolpertinger.html#rl_coach.agents.wolpertinger_agent.WolpertingerAlgorithmParameters">[docs]</a><span class="k">class</span> <span class="nc">WolpertingerAlgorithmParameters</span><span class="p">(</span><span class="n">DDPGAlgorithmParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_embedding_width</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">WolpertingerAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_batchnorm</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="n">exploration_params</span> <span class="o">=</span> <span class="n">AdditiveNoiseParameters</span><span class="p">()</span>
|
||||
<span class="n">exploration_params</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">algorithm</span><span class="o">=</span><span class="n">WolpertingerAlgorithmParameters</span><span class="p">(),</span>
|
||||
<span class="n">exploration</span><span class="o">=</span><span class="n">exploration_params</span><span class="p">,</span>
|
||||
<span class="n">memory</span><span class="o">=</span><span class="n">EpisodicExperienceReplayParameters</span><span class="p">(),</span>
|
||||
<span class="n">networks</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">(</span>
|
||||
<span class="p">[(</span><span class="s2">"actor"</span><span class="p">,</span> <span class="n">WolpertingerActorNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">)),</span>
|
||||
<span class="p">(</span><span class="s2">"critic"</span><span class="p">,</span> <span class="n">WolpertingerCriticNetworkParameters</span><span class="p">(</span><span class="n">use_batchnorm</span><span class="o">=</span><span class="n">use_batchnorm</span><span class="p">))]))</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="s1">'rl_coach.agents.wolpertinger_agent:WolpertingerAgent'</span>
|
||||
|
||||
|
||||
<span class="c1"># Deep Reinforcement Learning in Large Discrete Action Spaces - https://arxiv.org/pdf/1512.07679.pdf</span>
|
||||
<span class="k">class</span> <span class="nc">WolpertingerAgent</span><span class="p">(</span><span class="n">DDPGAgent</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s1">'LevelManager'</span><span class="p">,</span> <span class="s1">'CompositeAgent'</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">agent_parameters</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">learn_from_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
|
||||
<span class="c1"># replay buffer holds the actions in the discrete manner, as the agent is expected to act with discrete actions</span>
|
||||
<span class="c1"># with the BoxDiscretization output filter. But DDPG needs to work on continuous actions, thus converting to</span>
|
||||
<span class="c1"># continuous actions. This is actually a duplicate since this filtering is also done before applying actions on</span>
|
||||
<span class="c1"># the environment. So might want to somehow reuse that conversion. Maybe can hold this information in the info</span>
|
||||
<span class="c1"># dictionary of the transition.</span>
|
||||
|
||||
<span class="n">output_action_filter</span> <span class="o">=</span> \
|
||||
<span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">output_filter</span><span class="o">.</span><span class="n">action_filters</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">continuous_actions</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">action</span> <span class="ow">in</span> <span class="n">batch</span><span class="o">.</span><span class="n">actions</span><span class="p">():</span>
|
||||
<span class="n">continuous_actions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">output_action_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">action</span><span class="p">))</span>
|
||||
<span class="n">batch</span><span class="o">.</span><span class="n">_actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">continuous_actions</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">learn_from_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="n">DiscreteActionSpace</span><span class="p">):</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"WolpertingerAgent works only for discrete control problems"</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># convert to batch so we can run it through the network</span>
|
||||
<span class="n">tf_input_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_batch_for_inference</span><span class="p">(</span><span class="n">curr_state</span><span class="p">,</span> <span class="s1">'actor'</span><span class="p">)</span>
|
||||
<span class="n">actor_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'actor'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
<span class="n">critic_network</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">networks</span><span class="p">[</span><span class="s1">'critic'</span><span class="p">]</span><span class="o">.</span><span class="n">online_network</span>
|
||||
<span class="n">proto_action</span> <span class="o">=</span> <span class="n">actor_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
|
||||
<span class="n">proto_action</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exploration_policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">proto_action</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="n">nn_action_embeddings</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">knn_tree</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">keys</span><span class="o">=</span><span class="n">proto_action</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">k</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># now move the actions through the critic and choose the one with the highest q value</span>
|
||||
<span class="n">critic_inputs</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">tf_input_state</span><span class="p">)</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'observation'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'observation'</span><span class="p">],</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">critic_inputs</span><span class="p">[</span><span class="s1">'action'</span><span class="p">]</span> <span class="o">=</span> <span class="n">nn_action_embeddings</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">q_values</span> <span class="o">=</span> <span class="n">critic_network</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">critic_inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">action</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">indices</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_values</span><span class="p">)])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">action_signal</span><span class="o">.</span><span class="n">add_sample</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ActionInfo</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">,</span> <span class="n">action_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">init_environment_dependent_modules</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_environment_dependent_modules</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">knn_tree</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_initialized_knn</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># TODO - ideally the knn should not be defined here, but somehow be defined by the user in the preset</span>
|
||||
<span class="k">def</span> <span class="nf">get_initialized_knn</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="n">num_actions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span>
|
||||
<span class="n">action_max_abs_range</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="o">.</span><span class="n">max_abs_range</span> <span class="k">if</span> \
|
||||
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="p">,</span> <span class="s1">'filtered_action_space'</span><span class="p">)</span> <span class="ow">and</span>
|
||||
<span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spaces</span><span class="o">.</span><span class="n">action</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">))</span> \
|
||||
<span class="k">else</span> <span class="mf">1.0</span>
|
||||
<span class="n">keys</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_actions</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">num_actions</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">action_max_abs_range</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_actions</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">knn_tree</span> <span class="o">=</span> <span class="n">AnnoyDictionary</span><span class="p">(</span><span class="n">dict_size</span><span class="o">=</span><span class="n">num_actions</span><span class="p">,</span> <span class="n">key_width</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">action_embedding_width</span><span class="p">)</span>
|
||||
<span class="n">knn_tree</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">force_rebuild_tree</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">knn_tree</span>
|
||||
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2018-2019, Intel AI Lab
|
||||
|
||||
</p>
|
||||
</div>
|
||||
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -396,6 +396,14 @@
|
||||
<span class="c1"># Support for parameter noise</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">supports_parameter_noise</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># Override, in retrospective, all the episode rewards with the last reward in the episode</span>
|
||||
<span class="c1"># (sometimes useful for sparse, end of the episode, rewards problems)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># Filters - TODO consider creating a FilterParameters class and initialize the filters with it</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="PresetValidationParameters"><a class="viewcode-back" href="../../components/additional_parameters.html#rl_coach.base_parameters.PresetValidationParameters">[docs]</a><span class="k">class</span> <span class="nc">PresetValidationParameters</span><span class="p">(</span><span class="n">Parameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||||
|
||||
@@ -298,6 +298,12 @@
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_steps</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="n">EnvironmentSteps</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_steps</span> <span class="o">/</span> <span class="n">other</span><span class="o">.</span><span class="n">num_steps</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">Time</span><span class="p">(</span><span class="n">StepMethod</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">):</span>
|
||||
|
||||
@@ -200,15 +200,17 @@
|
||||
|
||||
<span class="kn">import</span> <span class="nn">uuid</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">NFSDataStoreParameters</span><span class="p">(</span><span class="n">DataStoreParameters</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ds_params</span><span class="p">,</span> <span class="n">deployed</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">server</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s2">""</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">ds_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_type</span><span class="p">,</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="s2">"default"</span>
|
||||
<span class="k">if</span> <span class="s2">"namespace"</span> <span class="ow">in</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">namespace</span> <span class="o">=</span> <span class="n">ds_params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s2">"namespace"</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="n">checkpoint_dir</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">pvc_name</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">pv_name</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
@@ -221,7 +223,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">path</span> <span class="o">=</span> <span class="n">path</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="NFSDataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.nfs_data_store.NFSDataStore">[docs]</a><span class="k">class</span> <span class="nc">NFSDataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.</span>
|
||||
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>
|
||||
|
||||
@@ -198,7 +198,8 @@
|
||||
<span class="c1">#</span>
|
||||
|
||||
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStore</span><span class="p">,</span> <span class="n">DataStoreParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.data_store</span> <span class="k">import</span> <span class="n">DataStoreParameters</span>
|
||||
<span class="kn">from</span> <span class="nn">rl_coach.data_stores.checkpoint_data_store</span> <span class="k">import</span> <span class="n">CheckpointDataStore</span>
|
||||
<span class="kn">from</span> <span class="nn">minio</span> <span class="k">import</span> <span class="n">Minio</span>
|
||||
<span class="kn">from</span> <span class="nn">minio.error</span> <span class="k">import</span> <span class="n">ResponseError</span>
|
||||
<span class="kn">from</span> <span class="nn">configparser</span> <span class="k">import</span> <span class="n">ConfigParser</span><span class="p">,</span> <span class="n">Error</span>
|
||||
@@ -222,7 +223,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">expt_dir</span> <span class="o">=</span> <span class="n">expt_dir</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">DataStore</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="S3DataStore"><a class="viewcode-back" href="../../../components/data_stores/index.html#rl_coach.data_stores.s3_data_store.S3DataStore">[docs]</a><span class="k">class</span> <span class="nc">S3DataStore</span><span class="p">(</span><span class="n">CheckpointDataStore</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.</span>
|
||||
<span class="sd"> The policy checkpoints are written by the trainer and read by the rollout worker.</span>
|
||||
|
||||
@@ -245,7 +245,9 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">evaluation_noise</span> <span class="o">=</span> <span class="n">evaluation_noise</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">noise_as_percentage_from_action_space</span> <span class="o">=</span> <span class="n">noise_as_percentage_from_action_space</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">and</span> \
|
||||
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="s1">'filtered_action_space'</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span>
|
||||
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)):</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Additive noise exploration works only for continuous controls."</span>
|
||||
<span class="s2">"The given action space is of type: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span>
|
||||
|
||||
|
||||
@@ -298,7 +298,10 @@
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> :param action_space: the action space used by the environment</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">)</span> <span class="ow">or</span> \
|
||||
<span class="p">(</span><span class="nb">hasattr</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="s1">'filtered_action_space'</span><span class="p">)</span> <span class="ow">and</span>
|
||||
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="o">.</span><span class="n">filtered_action_space</span><span class="p">,</span> <span class="n">BoxActionSpace</span><span class="p">))</span> <span class="ow">or</span> \
|
||||
<span class="nb">isinstance</span><span class="p">(</span><span class="n">action_space</span><span class="p">,</span> <span class="n">GoalsSpace</span><span class="p">)</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">action_space</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
|
||||
@@ -271,9 +271,6 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span>
|
||||
|
||||
<span class="c1"># scale the noise to the action space range</span>
|
||||
<span class="n">action_values_std</span> <span class="o">=</span> <span class="n">current_noise</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">high</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">low</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># extract the mean values</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">action_values</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="c1"># the action values are expected to be a list with the action mean and optionally the action stdev</span>
|
||||
|
||||
@@ -231,7 +231,8 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_unfiltered_action_space</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_action_space</span><span class="p">:</span> <span class="n">ActionSpace</span><span class="p">)</span> <span class="o">-></span> <span class="n">DiscreteActionSpace</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">output_action_space</span> <span class="o">=</span> <span class="n">output_action_space</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span> <span class="o">=</span> <span class="n">DiscreteActionSpace</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span> <span class="o">=</span> <span class="n">DiscreteActionSpace</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">target_actions</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">descriptions</span><span class="p">,</span>
|
||||
<span class="n">filtered_action_space</span><span class="o">=</span><span class="n">output_action_space</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_action_space</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">ActionType</span><span class="p">)</span> <span class="o">-></span> <span class="n">ActionType</span><span class="p">:</span>
|
||||
|
||||
@@ -261,11 +261,18 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">if</span> <span class="s1">'namespace'</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s1">'namespace'</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"default"</span>
|
||||
<span class="kn">from</span> <span class="nn">kubernetes</span> <span class="k">import</span> <span class="n">client</span>
|
||||
<span class="kn">from</span> <span class="nn">kubernetes</span> <span class="k">import</span> <span class="n">client</span><span class="p">,</span> <span class="n">config</span>
|
||||
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">redis_server_name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="s1">'redis:4-alpine'</span><span class="p">,</span>
|
||||
<span class="n">resources</span><span class="o">=</span><span class="n">client</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
|
||||
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
|
||||
<span class="s2">"cpu"</span><span class="p">:</span> <span class="s2">"8"</span><span class="p">,</span>
|
||||
<span class="s2">"memory"</span><span class="p">:</span> <span class="s2">"4Gi"</span>
|
||||
<span class="c1"># "nvidia.com/gpu": "0",</span>
|
||||
<span class="p">}</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">template</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">client</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">redis_server_name</span><span class="p">}),</span>
|
||||
@@ -288,8 +295,10 @@
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">deployment_spec</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">config</span><span class="o">.</span><span class="n">load_kube_config</span><span class="p">()</span>
|
||||
<span class="n">api_client</span> <span class="o">=</span> <span class="n">client</span><span class="o">.</span><span class="n">AppsV1Api</span><span class="p">()</span>
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">)</span>
|
||||
<span class="n">api_client</span><span class="o">.</span><span class="n">create_namespaced_deployment</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">orchestrator_params</span><span class="p">[</span><span class="s1">'namespace'</span><span class="p">],</span> <span class="n">deployment</span><span class="p">)</span>
|
||||
<span class="k">except</span> <span class="n">client</span><span class="o">.</span><span class="n">rest</span><span class="o">.</span><span class="n">ApiException</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Got exception: </span><span class="si">%s</span><span class="se">\n</span><span class="s2"> while creating redis-server"</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span>
|
||||
|
||||
@@ -240,7 +240,7 @@
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">built_capacity</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">additional_data</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">additional_data</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">force_rebuild_tree</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">additional_data</span><span class="p">:</span>
|
||||
<span class="n">additional_data</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span>
|
||||
|
||||
@@ -279,7 +279,7 @@
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffered_indices</span><span class="p">)</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_update_size</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">min_update_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_update_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_size</span> <span class="o">*</span> <span class="mf">0.02</span><span class="p">))</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_rebuild_index</span><span class="p">()</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">rebuild_on_every_update</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="n">force_rebuild_tree</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">rebuild_on_every_update</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_rebuild_index</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">current_timestamp</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
@@ -307,6 +307,11 @@
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">deploy</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_address</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_backend</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">redis_port</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_store</span><span class="o">.</span><span class="n">deploy</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
@@ -329,6 +334,8 @@
|
||||
<span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
|
||||
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
@@ -354,7 +361,7 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"s3"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
@@ -373,6 +380,34 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
<span class="n">command</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="n">trainer_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
|
||||
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">'Always'</span><span class="p">,</span>
|
||||
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
|
||||
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
|
||||
<span class="s2">"cpu"</span><span class="p">:</span> <span class="s2">"40"</span><span class="p">,</span>
|
||||
<span class="s2">"memory"</span><span class="p">:</span> <span class="s2">"4Gi"</span><span class="p">,</span>
|
||||
<span class="s2">"nvidia.com/gpu"</span><span class="p">:</span> <span class="s2">"1"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"unexpected store_type </span><span class="si">{}</span><span class="s2">. expected 's3', 'nfs', 'redis'"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span>
|
||||
<span class="p">))</span>
|
||||
|
||||
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
|
||||
<span class="n">completions</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
@@ -404,12 +439,17 @@
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">worker_params</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="kc">False</span>
|
||||
|
||||
<span class="c1"># At this point, the memory backend and data store have been deployed and in the process,</span>
|
||||
<span class="c1"># these parameters have been updated to include things like the hostname and port the</span>
|
||||
<span class="c1"># service can be found at.</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--memory_backend_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">memory_backend_parameters</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--data_store_params'</span><span class="p">,</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)]</span>
|
||||
<span class="n">worker_params</span><span class="o">.</span><span class="n">command</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'--num_workers'</span><span class="p">,</span> <span class="s1">'</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">)]</span>
|
||||
|
||||
<span class="n">name</span> <span class="o">=</span> <span class="s2">"</span><span class="si">{}</span><span class="s2">-</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">worker_params</span><span class="o">.</span><span class="n">run_type</span><span class="p">,</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
|
||||
|
||||
<span class="c1"># TODO: instead of defining each container and template spec from scratch, loaded default</span>
|
||||
<span class="c1"># configuration and modify them as necessary depending on the store type</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"nfs"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
@@ -435,7 +475,7 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"s3"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
@@ -454,6 +494,32 @@
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span> <span class="o">==</span> <span class="s2">"redis"</span><span class="p">:</span>
|
||||
<span class="n">container</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1Container</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
|
||||
<span class="n">image</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">image</span><span class="p">,</span>
|
||||
<span class="n">command</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">command</span><span class="p">,</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span>
|
||||
<span class="n">image_pull_policy</span><span class="o">=</span><span class="s1">'Always'</span><span class="p">,</span>
|
||||
<span class="n">stdin</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">tty</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">resources</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ResourceRequirements</span><span class="p">(</span>
|
||||
<span class="n">limits</span><span class="o">=</span><span class="p">{</span>
|
||||
<span class="s2">"cpu"</span><span class="p">:</span> <span class="s2">"8"</span><span class="p">,</span>
|
||||
<span class="s2">"memory"</span><span class="p">:</span> <span class="s2">"4Gi"</span><span class="p">,</span>
|
||||
<span class="c1"># "nvidia.com/gpu": "0",</span>
|
||||
<span class="p">}</span>
|
||||
<span class="p">),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">template</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodTemplateSpec</span><span class="p">(</span>
|
||||
<span class="n">metadata</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1ObjectMeta</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="p">{</span><span class="s1">'app'</span><span class="p">:</span> <span class="n">name</span><span class="p">}),</span>
|
||||
<span class="n">spec</span><span class="o">=</span><span class="n">k8sclient</span><span class="o">.</span><span class="n">V1PodSpec</span><span class="p">(</span>
|
||||
<span class="n">containers</span><span class="o">=</span><span class="p">[</span><span class="n">container</span><span class="p">],</span>
|
||||
<span class="n">restart_policy</span><span class="o">=</span><span class="s1">'Never'</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'unexpected store type </span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">data_store_params</span><span class="o">.</span><span class="n">store_type</span><span class="p">))</span>
|
||||
|
||||
<span class="n">job_spec</span> <span class="o">=</span> <span class="n">k8sclient</span><span class="o">.</span><span class="n">V1JobSpec</span><span class="p">(</span>
|
||||
<span class="n">completions</span><span class="o">=</span><span class="n">worker_params</span><span class="o">.</span><span class="n">num_replicas</span><span class="p">,</span>
|
||||
|
||||
@@ -568,7 +568,8 @@
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> A discrete action space with action indices as actions</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_actions</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_actions</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">descriptions</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span><span class="p">]</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_action</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">filtered_action_space</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">num_actions</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descriptions</span><span class="o">=</span><span class="n">descriptions</span><span class="p">)</span>
|
||||
<span class="c1"># the number of actions is mapped to high</span>
|
||||
|
||||
@@ -578,6 +579,9 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">default_action</span> <span class="o">=</span> <span class="n">default_action</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">filtered_action_space</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">filtered_action_space</span> <span class="o">=</span> <span class="n">filtered_action_space</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">actions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">ActionType</span><span class="p">]:</span>
|
||||
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">high</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||||
|
||||
Reference in New Issue
Block a user