1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Create a dataset using an agent (#306)

Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
Gal Leibovich
2019-05-28 09:34:49 +03:00
committed by GitHub
parent 342b7184bc
commit 9e9c4fd332
26 changed files with 351 additions and 111 deletions

View File

@@ -16,7 +16,7 @@
import os
import pickle
from typing import Union
from typing import Union, List
import numpy as np
@@ -40,6 +40,7 @@ class NECNetworkParameters(NetworkParameters):
self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DNDQHeadParameters()]
self.optimizer_type = 'Adam'
self.should_get_softmax_probabilities = False
class NECAlgorithmParameters(AlgorithmParameters):
@@ -166,11 +167,25 @@ class NECAgent(ValueOptimizationAgent):
return super().act()
def get_all_q_values_for_states(self, states: StateType):
def get_all_q_values_for_states(self, states: StateType, additional_outputs: List = None):
# we need to store the state embeddings regardless if the action is random or not
return self.get_prediction(states)
return self.get_prediction_and_update_embeddings(states)
def get_prediction(self, states):
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
# get the actions q values and the state embedding
embedding, actions_q_values, softmax_probabilities = self.networks['main'].online_network.predict(
self.prepare_batch_for_inference(states, 'main'),
outputs=[self.networks['main'].online_network.state_embedding,
self.networks['main'].online_network.output_heads[0].output,
self.networks['main'].online_network.output_heads[0].softmax]
)
if self.phase != RunPhase.TEST:
# store the state embedding for inserting it to the DND later
self.current_episode_state_embeddings.append(embedding.squeeze())
actions_q_values = actions_q_values[0][0]
return actions_q_values, softmax_probabilities
def get_prediction_and_update_embeddings(self, states):
# get the actions q values and the state embedding
embedding, actions_q_values = self.networks['main'].online_network.predict(
self.prepare_batch_for_inference(states, 'main'),