mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
using the CoRL2017 experiment suite for CARLA_CIL
This commit is contained in:
@@ -29,6 +29,7 @@ class CILAlgorithmParameters(AlgorithmParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.collect_new_data = False
|
||||
self.state_key_with_the_class_index = 'high_level_command'
|
||||
|
||||
|
||||
class CILNetworkParameters(NetworkParameters):
|
||||
@@ -63,7 +64,7 @@ class CILAgent(ImitationAgent):
|
||||
self.current_high_level_control = 0
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
self.current_high_level_control = curr_state['high_level_command']
|
||||
self.current_high_level_control = curr_state[self.ap.algorithm.state_key_with_the_class_index]
|
||||
return super().choose_action(curr_state)
|
||||
|
||||
def extract_action_values(self, prediction):
|
||||
@@ -74,7 +75,7 @@ class CILAgent(ImitationAgent):
|
||||
|
||||
target_values = self.networks['main'].online_network.predict({**batch.states(network_keys)})
|
||||
|
||||
branch_to_update = batch.states(['high_level_command'])['high_level_command']
|
||||
branch_to_update = batch.states([self.ap.algorithm.state_key_with_the_class_index])[self.ap.algorithm.state_key_with_the_class_index]
|
||||
for idx, branch in enumerate(branch_to_update):
|
||||
target_values[branch][idx] = batch.actions()[idx]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user