You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.2 KiB
54 lines
1.2 KiB
import abc
|
|
import numpy as np
|
|
|
|
|
|
class Policy(object):
|
|
def __init__(self):
|
|
"""
|
|
Base class for all policies, has an abstract method predict().
|
|
"""
|
|
self.trainable = False
|
|
self.phase = None
|
|
self.model = None
|
|
self.modeldl = None
|
|
self.device = None
|
|
self.last_state = None
|
|
self.time_step = None
|
|
# if agent is assumed to know the dynamics of real world
|
|
self.env = None
|
|
|
|
@abc.abstractmethod
|
|
def configure(self, config):
|
|
return
|
|
|
|
def set_phase(self, phase):
|
|
self.phase = phase
|
|
|
|
def set_device(self, device):
|
|
self.device = device
|
|
|
|
def set_env(self, env):
|
|
self.env = env
|
|
|
|
def get_model(self):
|
|
return self.model
|
|
|
|
def get_modeldl(self):
|
|
return self.modeldl
|
|
|
|
@abc.abstractmethod
|
|
def predict(self, state):
|
|
"""
|
|
Policy takes state as input and output an action
|
|
|
|
"""
|
|
return
|
|
|
|
@staticmethod
|
|
def reach_destination(state):
|
|
self_state = state.self_state
|
|
if np.linalg.norm((self_state.py - self_state.gy, self_state.px - self_state.gx)) < self_state.radius:
|
|
return True
|
|
else:
|
|
return False
|