mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Create a dataset using an agent (#306)
Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -40,6 +40,7 @@ class NECNetworkParameters(NetworkParameters):
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DNDQHeadParameters()]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.should_get_softmax_probabilities = False
|
||||
|
||||
|
||||
class NECAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -166,11 +167,25 @@ class NECAgent(ValueOptimizationAgent):
|
||||
|
||||
return super().act()
|
||||
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
def get_all_q_values_for_states(self, states: StateType, additional_outputs: List = None):
|
||||
# we need to store the state embeddings regardless if the action is random or not
|
||||
return self.get_prediction(states)
|
||||
return self.get_prediction_and_update_embeddings(states)
|
||||
|
||||
def get_prediction(self, states):
|
||||
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
|
||||
# get the actions q values and the state embedding
|
||||
embedding, actions_q_values, softmax_probabilities = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(states, 'main'),
|
||||
outputs=[self.networks['main'].online_network.state_embedding,
|
||||
self.networks['main'].online_network.output_heads[0].output,
|
||||
self.networks['main'].online_network.output_heads[0].softmax]
|
||||
)
|
||||
if self.phase != RunPhase.TEST:
|
||||
# store the state embedding for inserting it to the DND later
|
||||
self.current_episode_state_embeddings.append(embedding.squeeze())
|
||||
actions_q_values = actions_q_values[0][0]
|
||||
return actions_q_values, softmax_probabilities
|
||||
|
||||
def get_prediction_and_update_embeddings(self, states):
|
||||
# get the actions q values and the state embedding
|
||||
embedding, actions_q_values = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(states, 'main'),
|
||||
|
||||
Reference in New Issue
Block a user