You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.2 KiB

import abc
import numpy as np
class Policy(object):
def __init__(self):
"""
Base class for all policies, has an abstract method predict().
"""
self.trainable = False
self.phase = None
self.model = None
self.modeldl = None
self.device = None
self.last_state = None
self.time_step = None
# if agent is assumed to know the dynamics of real world
self.env = None
@abc.abstractmethod
def configure(self, config):
return
def set_phase(self, phase):
self.phase = phase
def set_device(self, device):
self.device = device
def set_env(self, env):
self.env = env
def get_model(self):
return self.model
def get_modeldl(self):
return self.modeldl
@abc.abstractmethod
def predict(self, state):
"""
Policy takes state as input and output an action
"""
return
@staticmethod
def reach_destination(state):
self_state = state.self_state
if np.linalg.norm((self_state.py - self_state.gy, self_state.px - self_state.gx)) < self_state.radius:
return True
else:
return False