Merge pull request '修改源码目录' (#12) from develop into master
commit
9b8d8cf369
@ -0,0 +1,7 @@
|
|||||||
|
include LICENSE.txt
|
||||||
|
include README.md
|
||||||
|
include docs/en/whl_en.md
|
||||||
|
recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py
|
||||||
|
recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
|
||||||
|
|
||||||
|
recursive-include ppcls/ *.py *.txt
|
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = ['PaddleClas']
|
||||||
|
from .paddleclas import PaddleClas
|
||||||
|
from ppcls.arch.backbone import *
|
@ -0,0 +1,788 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
dependencies = ['paddle']
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class _SysPathG(object):
|
||||||
|
"""
|
||||||
|
_SysPathG used to add/clean path for sys.path. Making sure minimal pkgs dependents by skiping parent dirs.
|
||||||
|
|
||||||
|
__enter__
|
||||||
|
add path into sys.path
|
||||||
|
__exit__
|
||||||
|
clean user's sys.path to avoid unexpect behaviors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def __enter__(self, ):
|
||||||
|
sys.path.insert(0, self.path)
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
_p = sys.path.pop(0)
|
||||||
|
assert _p == self.path, 'Make sure sys.path cleaning {} correctly.'.format(
|
||||||
|
self.path)
|
||||||
|
|
||||||
|
|
||||||
|
with _SysPathG(os.path.dirname(os.path.abspath(__file__)), ):
|
||||||
|
import ppcls
|
||||||
|
import ppcls.arch.backbone as backbone
|
||||||
|
|
||||||
|
def ppclas_init():
|
||||||
|
if ppcls.utils.logger._logger is None:
|
||||||
|
ppcls.utils.logger.init_logger()
|
||||||
|
|
||||||
|
ppclas_init()
|
||||||
|
|
||||||
|
def _load_pretrained_parameters(model, name):
|
||||||
|
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(
|
||||||
|
name)
|
||||||
|
path = paddle.utils.download.get_weights_path_from_url(url)
|
||||||
|
model.set_state_dict(paddle.load(path))
|
||||||
|
return model
|
||||||
|
|
||||||
|
def alexnet(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
AlexNet
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `AlexNet` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.AlexNet(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vgg11(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
VGG11
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.VGG11(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vgg13(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
VGG13
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `VGG13` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.VGG13(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vgg16(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
VGG16
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `VGG16` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.VGG16(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vgg19(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
VGG19
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `VGG19` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.VGG19(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet18(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet18
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
input_image_channel: int=3. The number of input image channels
|
||||||
|
data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet18` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNet18(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet34
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
input_image_channel: int=3. The number of input image channels
|
||||||
|
data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet34` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNet34(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet50
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
input_image_channel: int=3. The number of input image channels
|
||||||
|
data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet50` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNet50(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet101
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
input_image_channel: int=3. The number of input image channels
|
||||||
|
data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet101` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNet101(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet152
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
input_image_channel: int=3. The number of input image channels
|
||||||
|
data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet152` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNet152(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def squeezenet1_0(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
SqueezeNet1_0
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `SqueezeNet1_0` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.SqueezeNet1_0(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def squeezenet1_1(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
SqueezeNet1_1
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `SqueezeNet1_1` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.SqueezeNet1_1(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def densenet121(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DenseNet121
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
dropout: float=0. Probability of setting units to zero.
|
||||||
|
bn_size: int=4. The number of channals per group
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `DenseNet121` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DenseNet121(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def densenet161(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DenseNet161
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
dropout: float=0. Probability of setting units to zero.
|
||||||
|
bn_size: int=4. The number of channals per group
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `DenseNet161` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DenseNet161(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def densenet169(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DenseNet169
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
dropout: float=0. Probability of setting units to zero.
|
||||||
|
bn_size: int=4. The number of channals per group
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `DenseNet169` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DenseNet169(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def densenet201(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DenseNet201
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
dropout: float=0. Probability of setting units to zero.
|
||||||
|
bn_size: int=4. The number of channals per group
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `DenseNet201` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DenseNet201(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def densenet264(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DenseNet264
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
dropout: float=0. Probability of setting units to zero.
|
||||||
|
bn_size: int=4. The number of channals per group
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `DenseNet264` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DenseNet264(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def inceptionv3(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
InceptionV3
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `InceptionV3` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.InceptionV3(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def inceptionv4(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
InceptionV4
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `InceptionV4` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.InceptionV4(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def googlenet(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
GoogLeNet
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `GoogLeNet` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.GoogLeNet(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def shufflenetv2_x0_25(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ShuffleNetV2_x0_25
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ShuffleNetV2_x0_25(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv1(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV1
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV1` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV1(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv1_x0_25(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV1_x0_25
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV1_x0_25(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv1_x0_5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV1_x0_5
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV1_x0_5(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv1_x0_75(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV1_x0_75
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV1_x0_75(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv2_x0_25(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV2_x0_25
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV2_x0_25` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV2_x0_25(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv2_x0_5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV2_x0_5
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV2_x0_5` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV2_x0_5(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv2_x0_75(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV2_x0_75
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV2_x0_75` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV2_x0_75(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv2_x1_5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV2_x1_5
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV2_x1_5` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV2_x1_5(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv2_x2_0(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV2_x2_0
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV2_x2_0` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV2_x2_0(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_large_x0_35(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_large_x0_35
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_large_x0_35(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_large_x0_5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_large_x0_5
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_large_x0_5(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_large_x0_75(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_large_x0_75
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_large_x0_75(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_large_x1_0(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_large_x1_0
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_large_x1_0(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_large_x1_25(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_large_x1_25
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_large_x1_25(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_small_x0_35(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_small_x0_35
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_small_x0_35(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_small_x0_5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_small_x0_5
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_small_x0_5(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_small_x0_75(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_small_x0_75
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_small_x0_75(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_small_x1_0(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_small_x1_0
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_small_x1_0(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def mobilenetv3_small_x1_25(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
MobileNetV3_small_x1_25
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.MobileNetV3_small_x1_25(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext101_32x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt101_32x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt101_32x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt101_32x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext101_64x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt101_64x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt101_64x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt101_64x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext152_32x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt152_32x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt152_32x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt152_32x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext152_64x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt152_64x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt152_64x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt152_64x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt50_32x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt50_32x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt50_32x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def resnext50_64x4d(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNeXt50_64x4d
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.ResNeXt50_64x4d(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def darknet53(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
DarkNet53
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
kwargs:
|
||||||
|
class_dim: int=1000. Output dim of last fc layer.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
|
||||||
|
"""
|
||||||
|
kwargs.update({'pretrained': pretrained})
|
||||||
|
model = backbone.DarkNet53(**kwargs)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,572 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(os.path.join(__dir__, ""))
|
||||||
|
sys.path.append(os.path.join(__dir__, "deploy"))
|
||||||
|
|
||||||
|
from typing import Union, Generator
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import textwrap
|
||||||
|
import tarfile
|
||||||
|
import requests
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
from deploy.python.predict_cls import ClsPredictor
|
||||||
|
from deploy.utils.get_image_list import get_image_list
|
||||||
|
from deploy.utils import config
|
||||||
|
|
||||||
|
from ppcls.arch.backbone import *
|
||||||
|
from ppcls.utils.logger import init_logger
|
||||||
|
|
||||||
|
# for building model with loading pretrained weights from backbone
|
||||||
|
init_logger()
|
||||||
|
|
||||||
|
__all__ = ["PaddleClas"]
|
||||||
|
|
||||||
|
BASE_DIR = os.path.expanduser("~/.paddleclas/")
|
||||||
|
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
|
||||||
|
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
|
||||||
|
BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
|
||||||
|
MODEL_SERIES = {
|
||||||
|
"AlexNet": ["AlexNet"],
|
||||||
|
"DarkNet": ["DarkNet53"],
|
||||||
|
"DeiT": [
|
||||||
|
"DeiT_base_distilled_patch16_224", "DeiT_base_distilled_patch16_384",
|
||||||
|
"DeiT_base_patch16_224", "DeiT_base_patch16_384",
|
||||||
|
"DeiT_small_distilled_patch16_224", "DeiT_small_patch16_224",
|
||||||
|
"DeiT_tiny_distilled_patch16_224", "DeiT_tiny_patch16_224"
|
||||||
|
],
|
||||||
|
"DenseNet": [
|
||||||
|
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
|
||||||
|
"DenseNet264"
|
||||||
|
],
|
||||||
|
"DLA": [
|
||||||
|
"DLA46_c", "DLA60x_c", "DLA34", "DLA60", "DLA60x", "DLA102", "DLA102x",
|
||||||
|
"DLA102x2", "DLA169"
|
||||||
|
],
|
||||||
|
"DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
|
||||||
|
"EfficientNet": [
|
||||||
|
"EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
|
||||||
|
"EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
|
||||||
|
"EfficientNetB6", "EfficientNetB7"
|
||||||
|
],
|
||||||
|
"ESNet": ["ESNet_x0_25", "ESNet_x0_5", "ESNet_x0_75", "ESNet_x1_0"],
|
||||||
|
"GhostNet":
|
||||||
|
["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
|
||||||
|
"HarDNet": ["HarDNet39_ds", "HarDNet68_ds", "HarDNet68", "HarDNet85"],
|
||||||
|
"HRNet": [
|
||||||
|
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
|
||||||
|
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
|
||||||
|
"HRNet_W48_C_ssld"
|
||||||
|
],
|
||||||
|
"Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
|
||||||
|
"MixNet": ["MixNet_S", "MixNet_M", "MixNet_L"],
|
||||||
|
"MobileNetV1": [
|
||||||
|
"MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
|
||||||
|
"MobileNetV1", "MobileNetV1_ssld"
|
||||||
|
],
|
||||||
|
"MobileNetV2": [
|
||||||
|
"MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
|
||||||
|
"MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
|
||||||
|
"MobileNetV2_ssld"
|
||||||
|
],
|
||||||
|
"MobileNetV3": [
|
||||||
|
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
|
||||||
|
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
|
||||||
|
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
|
||||||
|
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
|
||||||
|
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
|
||||||
|
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
|
||||||
|
],
|
||||||
|
"PPLCNet": [
|
||||||
|
"PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
|
||||||
|
"PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
|
||||||
|
],
|
||||||
|
"RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
|
||||||
|
"RegNet": ["RegNetX_4GF"],
|
||||||
|
"Res2Net": [
|
||||||
|
"Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
|
||||||
|
"Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
|
||||||
|
"Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
|
||||||
|
"Res2Net200_vd_26w_4s_ssld"
|
||||||
|
],
|
||||||
|
"ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
|
||||||
|
"ResNet": [
|
||||||
|
"ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
|
||||||
|
"ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
|
||||||
|
"ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
|
||||||
|
"ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
|
||||||
|
"ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
|
||||||
|
],
|
||||||
|
"ResNeXt": [
|
||||||
|
"ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
|
||||||
|
"ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
|
||||||
|
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
|
||||||
|
"ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
|
||||||
|
"Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
|
||||||
|
"ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
|
||||||
|
"ResNeXt152_vd_64x4d"
|
||||||
|
],
|
||||||
|
"ReXNet":
|
||||||
|
["ReXNet_1_0", "ReXNet_1_3", "ReXNet_1_5", "ReXNet_2_0", "ReXNet_3_0"],
|
||||||
|
"SENet": [
|
||||||
|
"SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
|
||||||
|
"SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
|
||||||
|
"SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
|
||||||
|
],
|
||||||
|
"ShuffleNetV2": [
|
||||||
|
"ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
|
||||||
|
"ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
|
||||||
|
"ShuffleNetV2_x2_0"
|
||||||
|
],
|
||||||
|
"SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
|
||||||
|
"SwinTransformer": [
|
||||||
|
"SwinTransformer_large_patch4_window7_224_22kto1k",
|
||||||
|
"SwinTransformer_large_patch4_window12_384_22kto1k",
|
||||||
|
"SwinTransformer_base_patch4_window7_224_22kto1k",
|
||||||
|
"SwinTransformer_base_patch4_window12_384_22kto1k",
|
||||||
|
"SwinTransformer_base_patch4_window12_384",
|
||||||
|
"SwinTransformer_base_patch4_window7_224",
|
||||||
|
"SwinTransformer_small_patch4_window7_224",
|
||||||
|
"SwinTransformer_tiny_patch4_window7_224"
|
||||||
|
],
|
||||||
|
"Twins": [
|
||||||
|
"pcpvt_small", "pcpvt_base", "pcpvt_large", "alt_gvt_small",
|
||||||
|
"alt_gvt_base", "alt_gvt_large"
|
||||||
|
],
|
||||||
|
"VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
|
||||||
|
"VisionTransformer": [
|
||||||
|
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
|
||||||
|
"ViT_large_patch16_224", "ViT_large_patch16_384",
|
||||||
|
"ViT_large_patch32_384", "ViT_small_patch16_224"
|
||||||
|
],
|
||||||
|
"Xception": [
|
||||||
|
"Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
|
||||||
|
"Xception71"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTypeError(Exception):
|
||||||
|
"""ImageTypeError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message=""):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class InputModelError(Exception):
|
||||||
|
"""InputModelError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message=""):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
def init_config(model_name,
|
||||||
|
inference_model_dir,
|
||||||
|
use_gpu=True,
|
||||||
|
batch_size=1,
|
||||||
|
topk=5,
|
||||||
|
**kwargs):
|
||||||
|
imagenet1k_map_path = os.path.join(
|
||||||
|
os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt")
|
||||||
|
cfg = {
|
||||||
|
"Global": {
|
||||||
|
"infer_imgs": kwargs["infer_imgs"]
|
||||||
|
if "infer_imgs" in kwargs else False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"inference_model_dir": inference_model_dir,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"use_gpu": use_gpu,
|
||||||
|
"enable_mkldnn": kwargs["enable_mkldnn"]
|
||||||
|
if "enable_mkldnn" in kwargs else False,
|
||||||
|
"cpu_num_threads": kwargs["cpu_num_threads"]
|
||||||
|
if "cpu_num_threads" in kwargs else 1,
|
||||||
|
"enable_benchmark": False,
|
||||||
|
"use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False,
|
||||||
|
"ir_optim": True,
|
||||||
|
"use_tensorrt": kwargs["use_tensorrt"]
|
||||||
|
if "use_tensorrt" in kwargs else False,
|
||||||
|
"gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000,
|
||||||
|
"enable_profile": False
|
||||||
|
},
|
||||||
|
"PreProcess": {
|
||||||
|
"transform_ops": [{
|
||||||
|
"ResizeImage": {
|
||||||
|
"resize_short": kwargs["resize_short"]
|
||||||
|
if "resize_short" in kwargs else 256
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"CropImage": {
|
||||||
|
"size": kwargs["crop_size"]
|
||||||
|
if "crop_size" in kwargs else 224
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"NormalizeImage": {
|
||||||
|
"scale": 0.00392157,
|
||||||
|
"mean": [0.485, 0.456, 0.406],
|
||||||
|
"std": [0.229, 0.224, 0.225],
|
||||||
|
"order": ''
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"ToCHWImage": None
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"PostProcess": {
|
||||||
|
"main_indicator": "Topk",
|
||||||
|
"Topk": {
|
||||||
|
"topk": topk,
|
||||||
|
"class_id_map_file": imagenet1k_map_path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if "save_dir" in kwargs:
|
||||||
|
if kwargs["save_dir"] is not None:
|
||||||
|
cfg["PostProcess"]["SavePreLabel"] = {
|
||||||
|
"save_dir": kwargs["save_dir"]
|
||||||
|
}
|
||||||
|
if "class_id_map_file" in kwargs:
|
||||||
|
if kwargs["class_id_map_file"] is not None:
|
||||||
|
cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
|
||||||
|
"class_id_map_file"]
|
||||||
|
|
||||||
|
cfg = config.AttrDict(cfg)
|
||||||
|
config.create_attr_dict(cfg)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def args_cfg():
|
||||||
|
def str2bool(v):
|
||||||
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--infer_imgs",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The image(s) to be predicted.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name", type=str, help="The model name to be used.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference_model_dir",
|
||||||
|
type=str,
|
||||||
|
help="The directory of model files. Valid when model_name not specifed."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gpu", type=str, default=True, help="Whether use GPU.")
|
||||||
|
parser.add_argument("--gpu_mem", type=int, default=8000, help="")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_mkldnn",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether use MKLDNN. Valid when use_gpu is False")
|
||||||
|
parser.add_argument("--cpu_num_threads", type=int, default=1, help="")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_tensorrt", type=str2bool, default=False, help="")
|
||||||
|
parser.add_argument("--use_fp16", type=str2bool, default=False, help="")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Batch size. Default by 1.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Return topk score(s) and corresponding results. Default by 5.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--class_id_map_file",
|
||||||
|
type=str,
|
||||||
|
help="The path of file that map class_id and label.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
help="The directory to save prediction results as pre-label.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--resize_short",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Resize according to short size.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop_size", type=int, default=224, help="Centor crop size.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return vars(args)
|
||||||
|
|
||||||
|
|
||||||
|
def print_info():
|
||||||
|
"""Print list of supported models in formatted.
|
||||||
|
"""
|
||||||
|
table = PrettyTable(["Series", "Name"])
|
||||||
|
try:
|
||||||
|
sz = os.get_terminal_size()
|
||||||
|
width = sz.columns - 30 if sz.columns > 50 else 10
|
||||||
|
except OSError:
|
||||||
|
width = 100
|
||||||
|
for series in MODEL_SERIES:
|
||||||
|
names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width)
|
||||||
|
table.add_row([series, names])
|
||||||
|
width = len(str(table).split("\n")[0])
|
||||||
|
print("{}".format("-" * width))
|
||||||
|
print("Models supported by PaddleClas".center(width))
|
||||||
|
print(table)
|
||||||
|
print("Powered by PaddlePaddle!".rjust(width))
|
||||||
|
print("{}".format("-" * width))
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_names():
|
||||||
|
"""Get the model names list.
|
||||||
|
"""
|
||||||
|
model_names = []
|
||||||
|
for series in MODEL_SERIES:
|
||||||
|
model_names += (MODEL_SERIES[series])
|
||||||
|
return model_names
|
||||||
|
|
||||||
|
|
||||||
|
def similar_architectures(name="", names=[], thresh=0.1, topk=10):
|
||||||
|
"""Find the most similar topk model names.
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for idx, n in enumerate(names):
|
||||||
|
if n.startswith("__"):
|
||||||
|
continue
|
||||||
|
score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
|
||||||
|
if score > thresh:
|
||||||
|
scores.append((idx, score))
|
||||||
|
scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
|
||||||
|
return similar_names
|
||||||
|
|
||||||
|
|
||||||
|
def download_with_progressbar(url, save_path):
|
||||||
|
"""Download from url with progressbar.
|
||||||
|
"""
|
||||||
|
if os.path.isfile(save_path):
|
||||||
|
os.remove(save_path)
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||||
|
block_size = 1024 # 1 Kibibyte
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
|
with open(save_path, "wb") as file:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
file.write(data)
|
||||||
|
progress_bar.close()
|
||||||
|
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile(
|
||||||
|
save_path):
|
||||||
|
raise Exception(
|
||||||
|
f"Something went wrong while downloading file from {url}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_file(model_name):
|
||||||
|
"""Check the model files exist and download and untar when no exist.
|
||||||
|
"""
|
||||||
|
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||||
|
model_name)
|
||||||
|
url = BASE_DOWNLOAD_URL.format(model_name)
|
||||||
|
|
||||||
|
tar_file_name_list = [
|
||||||
|
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
|
||||||
|
]
|
||||||
|
model_file_path = storage_directory("inference.pdmodel")
|
||||||
|
params_file_path = storage_directory("inference.pdiparams")
|
||||||
|
if not os.path.exists(model_file_path) or not os.path.exists(
|
||||||
|
params_file_path):
|
||||||
|
tmp_path = storage_directory(url.split("/")[-1])
|
||||||
|
print(f"download {url} to {tmp_path}")
|
||||||
|
os.makedirs(storage_directory(), exist_ok=True)
|
||||||
|
download_with_progressbar(url, tmp_path)
|
||||||
|
with tarfile.open(tmp_path, "r") as tarObj:
|
||||||
|
for member in tarObj.getmembers():
|
||||||
|
filename = None
|
||||||
|
for tar_file_name in tar_file_name_list:
|
||||||
|
if tar_file_name in member.name:
|
||||||
|
filename = tar_file_name
|
||||||
|
if filename is None:
|
||||||
|
continue
|
||||||
|
file = tarObj.extractfile(member)
|
||||||
|
with open(storage_directory(filename), "wb") as f:
|
||||||
|
f.write(file.read())
|
||||||
|
os.remove(tmp_path)
|
||||||
|
if not os.path.exists(model_file_path) or not os.path.exists(
|
||||||
|
params_file_path):
|
||||||
|
raise Exception(
|
||||||
|
f"Something went wrong while praparing the model[{model_name}] files!"
|
||||||
|
)
|
||||||
|
|
||||||
|
return storage_directory()
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleClas(object):
|
||||||
|
"""PaddleClas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print_info()
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name: str=None,
|
||||||
|
inference_model_dir: str=None,
|
||||||
|
use_gpu: bool=True,
|
||||||
|
batch_size: int=1,
|
||||||
|
topk: int=5,
|
||||||
|
**kwargs):
|
||||||
|
"""Init PaddleClas with config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str, optional): The model name supported by PaddleClas. If specified, override config. Defaults to None.
|
||||||
|
inference_model_dir (str, optional): The directory that contained model file and params file to be used. If specified, override config. Defaults to None.
|
||||||
|
use_gpu (bool, optional): Whether use GPU. If specified, override config. Defaults to True.
|
||||||
|
batch_size (int, optional): The batch size to pridict. If specified, override config. Defaults to 1.
|
||||||
|
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._config = init_config(model_name, inference_model_dir, use_gpu,
|
||||||
|
batch_size, topk, **kwargs)
|
||||||
|
self._check_input_model()
|
||||||
|
self.cls_predictor = ClsPredictor(self._config)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Get the config.
|
||||||
|
"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def _check_input_model(self):
|
||||||
|
"""Check input model name or model files.
|
||||||
|
"""
|
||||||
|
candidate_model_names = get_model_names()
|
||||||
|
input_model_name = self._config.Global.get("model_name", None)
|
||||||
|
inference_model_dir = self._config.Global.get("inference_model_dir",
|
||||||
|
None)
|
||||||
|
if input_model_name is not None:
|
||||||
|
similar_names = similar_architectures(input_model_name,
|
||||||
|
candidate_model_names)
|
||||||
|
similar_names_str = ", ".join(similar_names)
|
||||||
|
if input_model_name not in candidate_model_names:
|
||||||
|
err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
|
||||||
|
raise InputModelError(err)
|
||||||
|
self._config.Global.inference_model_dir = check_model_file(
|
||||||
|
input_model_name)
|
||||||
|
return
|
||||||
|
elif inference_model_dir is not None:
|
||||||
|
model_file_path = os.path.join(inference_model_dir,
|
||||||
|
"inference.pdmodel")
|
||||||
|
params_file_path = os.path.join(inference_model_dir,
|
||||||
|
"inference.pdiparams")
|
||||||
|
if not os.path.isfile(model_file_path) or not os.path.isfile(
|
||||||
|
params_file_path):
|
||||||
|
err = f"There is no model file or params file in this directory: {inference_model_dir}"
|
||||||
|
raise InputModelError(err)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
|
||||||
|
raise InputModelError(err)
|
||||||
|
return
|
||||||
|
|
||||||
|
def predict(self, input_data: Union[str, np.array],
|
||||||
|
print_pred: bool=False) -> Generator[list, None, None]:
|
||||||
|
"""Predict input_data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data (Union[str, np.array]):
|
||||||
|
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
|
||||||
|
When the type is np.array, it is the image data whose channel order is RGB.
|
||||||
|
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageTypeError: Illegal input_data.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generator[list, None, None]:
|
||||||
|
The prediction result(s) of input_data by batch_size. For every one image,
|
||||||
|
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
|
||||||
|
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(input_data, np.ndarray):
|
||||||
|
yield self.cls_predictor.predict(input_data)
|
||||||
|
elif isinstance(input_data, str):
|
||||||
|
if input_data.startswith("http") or input_data.startswith("https"):
|
||||||
|
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
|
||||||
|
if not os.path.exists(image_storage_dir()):
|
||||||
|
os.makedirs(image_storage_dir())
|
||||||
|
image_save_path = image_storage_dir("tmp.jpg")
|
||||||
|
download_with_progressbar(input_data, image_save_path)
|
||||||
|
input_data = image_save_path
|
||||||
|
warnings.warn(
|
||||||
|
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
|
||||||
|
)
|
||||||
|
image_list = get_image_list(input_data)
|
||||||
|
|
||||||
|
batch_size = self._config.Global.get("batch_size", 1)
|
||||||
|
topk = self._config.PostProcess.Topk.get('topk', 1)
|
||||||
|
|
||||||
|
img_list = []
|
||||||
|
img_path_list = []
|
||||||
|
cnt = 0
|
||||||
|
for idx, img_path in enumerate(image_list):
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
if img is None:
|
||||||
|
warnings.warn(
|
||||||
|
f"Image file failed to read and has been skipped. The path: {img_path}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img_list.append(img)
|
||||||
|
img_path_list.append(img_path)
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
if cnt % batch_size == 0 or (idx + 1) == len(image_list):
|
||||||
|
preds = self.cls_predictor.predict(img_list)
|
||||||
|
|
||||||
|
if print_pred and preds:
|
||||||
|
for idx, pred in enumerate(preds):
|
||||||
|
pred_str = ", ".join(
|
||||||
|
[f"{k}: {pred[k]}" for k in pred])
|
||||||
|
print(
|
||||||
|
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_list = []
|
||||||
|
img_path_list = []
|
||||||
|
yield preds
|
||||||
|
else:
|
||||||
|
err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
|
||||||
|
raise ImageTypeError(err)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# for CLI
|
||||||
|
def main():
|
||||||
|
"""Function API used for commad line.
|
||||||
|
"""
|
||||||
|
cfg = args_cfg()
|
||||||
|
clas_engine = PaddleClas(**cfg)
|
||||||
|
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
|
||||||
|
for _ in res:
|
||||||
|
pass
|
||||||
|
print("Predict complete!")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,11 @@
|
|||||||
|
prettytable
|
||||||
|
ujson
|
||||||
|
opencv-python==4.4.0.46
|
||||||
|
pillow
|
||||||
|
tqdm
|
||||||
|
PyYAML
|
||||||
|
visualdl >= 2.2.0
|
||||||
|
scipy
|
||||||
|
scikit-learn==0.23.2
|
||||||
|
gast==0.3.3
|
||||||
|
faiss-cpu==1.7.1.post2
|
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from io import open
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
with open('requirements.txt', encoding="utf-8-sig") as f:
|
||||||
|
requirements = f.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open(
|
||||||
|
'docs/en/inference_deployment/whl_deploy_en.md',
|
||||||
|
encoding="utf-8-sig") as f:
|
||||||
|
README = f.read()
|
||||||
|
return README
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='paddleclas',
|
||||||
|
packages=['paddleclas'],
|
||||||
|
package_dir={'paddleclas': ''},
|
||||||
|
include_package_data=True,
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": ["paddleclas= paddleclas.paddleclas:main"]
|
||||||
|
},
|
||||||
|
version='0.0.0',
|
||||||
|
install_requires=requirements,
|
||||||
|
license='Apache License 2.0',
|
||||||
|
description='Awesome Image Classification toolkits based on PaddlePaddle ',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/PaddlePaddle/PaddleClas',
|
||||||
|
download_url='https://github.com/PaddlePaddle/PaddleClas.git',
|
||||||
|
keywords=[
|
||||||
|
'A treasure chest for image classification powered by PaddlePaddle.'
|
||||||
|
],
|
||||||
|
classifiers=[
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Natural Language :: Chinese (Simplified)',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.2',
|
||||||
|
'Programming Language :: Python :: 3.3',
|
||||||
|
'Programming Language :: Python :: 3.4',
|
||||||
|
'Programming Language :: Python :: 3.5',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
|
||||||
|
], )
|
@ -1,3 +0,0 @@
|
|||||||
# 默认忽略的文件
|
|
||||||
/shelf/
|
|
||||||
/workspace.xml
|
|
@ -1,12 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="inheritedJdk" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
<component name="PyDocumentationSettings">
|
|
||||||
<option name="format" value="PLAIN" />
|
|
||||||
<option name="myDocStringFormat" value="Plain" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
@ -1,15 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<profile version="1.0">
|
|
||||||
<option name="myName" value="Project Default" />
|
|
||||||
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
|
||||||
<option name="ourVersions">
|
|
||||||
<value>
|
|
||||||
<list size="2">
|
|
||||||
<item index="0" class="java.lang.String" itemvalue="3.7" />
|
|
||||||
<item index="1" class="java.lang.String" itemvalue="3.8" />
|
|
||||||
</list>
|
|
||||||
</value>
|
|
||||||
</option>
|
|
||||||
</inspection_tool>
|
|
||||||
</profile>
|
|
||||||
</component>
|
|
@ -1,6 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<settings>
|
|
||||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
|
||||||
<version value="1.0" />
|
|
||||||
</settings>
|
|
||||||
</component>
|
|
@ -1,4 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
|
||||||
</project>
|
|
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/Search_2D.iml" filepath="$PROJECT_DIR$/.idea/Search_2D.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,222 +0,0 @@
|
|||||||
"""
|
|
||||||
ARA_star 2D (Anytime Repairing A*)
|
|
||||||
@author: huiming zhou
|
|
||||||
|
|
||||||
@description: local inconsistency: g-value decreased.
|
|
||||||
g(s) decreased introduces a local inconsistency between s and its successors.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class AraStar:
|
|
||||||
def __init__(self, s_start, s_goal, e, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env() # class Env
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
self.e = e # weight
|
|
||||||
|
|
||||||
self.g = dict() # Cost to come
|
|
||||||
self.OPEN = dict() # priority queue / OPEN set
|
|
||||||
self.CLOSED = set() # CLOSED set
|
|
||||||
self.INCONS = {} # INCONSISTENT set
|
|
||||||
self.PARENT = dict() # relations
|
|
||||||
self.path = [] # planning path
|
|
||||||
self.visited = [] # order of visited nodes
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
"""
|
|
||||||
initialize each set.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.g[self.s_start] = 0.0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
self.OPEN[self.s_start] = self.f_value(self.s_start)
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
|
|
||||||
def searching(self):
|
|
||||||
self.init()
|
|
||||||
self.ImprovePath()
|
|
||||||
self.path.append(self.extract_path())
|
|
||||||
|
|
||||||
while self.update_e() > 1: # continue condition
|
|
||||||
self.e -= 0.4 # increase weight
|
|
||||||
self.OPEN.update(self.INCONS)
|
|
||||||
self.OPEN = {s: self.f_value(s) for s in self.OPEN} # update f_value of OPEN set
|
|
||||||
|
|
||||||
self.INCONS = dict()
|
|
||||||
self.CLOSED = set()
|
|
||||||
self.ImprovePath() # improve path
|
|
||||||
self.path.append(self.extract_path())
|
|
||||||
|
|
||||||
return self.path, self.visited
|
|
||||||
|
|
||||||
def ImprovePath(self):
|
|
||||||
"""
|
|
||||||
:return: a e'-suboptimal path
|
|
||||||
"""
|
|
||||||
|
|
||||||
visited_each = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s, f_small = self.calc_smallest_f()
|
|
||||||
|
|
||||||
if self.f_value(self.s_goal) <= f_small:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.OPEN.pop(s)
|
|
||||||
self.CLOSED.add(s)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n in self.obs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g or new_cost < self.g[s_n]:
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
visited_each.append(s_n)
|
|
||||||
|
|
||||||
if s_n not in self.CLOSED:
|
|
||||||
self.OPEN[s_n] = self.f_value(s_n)
|
|
||||||
else:
|
|
||||||
self.INCONS[s_n] = 0.0
|
|
||||||
|
|
||||||
self.visited.append(visited_each)
|
|
||||||
|
|
||||||
def calc_smallest_f(self):
|
|
||||||
"""
|
|
||||||
:return: node with smallest f_value in OPEN set.
|
|
||||||
"""
|
|
||||||
|
|
||||||
s_small = min(self.OPEN, key=self.OPEN.get)
|
|
||||||
|
|
||||||
return s_small, self.OPEN[s_small]
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
return {(s[0] + u[0], s[1] + u[1]) for u in self.u_set}
|
|
||||||
|
|
||||||
def update_e(self):
|
|
||||||
v = float("inf")
|
|
||||||
|
|
||||||
if self.OPEN:
|
|
||||||
v = min(self.g[s] + self.h(s) for s in self.OPEN)
|
|
||||||
if self.INCONS:
|
|
||||||
v = min(v, min(self.g[s] + self.h(s) for s in self.INCONS))
|
|
||||||
|
|
||||||
return min(self.e, self.g[self.s_goal] / v)
|
|
||||||
|
|
||||||
def f_value(self, x):
|
|
||||||
"""
|
|
||||||
f = g + e * h
|
|
||||||
f = cost-to-come + weight * cost-to-go
|
|
||||||
:param x: current state
|
|
||||||
:return: f_value
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.g[x] + self.e * self.h(x)
|
|
||||||
|
|
||||||
def extract_path(self):
|
|
||||||
"""
|
|
||||||
Extract the path based on the PARENT set.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_goal]
|
|
||||||
s = self.s_goal
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s = self.PARENT[s]
|
|
||||||
path.append(s)
|
|
||||||
|
|
||||||
if s == self.s_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(path)
|
|
||||||
|
|
||||||
def h(self, s):
|
|
||||||
"""
|
|
||||||
Calculate heuristic.
|
|
||||||
:param s: current node (state)
|
|
||||||
:return: heuristic function value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
goal = self.s_goal # goal node
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return math.inf
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
"""
|
|
||||||
check if the line segment (s_start, s_end) is collision.
|
|
||||||
:param s_start: start node
|
|
||||||
:param s_end: end node
|
|
||||||
:return: True: is collision / False: not collision
|
|
||||||
"""
|
|
||||||
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
arastar = AraStar(s_start, s_goal, 2.5, "euclidean")
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = arastar.searching()
|
|
||||||
plot.animation_ara_star(path, visited, "Anytime Repairing A* (ARA*)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,317 +0,0 @@
|
|||||||
"""
|
|
||||||
Anytime_D_star 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting
|
|
||||||
from Search_2D import env
|
|
||||||
|
|
||||||
|
|
||||||
class ADStar:
|
|
||||||
def __init__(self, s_start, s_goal, eps, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env() # class Env
|
|
||||||
self.Plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
self.x = self.Env.x_range
|
|
||||||
self.y = self.Env.y_range
|
|
||||||
|
|
||||||
self.g, self.rhs, self.OPEN = {}, {}, {}
|
|
||||||
|
|
||||||
for i in range(1, self.Env.x_range - 1):
|
|
||||||
for j in range(1, self.Env.y_range - 1):
|
|
||||||
self.rhs[(i, j)] = float("inf")
|
|
||||||
self.g[(i, j)] = float("inf")
|
|
||||||
|
|
||||||
self.rhs[self.s_goal] = 0.0
|
|
||||||
self.eps = eps
|
|
||||||
self.OPEN[self.s_goal] = self.Key(self.s_goal)
|
|
||||||
self.CLOSED, self.INCONS = set(), dict()
|
|
||||||
|
|
||||||
self.visited = set()
|
|
||||||
self.count = 0
|
|
||||||
self.count_env_change = 0
|
|
||||||
self.obs_add = set()
|
|
||||||
self.obs_remove = set()
|
|
||||||
self.title = "Anytime D*: Small changes" # Significant changes
|
|
||||||
self.fig = plt.figure()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.Plot.plot_grid(self.title)
|
|
||||||
self.ComputeOrImprovePath()
|
|
||||||
self.plot_visited()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
self.visited = set()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if self.eps <= 1.0:
|
|
||||||
break
|
|
||||||
self.eps -= 0.5
|
|
||||||
self.OPEN.update(self.INCONS)
|
|
||||||
for s in self.OPEN:
|
|
||||||
self.OPEN[s] = self.Key(s)
|
|
||||||
self.CLOSED = set()
|
|
||||||
self.ComputeOrImprovePath()
|
|
||||||
self.plot_visited()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
self.visited = set()
|
|
||||||
plt.pause(0.5)
|
|
||||||
|
|
||||||
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def on_press(self, event):
|
|
||||||
x, y = event.xdata, event.ydata
|
|
||||||
if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
|
|
||||||
print("Please choose right area!")
|
|
||||||
else:
|
|
||||||
self.count_env_change += 1
|
|
||||||
x, y = int(x), int(y)
|
|
||||||
print("Change position: s =", x, ",", "y =", y)
|
|
||||||
|
|
||||||
# for small changes
|
|
||||||
if self.title == "Anytime D*: Small changes":
|
|
||||||
if (x, y) not in self.obs:
|
|
||||||
self.obs.add((x, y))
|
|
||||||
self.g[(x, y)] = float("inf")
|
|
||||||
self.rhs[(x, y)] = float("inf")
|
|
||||||
else:
|
|
||||||
self.obs.remove((x, y))
|
|
||||||
self.UpdateState((x, y))
|
|
||||||
|
|
||||||
self.Plot.update_obs(self.obs)
|
|
||||||
|
|
||||||
for sn in self.get_neighbor((x, y)):
|
|
||||||
self.UpdateState(sn)
|
|
||||||
|
|
||||||
plt.cla()
|
|
||||||
self.Plot.plot_grid(self.title)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if len(self.INCONS) == 0:
|
|
||||||
break
|
|
||||||
self.OPEN.update(self.INCONS)
|
|
||||||
for s in self.OPEN:
|
|
||||||
self.OPEN[s] = self.Key(s)
|
|
||||||
self.CLOSED = set()
|
|
||||||
self.ComputeOrImprovePath()
|
|
||||||
self.plot_visited()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
# plt.plot(self.title)
|
|
||||||
self.visited = set()
|
|
||||||
|
|
||||||
if self.eps <= 1.0:
|
|
||||||
break
|
|
||||||
|
|
||||||
else:
|
|
||||||
if (x, y) not in self.obs:
|
|
||||||
self.obs.add((x, y))
|
|
||||||
self.obs_add.add((x, y))
|
|
||||||
plt.plot(x, y, 'sk')
|
|
||||||
if (x, y) in self.obs_remove:
|
|
||||||
self.obs_remove.remove((x, y))
|
|
||||||
else:
|
|
||||||
self.obs.remove((x, y))
|
|
||||||
self.obs_remove.add((x, y))
|
|
||||||
plt.plot(x, y, marker='s', color='white')
|
|
||||||
if (x, y) in self.obs_add:
|
|
||||||
self.obs_add.remove((x, y))
|
|
||||||
|
|
||||||
self.Plot.update_obs(self.obs)
|
|
||||||
|
|
||||||
if self.count_env_change >= 15:
|
|
||||||
self.count_env_change = 0
|
|
||||||
self.eps += 2.0
|
|
||||||
for s in self.obs_add:
|
|
||||||
self.g[(x, y)] = float("inf")
|
|
||||||
self.rhs[(x, y)] = float("inf")
|
|
||||||
|
|
||||||
for sn in self.get_neighbor(s):
|
|
||||||
self.UpdateState(sn)
|
|
||||||
|
|
||||||
for s in self.obs_remove:
|
|
||||||
for sn in self.get_neighbor(s):
|
|
||||||
self.UpdateState(sn)
|
|
||||||
self.UpdateState(s)
|
|
||||||
|
|
||||||
plt.cla()
|
|
||||||
self.Plot.plot_grid(self.title)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if self.eps <= 1.0:
|
|
||||||
break
|
|
||||||
self.eps -= 0.5
|
|
||||||
self.OPEN.update(self.INCONS)
|
|
||||||
for s in self.OPEN:
|
|
||||||
self.OPEN[s] = self.Key(s)
|
|
||||||
self.CLOSED = set()
|
|
||||||
self.ComputeOrImprovePath()
|
|
||||||
self.plot_visited()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
plt.title(self.title)
|
|
||||||
self.visited = set()
|
|
||||||
plt.pause(0.5)
|
|
||||||
|
|
||||||
self.fig.canvas.draw_idle()
|
|
||||||
|
|
||||||
def ComputeOrImprovePath(self):
|
|
||||||
while True:
|
|
||||||
s, v = self.TopKey()
|
|
||||||
if v >= self.Key(self.s_start) and \
|
|
||||||
self.rhs[self.s_start] == self.g[self.s_start]:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.OPEN.pop(s)
|
|
||||||
self.visited.add(s)
|
|
||||||
|
|
||||||
if self.g[s] > self.rhs[s]:
|
|
||||||
self.g[s] = self.rhs[s]
|
|
||||||
self.CLOSED.add(s)
|
|
||||||
for sn in self.get_neighbor(s):
|
|
||||||
self.UpdateState(sn)
|
|
||||||
else:
|
|
||||||
self.g[s] = float("inf")
|
|
||||||
for sn in self.get_neighbor(s):
|
|
||||||
self.UpdateState(sn)
|
|
||||||
self.UpdateState(s)
|
|
||||||
|
|
||||||
def UpdateState(self, s):
|
|
||||||
if s != self.s_goal:
|
|
||||||
self.rhs[s] = float("inf")
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
self.rhs[s] = min(self.rhs[s], self.g[x] + self.cost(s, x))
|
|
||||||
if s in self.OPEN:
|
|
||||||
self.OPEN.pop(s)
|
|
||||||
|
|
||||||
if self.g[s] != self.rhs[s]:
|
|
||||||
if s not in self.CLOSED:
|
|
||||||
self.OPEN[s] = self.Key(s)
|
|
||||||
else:
|
|
||||||
self.INCONS[s] = 0
|
|
||||||
|
|
||||||
def Key(self, s):
|
|
||||||
if self.g[s] > self.rhs[s]:
|
|
||||||
return [self.rhs[s] + self.eps * self.h(self.s_start, s), self.rhs[s]]
|
|
||||||
else:
|
|
||||||
return [self.g[s] + self.h(self.s_start, s), self.g[s]]
|
|
||||||
|
|
||||||
def TopKey(self):
|
|
||||||
"""
|
|
||||||
:return: return the min key and its value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
s = min(self.OPEN, key=self.OPEN.get)
|
|
||||||
return s, self.OPEN[s]
|
|
||||||
|
|
||||||
def h(self, s_start, s_goal):
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(s_goal[0] - s_start[0]) + abs(s_goal[1] - s_start[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
nei_list = set()
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
nei_list.add(s_next)
|
|
||||||
|
|
||||||
return nei_list
|
|
||||||
|
|
||||||
def extract_path(self):
|
|
||||||
"""
|
|
||||||
Extract the path based on the PARENT set.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_start]
|
|
||||||
s = self.s_start
|
|
||||||
|
|
||||||
for k in range(100):
|
|
||||||
g_list = {}
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
if not self.is_collision(s, x):
|
|
||||||
g_list[x] = self.g[x]
|
|
||||||
s = min(g_list, key=g_list.get)
|
|
||||||
path.append(s)
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(path)
|
|
||||||
|
|
||||||
def plot_path(self, path):
|
|
||||||
px = [x[0] for x in path]
|
|
||||||
py = [x[1] for x in path]
|
|
||||||
plt.plot(px, py, linewidth=2)
|
|
||||||
plt.plot(self.s_start[0], self.s_start[1], "bs")
|
|
||||||
plt.plot(self.s_goal[0], self.s_goal[1], "gs")
|
|
||||||
|
|
||||||
def plot_visited(self):
|
|
||||||
self.count += 1
|
|
||||||
|
|
||||||
color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
|
|
||||||
'bisque', 'navajowhite', 'moccasin', 'wheat',
|
|
||||||
'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
|
|
||||||
|
|
||||||
if self.count >= len(color) - 1:
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
for x in self.visited:
|
|
||||||
plt.plot(x[0], x[1], marker='s', color=color[self.count])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
dstar = ADStar(s_start, s_goal, 2.5, "euclidean")
|
|
||||||
dstar.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,225 +0,0 @@
|
|||||||
"""
|
|
||||||
A_star 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class AStar:
|
|
||||||
"""AStar set the cost + heuristics as the priority
|
|
||||||
"""
|
|
||||||
def __init__(self, s_start, s_goal, heuristic_type):
|
|
||||||
self.s_start = s_start
|
|
||||||
self.s_goal = s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env() # class Env
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
|
|
||||||
self.OPEN = [] # priority queue / OPEN set
|
|
||||||
self.CLOSED = [] # CLOSED set / VISITED order
|
|
||||||
self.PARENT = dict() # recorded parent
|
|
||||||
self.g = dict() # cost to come
|
|
||||||
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
A_star Searching.
|
|
||||||
:return: path, visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
self.g[self.s_start] = 0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
heapq.heappush(self.OPEN,
|
|
||||||
(self.f_value(self.s_start), self.s_start))
|
|
||||||
|
|
||||||
while self.OPEN:
|
|
||||||
_, s = heapq.heappop(self.OPEN)
|
|
||||||
self.CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal: # stop condition
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g:
|
|
||||||
self.g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g[s_n]: # conditions for updating Cost
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
heapq.heappush(self.OPEN, (self.f_value(s_n), s_n))
|
|
||||||
|
|
||||||
return self.extract_path(self.PARENT), self.CLOSED
|
|
||||||
|
|
||||||
def searching_repeated_astar(self, e):
|
|
||||||
"""
|
|
||||||
repeated A*.
|
|
||||||
:param e: weight of A*
|
|
||||||
:return: path and visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
path, visited = [], []
|
|
||||||
|
|
||||||
while e >= 1:
|
|
||||||
p_k, v_k = self.repeated_searching(self.s_start, self.s_goal, e)
|
|
||||||
path.append(p_k)
|
|
||||||
visited.append(v_k)
|
|
||||||
e -= 0.5
|
|
||||||
|
|
||||||
return path, visited
|
|
||||||
|
|
||||||
def repeated_searching(self, s_start, s_goal, e):
|
|
||||||
"""
|
|
||||||
run A* with weight e.
|
|
||||||
:param s_start: starting state
|
|
||||||
:param s_goal: goal state
|
|
||||||
:param e: weight of a*
|
|
||||||
:return: path and visited order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
g = {s_start: 0, s_goal: float("inf")}
|
|
||||||
PARENT = {s_start: s_start}
|
|
||||||
OPEN = []
|
|
||||||
CLOSED = []
|
|
||||||
heapq.heappush(OPEN,
|
|
||||||
(g[s_start] + e * self.heuristic(s_start), s_start))
|
|
||||||
|
|
||||||
while OPEN:
|
|
||||||
_, s = heapq.heappop(OPEN)
|
|
||||||
CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in g:
|
|
||||||
g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < g[s_n]: # conditions for updating Cost
|
|
||||||
g[s_n] = new_cost
|
|
||||||
PARENT[s_n] = s
|
|
||||||
heapq.heappush(OPEN, (g[s_n] + e * self.heuristic(s_n), s_n))
|
|
||||||
|
|
||||||
return self.extract_path(PARENT), CLOSED
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return math.inf
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
"""
|
|
||||||
check if the line segment (s_start, s_end) is collision.
|
|
||||||
:param s_start: start node
|
|
||||||
:param s_end: end node
|
|
||||||
:return: True: is collision / False: not collision
|
|
||||||
"""
|
|
||||||
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def f_value(self, s):
|
|
||||||
"""
|
|
||||||
f = g + h. (g: Cost to come, h: heuristic value)
|
|
||||||
:param s: current state
|
|
||||||
:return: f
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.g[s] + self.heuristic(s)
|
|
||||||
|
|
||||||
def extract_path(self, PARENT):
|
|
||||||
"""
|
|
||||||
Extract the path based on the PARENT set.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_goal]
|
|
||||||
s = self.s_goal
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s = PARENT[s]
|
|
||||||
path.append(s)
|
|
||||||
|
|
||||||
if s == self.s_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(path)
|
|
||||||
|
|
||||||
def heuristic(self, s):
|
|
||||||
"""
|
|
||||||
Calculate heuristic.
|
|
||||||
:param s: current node (state)
|
|
||||||
:return: heuristic function value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
goal = self.s_goal # goal node
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
astar = AStar(s_start, s_goal, "euclidean")
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = astar.searching()
|
|
||||||
plot.animation(path, visited, "A*") # animation
|
|
||||||
|
|
||||||
# path, visited = astar.searching_repeated_astar(2.5) # initial weight e = 2.5
|
|
||||||
# plot.animation_ara_star(path, visited, "Repeated A*")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,68 +0,0 @@
|
|||||||
"""
|
|
||||||
Best-First Searching
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
from Search_2D.Astar import AStar
|
|
||||||
|
|
||||||
|
|
||||||
class BestFirst(AStar):
|
|
||||||
"""BestFirst set the heuristics as the priority
|
|
||||||
"""
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
Breadth-first Searching.
|
|
||||||
:return: path, visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
self.g[self.s_start] = 0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
heapq.heappush(self.OPEN,
|
|
||||||
(self.heuristic(self.s_start), self.s_start))
|
|
||||||
|
|
||||||
while self.OPEN:
|
|
||||||
_, s = heapq.heappop(self.OPEN)
|
|
||||||
self.CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g:
|
|
||||||
self.g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g[s_n]: # conditions for updating Cost
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
|
|
||||||
# best first set the heuristics as the priority
|
|
||||||
heapq.heappush(self.OPEN, (self.heuristic(s_n), s_n))
|
|
||||||
|
|
||||||
return self.extract_path(self.PARENT), self.CLOSED
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
BF = BestFirst(s_start, s_goal, 'euclidean')
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = BF.searching()
|
|
||||||
plot.animation(path, visited, "Best-first Searching") # animation
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,229 +0,0 @@
|
|||||||
"""
|
|
||||||
Bidirectional_a_star 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class BidirectionalAStar:
|
|
||||||
def __init__(self, s_start, s_goal, heuristic_type):
|
|
||||||
self.s_start = s_start
|
|
||||||
self.s_goal = s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env() # class Env
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
|
|
||||||
self.OPEN_fore = [] # OPEN set for forward searching
|
|
||||||
self.OPEN_back = [] # OPEN set for backward searching
|
|
||||||
self.CLOSED_fore = [] # CLOSED set for forward
|
|
||||||
self.CLOSED_back = [] # CLOSED set for backward
|
|
||||||
self.PARENT_fore = dict() # recorded parent for forward
|
|
||||||
self.PARENT_back = dict() # recorded parent for backward
|
|
||||||
self.g_fore = dict() # cost to come for forward
|
|
||||||
self.g_back = dict() # cost to come for backward
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
"""
|
|
||||||
initialize parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.g_fore[self.s_start] = 0.0
|
|
||||||
self.g_fore[self.s_goal] = math.inf
|
|
||||||
self.g_back[self.s_goal] = 0.0
|
|
||||||
self.g_back[self.s_start] = math.inf
|
|
||||||
self.PARENT_fore[self.s_start] = self.s_start
|
|
||||||
self.PARENT_back[self.s_goal] = self.s_goal
|
|
||||||
heapq.heappush(self.OPEN_fore,
|
|
||||||
(self.f_value_fore(self.s_start), self.s_start))
|
|
||||||
heapq.heappush(self.OPEN_back,
|
|
||||||
(self.f_value_back(self.s_goal), self.s_goal))
|
|
||||||
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
Bidirectional A*
|
|
||||||
:return: connected path, visited order of forward, visited order of backward
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.init()
|
|
||||||
s_meet = self.s_start
|
|
||||||
|
|
||||||
while self.OPEN_fore and self.OPEN_back:
|
|
||||||
# solve foreward-search
|
|
||||||
_, s_fore = heapq.heappop(self.OPEN_fore)
|
|
||||||
|
|
||||||
if s_fore in self.PARENT_back:
|
|
||||||
s_meet = s_fore
|
|
||||||
break
|
|
||||||
|
|
||||||
self.CLOSED_fore.append(s_fore)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s_fore):
|
|
||||||
new_cost = self.g_fore[s_fore] + self.cost(s_fore, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g_fore:
|
|
||||||
self.g_fore[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g_fore[s_n]:
|
|
||||||
self.g_fore[s_n] = new_cost
|
|
||||||
self.PARENT_fore[s_n] = s_fore
|
|
||||||
heapq.heappush(self.OPEN_fore,
|
|
||||||
(self.f_value_fore(s_n), s_n))
|
|
||||||
|
|
||||||
# solve backward-search
|
|
||||||
_, s_back = heapq.heappop(self.OPEN_back)
|
|
||||||
|
|
||||||
if s_back in self.PARENT_fore:
|
|
||||||
s_meet = s_back
|
|
||||||
break
|
|
||||||
|
|
||||||
self.CLOSED_back.append(s_back)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s_back):
|
|
||||||
new_cost = self.g_back[s_back] + self.cost(s_back, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g_back:
|
|
||||||
self.g_back[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g_back[s_n]:
|
|
||||||
self.g_back[s_n] = new_cost
|
|
||||||
self.PARENT_back[s_n] = s_back
|
|
||||||
heapq.heappush(self.OPEN_back,
|
|
||||||
(self.f_value_back(s_n), s_n))
|
|
||||||
|
|
||||||
return self.extract_path(s_meet), self.CLOSED_fore, self.CLOSED_back
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
|
|
||||||
|
|
||||||
def extract_path(self, s_meet):
|
|
||||||
"""
|
|
||||||
extract path from start and goal
|
|
||||||
:param s_meet: meet point of bi-direction a*
|
|
||||||
:return: path
|
|
||||||
"""
|
|
||||||
|
|
||||||
# extract path for foreward part
|
|
||||||
path_fore = [s_meet]
|
|
||||||
s = s_meet
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s = self.PARENT_fore[s]
|
|
||||||
path_fore.append(s)
|
|
||||||
if s == self.s_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
# extract path for backward part
|
|
||||||
path_back = []
|
|
||||||
s = s_meet
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s = self.PARENT_back[s]
|
|
||||||
path_back.append(s)
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(reversed(path_fore)) + list(path_back)
|
|
||||||
|
|
||||||
def f_value_fore(self, s):
|
|
||||||
"""
|
|
||||||
forward searching: f = g + h. (g: Cost to come, h: heuristic value)
|
|
||||||
:param s: current state
|
|
||||||
:return: f
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.g_fore[s] + self.h(s, self.s_goal)
|
|
||||||
|
|
||||||
def f_value_back(self, s):
|
|
||||||
"""
|
|
||||||
backward searching: f = g + h. (g: Cost to come, h: heuristic value)
|
|
||||||
:param s: current state
|
|
||||||
:return: f
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.g_back[s] + self.h(s, self.s_start)
|
|
||||||
|
|
||||||
def h(self, s, goal):
|
|
||||||
"""
|
|
||||||
Calculate heuristic value.
|
|
||||||
:param s: current node (state)
|
|
||||||
:param goal: goal node (state)
|
|
||||||
:return: heuristic value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return math.inf
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
"""
|
|
||||||
check if the line segment (s_start, s_end) is collision.
|
|
||||||
:param s_start: start node
|
|
||||||
:param s_end: end node
|
|
||||||
:return: True: is collision / False: not collision
|
|
||||||
"""
|
|
||||||
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
x_start = (5, 5)
|
|
||||||
x_goal = (45, 25)
|
|
||||||
|
|
||||||
bastar = BidirectionalAStar(x_start, x_goal, "euclidean")
|
|
||||||
plot = plotting.Plotting(x_start, x_goal)
|
|
||||||
|
|
||||||
path, visited_fore, visited_back = bastar.searching()
|
|
||||||
plot.animation_bi_astar(path, visited_fore, visited_back, "Bidirectional-A*") # animation
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,304 +0,0 @@
|
|||||||
"""
|
|
||||||
D_star 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class DStar:
|
|
||||||
def __init__(self, s_start, s_goal):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
|
|
||||||
self.Env = env.Env()
|
|
||||||
self.Plot = plotting.Plotting(self.s_start, self.s_goal)
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions
|
|
||||||
self.obs = self.Env.obs
|
|
||||||
self.x = self.Env.x_range
|
|
||||||
self.y = self.Env.y_range
|
|
||||||
|
|
||||||
self.fig = plt.figure()
|
|
||||||
|
|
||||||
self.OPEN = set()
|
|
||||||
self.t = dict()
|
|
||||||
self.PARENT = dict()
|
|
||||||
self.h = dict()
|
|
||||||
self.k = dict()
|
|
||||||
self.path = []
|
|
||||||
self.visited = set()
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
for i in range(self.Env.x_range):
|
|
||||||
for j in range(self.Env.y_range):
|
|
||||||
self.t[(i, j)] = 'NEW'
|
|
||||||
self.k[(i, j)] = 0.0
|
|
||||||
self.h[(i, j)] = float("inf")
|
|
||||||
self.PARENT[(i, j)] = None
|
|
||||||
|
|
||||||
self.h[self.s_goal] = 0.0
|
|
||||||
|
|
||||||
def run(self, s_start, s_end):
|
|
||||||
self.init()
|
|
||||||
self.insert(s_end, 0)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
self.process_state()
|
|
||||||
if self.t[s_start] == 'CLOSED':
|
|
||||||
break
|
|
||||||
|
|
||||||
self.path = self.extract_path(s_start, s_end)
|
|
||||||
self.Plot.plot_grid("Dynamic A* (D*)")
|
|
||||||
self.plot_path(self.path)
|
|
||||||
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def on_press(self, event):
|
|
||||||
x, y = event.xdata, event.ydata
|
|
||||||
if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
|
|
||||||
print("Please choose right area!")
|
|
||||||
else:
|
|
||||||
x, y = int(x), int(y)
|
|
||||||
if (x, y) not in self.obs:
|
|
||||||
print("Add obstacle at: s =", x, ",", "y =", y)
|
|
||||||
self.obs.add((x, y))
|
|
||||||
self.Plot.update_obs(self.obs)
|
|
||||||
|
|
||||||
s = self.s_start
|
|
||||||
self.visited = set()
|
|
||||||
self.count += 1
|
|
||||||
|
|
||||||
while s != self.s_goal:
|
|
||||||
if self.is_collision(s, self.PARENT[s]):
|
|
||||||
self.modify(s)
|
|
||||||
continue
|
|
||||||
s = self.PARENT[s]
|
|
||||||
|
|
||||||
self.path = self.extract_path(self.s_start, self.s_goal)
|
|
||||||
|
|
||||||
plt.cla()
|
|
||||||
self.Plot.plot_grid("Dynamic A* (D*)")
|
|
||||||
self.plot_visited(self.visited)
|
|
||||||
self.plot_path(self.path)
|
|
||||||
|
|
||||||
self.fig.canvas.draw_idle()
|
|
||||||
|
|
||||||
def extract_path(self, s_start, s_end):
|
|
||||||
path = [s_start]
|
|
||||||
s = s_start
|
|
||||||
while True:
|
|
||||||
s = self.PARENT[s]
|
|
||||||
path.append(s)
|
|
||||||
if s == s_end:
|
|
||||||
return path
|
|
||||||
|
|
||||||
def process_state(self):
|
|
||||||
s = self.min_state() # get node in OPEN set with min k value
|
|
||||||
self.visited.add(s)
|
|
||||||
|
|
||||||
if s is None:
|
|
||||||
return -1 # OPEN set is empty
|
|
||||||
|
|
||||||
k_old = self.get_k_min() # record the min k value of this iteration (min path cost)
|
|
||||||
self.delete(s) # move state s from OPEN set to CLOSED set
|
|
||||||
|
|
||||||
# k_min < h[s] --> s: RAISE state (increased cost)
|
|
||||||
if k_old < self.h[s]:
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if self.h[s_n] <= k_old and \
|
|
||||||
self.h[s] > self.h[s_n] + self.cost(s_n, s):
|
|
||||||
|
|
||||||
# update h_value and choose parent
|
|
||||||
self.PARENT[s] = s_n
|
|
||||||
self.h[s] = self.h[s_n] + self.cost(s_n, s)
|
|
||||||
|
|
||||||
# s: k_min >= h[s] -- > s: LOWER state (cost reductions)
|
|
||||||
if k_old == self.h[s]:
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if self.t[s_n] == 'NEW' or \
|
|
||||||
(self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)) or \
|
|
||||||
(self.PARENT[s_n] != s and self.h[s_n] > self.h[s] + self.cost(s, s_n)):
|
|
||||||
|
|
||||||
# Condition:
|
|
||||||
# 1) t[s_n] == 'NEW': not visited
|
|
||||||
# 2) s_n's parent: cost reduction
|
|
||||||
# 3) s_n find a better parent
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
self.insert(s_n, self.h[s] + self.cost(s, s_n))
|
|
||||||
else:
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if self.t[s_n] == 'NEW' or \
|
|
||||||
(self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)):
|
|
||||||
|
|
||||||
# Condition:
|
|
||||||
# 1) t[s_n] == 'NEW': not visited
|
|
||||||
# 2) s_n's parent: cost reduction
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
self.insert(s_n, self.h[s] + self.cost(s, s_n))
|
|
||||||
else:
|
|
||||||
if self.PARENT[s_n] != s and \
|
|
||||||
self.h[s_n] > self.h[s] + self.cost(s, s_n):
|
|
||||||
|
|
||||||
# Condition: LOWER happened in OPEN set (s), s should be explored again
|
|
||||||
self.insert(s, self.h[s])
|
|
||||||
else:
|
|
||||||
if self.PARENT[s_n] != s and \
|
|
||||||
self.h[s] > self.h[s_n] + self.cost(s_n, s) and \
|
|
||||||
self.t[s_n] == 'CLOSED' and \
|
|
||||||
self.h[s_n] > k_old:
|
|
||||||
|
|
||||||
# Condition: LOWER happened in CLOSED set (s_n), s_n should be explored again
|
|
||||||
self.insert(s_n, self.h[s_n])
|
|
||||||
|
|
||||||
return self.get_k_min()
|
|
||||||
|
|
||||||
def min_state(self):
|
|
||||||
"""
|
|
||||||
choose the node with the minimum k value in OPEN set.
|
|
||||||
:return: state
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.OPEN:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return min(self.OPEN, key=lambda x: self.k[x])
|
|
||||||
|
|
||||||
def get_k_min(self):
|
|
||||||
"""
|
|
||||||
calc the min k value for nodes in OPEN set.
|
|
||||||
:return: k value
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.OPEN:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
return min([self.k[x] for x in self.OPEN])
|
|
||||||
|
|
||||||
def insert(self, s, h_new):
|
|
||||||
"""
|
|
||||||
insert node into OPEN set.
|
|
||||||
:param s: node
|
|
||||||
:param h_new: new or better cost to come value
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.t[s] == 'NEW':
|
|
||||||
self.k[s] = h_new
|
|
||||||
elif self.t[s] == 'OPEN':
|
|
||||||
self.k[s] = min(self.k[s], h_new)
|
|
||||||
elif self.t[s] == 'CLOSED':
|
|
||||||
self.k[s] = min(self.h[s], h_new)
|
|
||||||
|
|
||||||
self.h[s] = h_new
|
|
||||||
self.t[s] = 'OPEN'
|
|
||||||
self.OPEN.add(s)
|
|
||||||
|
|
||||||
def delete(self, s):
|
|
||||||
"""
|
|
||||||
delete: move state s from OPEN set to CLOSED set.
|
|
||||||
:param s: state should be deleted
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.t[s] == 'OPEN':
|
|
||||||
self.t[s] = 'CLOSED'
|
|
||||||
|
|
||||||
self.OPEN.remove(s)
|
|
||||||
|
|
||||||
def modify(self, s):
|
|
||||||
"""
|
|
||||||
start processing from state s.
|
|
||||||
:param s: is a node whose status is RAISE or LOWER.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.modify_cost(s)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
k_min = self.process_state()
|
|
||||||
|
|
||||||
if k_min >= self.h[s]:
|
|
||||||
break
|
|
||||||
|
|
||||||
def modify_cost(self, s):
|
|
||||||
# if node in CLOSED set, put it into OPEN set.
|
|
||||||
# Since cost may be changed between s - s.parent, calc cost(s, s.p) again
|
|
||||||
|
|
||||||
if self.t[s] == 'CLOSED':
|
|
||||||
self.insert(s, self.h[self.PARENT[s]] + self.cost(s, self.PARENT[s]))
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
nei_list = set()
|
|
||||||
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
nei_list.add(s_next)
|
|
||||||
|
|
||||||
return nei_list
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def plot_path(self, path):
|
|
||||||
px = [x[0] for x in path]
|
|
||||||
py = [x[1] for x in path]
|
|
||||||
plt.plot(px, py, linewidth=2)
|
|
||||||
plt.plot(self.s_start[0], self.s_start[1], "bs")
|
|
||||||
plt.plot(self.s_goal[0], self.s_goal[1], "gs")
|
|
||||||
|
|
||||||
def plot_visited(self, visited):
|
|
||||||
color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
|
|
||||||
'bisque', 'navajowhite', 'moccasin', 'wheat',
|
|
||||||
'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
|
|
||||||
|
|
||||||
if self.count >= len(color) - 1:
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
for x in visited:
|
|
||||||
plt.plot(x[0], x[1], marker='s', color=color[self.count])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
dstar = DStar(s_start, s_goal)
|
|
||||||
dstar.run(s_start, s_goal)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,239 +0,0 @@
|
|||||||
"""
|
|
||||||
D_star_Lite 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class DStar:
|
|
||||||
def __init__(self, s_start, s_goal, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env() # class Env
|
|
||||||
self.Plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
self.x = self.Env.x_range
|
|
||||||
self.y = self.Env.y_range
|
|
||||||
|
|
||||||
self.g, self.rhs, self.U = {}, {}, {}
|
|
||||||
self.km = 0
|
|
||||||
|
|
||||||
for i in range(1, self.Env.x_range - 1):
|
|
||||||
for j in range(1, self.Env.y_range - 1):
|
|
||||||
self.rhs[(i, j)] = float("inf")
|
|
||||||
self.g[(i, j)] = float("inf")
|
|
||||||
|
|
||||||
self.rhs[self.s_goal] = 0.0
|
|
||||||
self.U[self.s_goal] = self.CalculateKey(self.s_goal)
|
|
||||||
self.visited = set()
|
|
||||||
self.count = 0
|
|
||||||
self.fig = plt.figure()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.Plot.plot_grid("D* Lite")
|
|
||||||
self.ComputePath()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def on_press(self, event):
|
|
||||||
x, y = event.xdata, event.ydata
|
|
||||||
if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
|
|
||||||
print("Please choose right area!")
|
|
||||||
else:
|
|
||||||
x, y = int(x), int(y)
|
|
||||||
print("Change position: s =", x, ",", "y =", y)
|
|
||||||
|
|
||||||
s_curr = self.s_start
|
|
||||||
s_last = self.s_start
|
|
||||||
i = 0
|
|
||||||
path = [self.s_start]
|
|
||||||
|
|
||||||
while s_curr != self.s_goal:
|
|
||||||
s_list = {}
|
|
||||||
|
|
||||||
for s in self.get_neighbor(s_curr):
|
|
||||||
s_list[s] = self.g[s] + self.cost(s_curr, s)
|
|
||||||
s_curr = min(s_list, key=s_list.get)
|
|
||||||
path.append(s_curr)
|
|
||||||
|
|
||||||
if i < 1:
|
|
||||||
self.km += self.h(s_last, s_curr)
|
|
||||||
s_last = s_curr
|
|
||||||
if (x, y) not in self.obs:
|
|
||||||
self.obs.add((x, y))
|
|
||||||
plt.plot(x, y, 'sk')
|
|
||||||
self.g[(x, y)] = float("inf")
|
|
||||||
self.rhs[(x, y)] = float("inf")
|
|
||||||
else:
|
|
||||||
self.obs.remove((x, y))
|
|
||||||
plt.plot(x, y, marker='s', color='white')
|
|
||||||
self.UpdateVertex((x, y))
|
|
||||||
for s in self.get_neighbor((x, y)):
|
|
||||||
self.UpdateVertex(s)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
self.count += 1
|
|
||||||
self.visited = set()
|
|
||||||
self.ComputePath()
|
|
||||||
|
|
||||||
self.plot_visited(self.visited)
|
|
||||||
self.plot_path(path)
|
|
||||||
self.fig.canvas.draw_idle()
|
|
||||||
|
|
||||||
def ComputePath(self):
|
|
||||||
while True:
|
|
||||||
s, v = self.TopKey()
|
|
||||||
if v >= self.CalculateKey(self.s_start) and \
|
|
||||||
self.rhs[self.s_start] == self.g[self.s_start]:
|
|
||||||
break
|
|
||||||
|
|
||||||
k_old = v
|
|
||||||
self.U.pop(s)
|
|
||||||
self.visited.add(s)
|
|
||||||
|
|
||||||
if k_old < self.CalculateKey(s):
|
|
||||||
self.U[s] = self.CalculateKey(s)
|
|
||||||
elif self.g[s] > self.rhs[s]:
|
|
||||||
self.g[s] = self.rhs[s]
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
self.UpdateVertex(x)
|
|
||||||
else:
|
|
||||||
self.g[s] = float("inf")
|
|
||||||
self.UpdateVertex(s)
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
self.UpdateVertex(x)
|
|
||||||
|
|
||||||
def UpdateVertex(self, s):
|
|
||||||
if s != self.s_goal:
|
|
||||||
self.rhs[s] = float("inf")
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
self.rhs[s] = min(self.rhs[s], self.g[x] + self.cost(s, x))
|
|
||||||
if s in self.U:
|
|
||||||
self.U.pop(s)
|
|
||||||
|
|
||||||
if self.g[s] != self.rhs[s]:
|
|
||||||
self.U[s] = self.CalculateKey(s)
|
|
||||||
|
|
||||||
def CalculateKey(self, s):
|
|
||||||
return [min(self.g[s], self.rhs[s]) + self.h(self.s_start, s) + self.km,
|
|
||||||
min(self.g[s], self.rhs[s])]
|
|
||||||
|
|
||||||
def TopKey(self):
|
|
||||||
"""
|
|
||||||
:return: return the min key and its value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
s = min(self.U, key=self.U.get)
|
|
||||||
return s, self.U[s]
|
|
||||||
|
|
||||||
def h(self, s_start, s_goal):
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(s_goal[0] - s_start[0]) + abs(s_goal[1] - s_start[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
nei_list = set()
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
nei_list.add(s_next)
|
|
||||||
|
|
||||||
return nei_list
|
|
||||||
|
|
||||||
def extract_path(self):
|
|
||||||
"""
|
|
||||||
Extract the path based on the PARENT set.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_start]
|
|
||||||
s = self.s_start
|
|
||||||
|
|
||||||
for k in range(100):
|
|
||||||
g_list = {}
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
if not self.is_collision(s, x):
|
|
||||||
g_list[x] = self.g[x]
|
|
||||||
s = min(g_list, key=g_list.get)
|
|
||||||
path.append(s)
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(path)
|
|
||||||
|
|
||||||
def plot_path(self, path):
|
|
||||||
px = [x[0] for x in path]
|
|
||||||
py = [x[1] for x in path]
|
|
||||||
plt.plot(px, py, linewidth=2)
|
|
||||||
plt.plot(self.s_start[0], self.s_start[1], "bs")
|
|
||||||
plt.plot(self.s_goal[0], self.s_goal[1], "gs")
|
|
||||||
|
|
||||||
def plot_visited(self, visited):
|
|
||||||
color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
|
|
||||||
'bisque', 'navajowhite', 'moccasin', 'wheat',
|
|
||||||
'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
|
|
||||||
|
|
||||||
if self.count >= len(color) - 1:
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
for x in visited:
|
|
||||||
plt.plot(x[0], x[1], marker='s', color=color[self.count])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
dstar = DStar(s_start, s_goal, "euclidean")
|
|
||||||
dstar.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,69 +0,0 @@
|
|||||||
"""
|
|
||||||
Dijkstra 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
from Search_2D.Astar import AStar
|
|
||||||
|
|
||||||
|
|
||||||
class Dijkstra(AStar):
|
|
||||||
"""Dijkstra set the cost as the priority
|
|
||||||
"""
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
Breadth-first Searching.
|
|
||||||
:return: path, visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
self.g[self.s_start] = 0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
heapq.heappush(self.OPEN,
|
|
||||||
(0, self.s_start))
|
|
||||||
|
|
||||||
while self.OPEN:
|
|
||||||
_, s = heapq.heappop(self.OPEN)
|
|
||||||
self.CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g:
|
|
||||||
self.g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g[s_n]: # conditions for updating Cost
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
|
|
||||||
# best first set the heuristics as the priority
|
|
||||||
heapq.heappush(self.OPEN, (new_cost, s_n))
|
|
||||||
|
|
||||||
return self.extract_path(self.PARENT), self.CLOSED
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
dijkstra = Dijkstra(s_start, s_goal, 'None')
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = dijkstra.searching()
|
|
||||||
plot.animation(path, visited, "Dijkstra's") # animation generate
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,256 +0,0 @@
|
|||||||
"""
|
|
||||||
LPA_star 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class LPAStar:
|
|
||||||
def __init__(self, s_start, s_goal, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env()
|
|
||||||
self.Plot = plotting.Plotting(self.s_start, self.s_goal)
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions
|
|
||||||
self.obs = self.Env.obs
|
|
||||||
self.x = self.Env.x_range
|
|
||||||
self.y = self.Env.y_range
|
|
||||||
|
|
||||||
self.g, self.rhs, self.U = {}, {}, {}
|
|
||||||
|
|
||||||
for i in range(self.Env.x_range):
|
|
||||||
for j in range(self.Env.y_range):
|
|
||||||
self.rhs[(i, j)] = float("inf")
|
|
||||||
self.g[(i, j)] = float("inf")
|
|
||||||
|
|
||||||
self.rhs[self.s_start] = 0
|
|
||||||
self.U[self.s_start] = self.CalculateKey(self.s_start)
|
|
||||||
self.visited = set()
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
self.fig = plt.figure()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.Plot.plot_grid("Lifelong Planning A*")
|
|
||||||
|
|
||||||
self.ComputeShortestPath()
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def on_press(self, event):
|
|
||||||
x, y = event.xdata, event.ydata
|
|
||||||
if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
|
|
||||||
print("Please choose right area!")
|
|
||||||
else:
|
|
||||||
x, y = int(x), int(y)
|
|
||||||
print("Change position: s =", x, ",", "y =", y)
|
|
||||||
|
|
||||||
self.visited = set()
|
|
||||||
self.count += 1
|
|
||||||
|
|
||||||
if (x, y) not in self.obs:
|
|
||||||
self.obs.add((x, y))
|
|
||||||
else:
|
|
||||||
self.obs.remove((x, y))
|
|
||||||
self.UpdateVertex((x, y))
|
|
||||||
|
|
||||||
self.Plot.update_obs(self.obs)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor((x, y)):
|
|
||||||
self.UpdateVertex(s_n)
|
|
||||||
|
|
||||||
self.ComputeShortestPath()
|
|
||||||
|
|
||||||
plt.cla()
|
|
||||||
self.Plot.plot_grid("Lifelong Planning A*")
|
|
||||||
self.plot_visited(self.visited)
|
|
||||||
self.plot_path(self.extract_path())
|
|
||||||
self.fig.canvas.draw_idle()
|
|
||||||
|
|
||||||
def ComputeShortestPath(self):
|
|
||||||
while True:
|
|
||||||
s, v = self.TopKey()
|
|
||||||
|
|
||||||
if v >= self.CalculateKey(self.s_goal) and \
|
|
||||||
self.rhs[self.s_goal] == self.g[self.s_goal]:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.U.pop(s)
|
|
||||||
self.visited.add(s)
|
|
||||||
|
|
||||||
if self.g[s] > self.rhs[s]:
|
|
||||||
|
|
||||||
# Condition: over-consistent (eg: deleted obstacles)
|
|
||||||
# So, rhs[s] decreased -- > rhs[s] < g[s]
|
|
||||||
self.g[s] = self.rhs[s]
|
|
||||||
else:
|
|
||||||
|
|
||||||
# Condition: # under-consistent (eg: added obstacles)
|
|
||||||
# So, rhs[s] increased --> rhs[s] > g[s]
|
|
||||||
self.g[s] = float("inf")
|
|
||||||
self.UpdateVertex(s)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
self.UpdateVertex(s_n)
|
|
||||||
|
|
||||||
def UpdateVertex(self, s):
|
|
||||||
"""
|
|
||||||
update the status and the current cost to come of state s.
|
|
||||||
:param s: state s
|
|
||||||
"""
|
|
||||||
|
|
||||||
if s != self.s_start:
|
|
||||||
|
|
||||||
# Condition: cost of parent of s changed
|
|
||||||
# Since we do not record the children of a state, we need to enumerate its neighbors
|
|
||||||
self.rhs[s] = min(self.g[s_n] + self.cost(s_n, s)
|
|
||||||
for s_n in self.get_neighbor(s))
|
|
||||||
|
|
||||||
if s in self.U:
|
|
||||||
self.U.pop(s)
|
|
||||||
|
|
||||||
if self.g[s] != self.rhs[s]:
|
|
||||||
|
|
||||||
# Condition: current cost to come is different to that of last time
|
|
||||||
# state s should be added into OPEN set (set U)
|
|
||||||
self.U[s] = self.CalculateKey(s)
|
|
||||||
|
|
||||||
def TopKey(self):
|
|
||||||
"""
|
|
||||||
:return: return the min key and its value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
s = min(self.U, key=self.U.get)
|
|
||||||
|
|
||||||
return s, self.U[s]
|
|
||||||
|
|
||||||
def CalculateKey(self, s):
|
|
||||||
|
|
||||||
return [min(self.g[s], self.rhs[s]) + self.h(s),
|
|
||||||
min(self.g[s], self.rhs[s])]
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
s_list = set()
|
|
||||||
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
s_list.add(s_next)
|
|
||||||
|
|
||||||
return s_list
|
|
||||||
|
|
||||||
def h(self, s):
|
|
||||||
"""
|
|
||||||
Calculate heuristic.
|
|
||||||
:param s: current node (state)
|
|
||||||
:return: heuristic function value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
goal = self.s_goal # goal node
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def extract_path(self):
|
|
||||||
"""
|
|
||||||
Extract the path based on the PARENT set.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_goal]
|
|
||||||
s = self.s_goal
|
|
||||||
|
|
||||||
for k in range(100):
|
|
||||||
g_list = {}
|
|
||||||
for x in self.get_neighbor(s):
|
|
||||||
if not self.is_collision(s, x):
|
|
||||||
g_list[x] = self.g[x]
|
|
||||||
s = min(g_list, key=g_list.get)
|
|
||||||
path.append(s)
|
|
||||||
if s == self.s_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(reversed(path))
|
|
||||||
|
|
||||||
def plot_path(self, path):
|
|
||||||
px = [x[0] for x in path]
|
|
||||||
py = [x[1] for x in path]
|
|
||||||
plt.plot(px, py, linewidth=2)
|
|
||||||
plt.plot(self.s_start[0], self.s_start[1], "bs")
|
|
||||||
plt.plot(self.s_goal[0], self.s_goal[1], "gs")
|
|
||||||
|
|
||||||
def plot_visited(self, visited):
|
|
||||||
color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
|
|
||||||
'bisque', 'navajowhite', 'moccasin', 'wheat',
|
|
||||||
'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
|
|
||||||
|
|
||||||
if self.count >= len(color) - 1:
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
for x in visited:
|
|
||||||
plt.plot(x[0], x[1], marker='s', color=color[self.count])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
x_start = (5, 5)
|
|
||||||
x_goal = (45, 25)
|
|
||||||
|
|
||||||
lpastar = LPAStar(x_start, x_goal, "Euclidean")
|
|
||||||
lpastar.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,230 +0,0 @@
|
|||||||
"""
|
|
||||||
LRTA_star 2D (Learning Real-time A*)
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import copy
|
|
||||||
import math
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import queue, plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class LrtAStarN:
|
|
||||||
def __init__(self, s_start, s_goal, N, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env()
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
|
|
||||||
self.N = N # number of expand nodes each iteration
|
|
||||||
self.visited = [] # order of visited nodes in planning
|
|
||||||
self.path = [] # path of each iteration
|
|
||||||
self.h_table = {} # h_value table
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
"""
|
|
||||||
initialize the h_value of all nodes in the environment.
|
|
||||||
it is a global table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
for i in range(self.Env.x_range):
|
|
||||||
for j in range(self.Env.y_range):
|
|
||||||
self.h_table[(i, j)] = self.h((i, j))
|
|
||||||
|
|
||||||
def searching(self):
|
|
||||||
self.init()
|
|
||||||
s_start = self.s_start # initialize start node
|
|
||||||
|
|
||||||
while True:
|
|
||||||
OPEN, CLOSED = self.AStar(s_start, self.N) # OPEN, CLOSED sets in each iteration
|
|
||||||
|
|
||||||
if OPEN == "FOUND": # reach the goal node
|
|
||||||
self.path.append(CLOSED)
|
|
||||||
break
|
|
||||||
|
|
||||||
h_value = self.iteration(CLOSED) # h_value table of CLOSED nodes
|
|
||||||
|
|
||||||
for x in h_value:
|
|
||||||
self.h_table[x] = h_value[x]
|
|
||||||
|
|
||||||
s_start, path_k = self.extract_path_in_CLOSE(s_start, h_value) # x_init -> expected node in OPEN set
|
|
||||||
self.path.append(path_k)
|
|
||||||
|
|
||||||
def extract_path_in_CLOSE(self, s_start, h_value):
|
|
||||||
path = [s_start]
|
|
||||||
s = s_start
|
|
||||||
|
|
||||||
while True:
|
|
||||||
h_list = {}
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n in h_value:
|
|
||||||
h_list[s_n] = h_value[s_n]
|
|
||||||
else:
|
|
||||||
h_list[s_n] = self.h_table[s_n]
|
|
||||||
|
|
||||||
s_key = min(h_list, key=h_list.get) # move to the smallest node with min h_value
|
|
||||||
path.append(s_key) # generate path
|
|
||||||
s = s_key # use end of this iteration as the start of next
|
|
||||||
|
|
||||||
if s_key not in h_value: # reach the expected node in OPEN set
|
|
||||||
return s_key, path
|
|
||||||
|
|
||||||
def iteration(self, CLOSED):
|
|
||||||
h_value = {}
|
|
||||||
|
|
||||||
for s in CLOSED:
|
|
||||||
h_value[s] = float("inf") # initialize h_value of CLOSED nodes
|
|
||||||
|
|
||||||
while True:
|
|
||||||
h_value_rec = copy.deepcopy(h_value)
|
|
||||||
for s in CLOSED:
|
|
||||||
h_list = []
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n not in CLOSED:
|
|
||||||
h_list.append(self.cost(s, s_n) + self.h_table[s_n])
|
|
||||||
else:
|
|
||||||
h_list.append(self.cost(s, s_n) + h_value[s_n])
|
|
||||||
h_value[s] = min(h_list) # update h_value of current node
|
|
||||||
|
|
||||||
if h_value == h_value_rec: # h_value table converged
|
|
||||||
return h_value
|
|
||||||
|
|
||||||
def AStar(self, x_start, N):
|
|
||||||
OPEN = queue.QueuePrior() # OPEN set
|
|
||||||
OPEN.put(x_start, self.h(x_start))
|
|
||||||
CLOSED = [] # CLOSED set
|
|
||||||
g_table = {x_start: 0, self.s_goal: float("inf")} # Cost to come
|
|
||||||
PARENT = {x_start: x_start} # relations
|
|
||||||
count = 0 # counter
|
|
||||||
|
|
||||||
while not OPEN.empty():
|
|
||||||
count += 1
|
|
||||||
s = OPEN.get()
|
|
||||||
CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal: # reach the goal node
|
|
||||||
self.visited.append(CLOSED)
|
|
||||||
return "FOUND", self.extract_path(x_start, PARENT)
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n not in CLOSED:
|
|
||||||
new_cost = g_table[s] + self.cost(s, s_n)
|
|
||||||
if s_n not in g_table:
|
|
||||||
g_table[s_n] = float("inf")
|
|
||||||
if new_cost < g_table[s_n]: # conditions for updating Cost
|
|
||||||
g_table[s_n] = new_cost
|
|
||||||
PARENT[s_n] = s
|
|
||||||
OPEN.put(s_n, g_table[s_n] + self.h_table[s_n])
|
|
||||||
|
|
||||||
if count == N: # expand needed CLOSED nodes
|
|
||||||
break
|
|
||||||
|
|
||||||
self.visited.append(CLOSED) # visited nodes in each iteration
|
|
||||||
|
|
||||||
return OPEN, CLOSED
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
s_list = []
|
|
||||||
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
s_list.append(s_next)
|
|
||||||
|
|
||||||
return s_list
|
|
||||||
|
|
||||||
def extract_path(self, x_start, parent):
|
|
||||||
"""
|
|
||||||
Extract the path based on the relationship of nodes.
|
|
||||||
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path_back = [self.s_goal]
|
|
||||||
x_current = self.s_goal
|
|
||||||
|
|
||||||
while True:
|
|
||||||
x_current = parent[x_current]
|
|
||||||
path_back.append(x_current)
|
|
||||||
|
|
||||||
if x_current == x_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(reversed(path_back))
|
|
||||||
|
|
||||||
def h(self, s):
|
|
||||||
"""
|
|
||||||
Calculate heuristic.
|
|
||||||
:param s: current node (state)
|
|
||||||
:return: heuristic function value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
goal = self.s_goal # goal node
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (10, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
lrta = LrtAStarN(s_start, s_goal, 250, "euclidean")
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
lrta.searching()
|
|
||||||
plot.animation_lrta(lrta.path, lrta.visited,
|
|
||||||
"Learning Real-time A* (LRTA*)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,237 +0,0 @@
|
|||||||
"""
|
|
||||||
RTAAstar 2D (Real-time Adaptive A*)
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import copy
|
|
||||||
import math
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import queue, plotting, env
|
|
||||||
|
|
||||||
|
|
||||||
class RTAAStar:
|
|
||||||
def __init__(self, s_start, s_goal, N, heuristic_type):
|
|
||||||
self.s_start, self.s_goal = s_start, s_goal
|
|
||||||
self.heuristic_type = heuristic_type
|
|
||||||
|
|
||||||
self.Env = env.Env()
|
|
||||||
|
|
||||||
self.u_set = self.Env.motions # feasible input set
|
|
||||||
self.obs = self.Env.obs # position of obstacles
|
|
||||||
|
|
||||||
self.N = N # number of expand nodes each iteration
|
|
||||||
self.visited = [] # order of visited nodes in planning
|
|
||||||
self.path = [] # path of each iteration
|
|
||||||
self.h_table = {} # h_value table
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
"""
|
|
||||||
initialize the h_value of all nodes in the environment.
|
|
||||||
it is a global table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
for i in range(self.Env.x_range):
|
|
||||||
for j in range(self.Env.y_range):
|
|
||||||
self.h_table[(i, j)] = self.h((i, j))
|
|
||||||
|
|
||||||
def searching(self):
|
|
||||||
self.init()
|
|
||||||
s_start = self.s_start # initialize start node
|
|
||||||
|
|
||||||
while True:
|
|
||||||
OPEN, CLOSED, g_table, PARENT = \
|
|
||||||
self.Astar(s_start, self.N)
|
|
||||||
|
|
||||||
if OPEN == "FOUND": # reach the goal node
|
|
||||||
self.path.append(CLOSED)
|
|
||||||
break
|
|
||||||
|
|
||||||
s_next, h_value = self.cal_h_value(OPEN, CLOSED, g_table, PARENT)
|
|
||||||
|
|
||||||
for x in h_value:
|
|
||||||
self.h_table[x] = h_value[x]
|
|
||||||
|
|
||||||
s_start, path_k = self.extract_path_in_CLOSE(s_start, s_next, h_value)
|
|
||||||
self.path.append(path_k)
|
|
||||||
|
|
||||||
def cal_h_value(self, OPEN, CLOSED, g_table, PARENT):
|
|
||||||
v_open = {}
|
|
||||||
h_value = {}
|
|
||||||
for (_, x) in OPEN.enumerate():
|
|
||||||
v_open[x] = g_table[PARENT[x]] + 1 + self.h_table[x]
|
|
||||||
s_open = min(v_open, key=v_open.get)
|
|
||||||
f_min = v_open[s_open]
|
|
||||||
for x in CLOSED:
|
|
||||||
h_value[x] = f_min - g_table[x]
|
|
||||||
|
|
||||||
return s_open, h_value
|
|
||||||
|
|
||||||
def iteration(self, CLOSED):
|
|
||||||
h_value = {}
|
|
||||||
|
|
||||||
for s in CLOSED:
|
|
||||||
h_value[s] = float("inf") # initialize h_value of CLOSED nodes
|
|
||||||
|
|
||||||
while True:
|
|
||||||
h_value_rec = copy.deepcopy(h_value)
|
|
||||||
for s in CLOSED:
|
|
||||||
h_list = []
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n not in CLOSED:
|
|
||||||
h_list.append(self.cost(s, s_n) + self.h_table[s_n])
|
|
||||||
else:
|
|
||||||
h_list.append(self.cost(s, s_n) + h_value[s_n])
|
|
||||||
h_value[s] = min(h_list) # update h_value of current node
|
|
||||||
|
|
||||||
if h_value == h_value_rec: # h_value table converged
|
|
||||||
return h_value
|
|
||||||
|
|
||||||
def Astar(self, x_start, N):
|
|
||||||
OPEN = queue.QueuePrior() # OPEN set
|
|
||||||
OPEN.put(x_start, self.h_table[x_start])
|
|
||||||
CLOSED = [] # CLOSED set
|
|
||||||
g_table = {x_start: 0, self.s_goal: float("inf")} # Cost to come
|
|
||||||
PARENT = {x_start: x_start} # relations
|
|
||||||
count = 0 # counter
|
|
||||||
|
|
||||||
while not OPEN.empty():
|
|
||||||
count += 1
|
|
||||||
s = OPEN.get()
|
|
||||||
CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal: # reach the goal node
|
|
||||||
self.visited.append(CLOSED)
|
|
||||||
return "FOUND", self.extract_path(x_start, PARENT), [], []
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n not in CLOSED:
|
|
||||||
new_cost = g_table[s] + self.cost(s, s_n)
|
|
||||||
if s_n not in g_table:
|
|
||||||
g_table[s_n] = float("inf")
|
|
||||||
if new_cost < g_table[s_n]: # conditions for updating Cost
|
|
||||||
g_table[s_n] = new_cost
|
|
||||||
PARENT[s_n] = s
|
|
||||||
OPEN.put(s_n, g_table[s_n] + self.h_table[s_n])
|
|
||||||
|
|
||||||
if count == N: # expand needed CLOSED nodes
|
|
||||||
break
|
|
||||||
|
|
||||||
self.visited.append(CLOSED) # visited nodes in each iteration
|
|
||||||
|
|
||||||
return OPEN, CLOSED, g_table, PARENT
|
|
||||||
|
|
||||||
def get_neighbor(self, s):
|
|
||||||
"""
|
|
||||||
find neighbors of state s that not in obstacles.
|
|
||||||
:param s: state
|
|
||||||
:return: neighbors
|
|
||||||
"""
|
|
||||||
|
|
||||||
s_list = set()
|
|
||||||
|
|
||||||
for u in self.u_set:
|
|
||||||
s_next = tuple([s[i] + u[i] for i in range(2)])
|
|
||||||
if s_next not in self.obs:
|
|
||||||
s_list.add(s_next)
|
|
||||||
|
|
||||||
return s_list
|
|
||||||
|
|
||||||
def extract_path_in_CLOSE(self, s_end, s_start, h_value):
|
|
||||||
path = [s_start]
|
|
||||||
s = s_start
|
|
||||||
|
|
||||||
while True:
|
|
||||||
h_list = {}
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
if s_n in h_value:
|
|
||||||
h_list[s_n] = h_value[s_n]
|
|
||||||
s_key = max(h_list, key=h_list.get) # move to the smallest node with min h_value
|
|
||||||
path.append(s_key) # generate path
|
|
||||||
s = s_key # use end of this iteration as the start of next
|
|
||||||
|
|
||||||
if s_key == s_end: # reach the expected node in OPEN set
|
|
||||||
return s_start, list(reversed(path))
|
|
||||||
|
|
||||||
def extract_path(self, x_start, parent):
|
|
||||||
"""
|
|
||||||
Extract the path based on the relationship of nodes.
|
|
||||||
:return: The planning path
|
|
||||||
"""
|
|
||||||
|
|
||||||
path = [self.s_goal]
|
|
||||||
s = self.s_goal
|
|
||||||
|
|
||||||
while True:
|
|
||||||
s = parent[s]
|
|
||||||
path.append(s)
|
|
||||||
if s == x_start:
|
|
||||||
break
|
|
||||||
|
|
||||||
return list(reversed(path))
|
|
||||||
|
|
||||||
def h(self, s):
|
|
||||||
"""
|
|
||||||
Calculate heuristic.
|
|
||||||
:param s: current node (state)
|
|
||||||
:return: heuristic function value
|
|
||||||
"""
|
|
||||||
|
|
||||||
heuristic_type = self.heuristic_type # heuristic type
|
|
||||||
goal = self.s_goal # goal node
|
|
||||||
|
|
||||||
if heuristic_type == "manhattan":
|
|
||||||
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
|
|
||||||
else:
|
|
||||||
return math.hypot(goal[0] - s[0], goal[1] - s[1])
|
|
||||||
|
|
||||||
def cost(self, s_start, s_goal):
|
|
||||||
"""
|
|
||||||
Calculate Cost for this motion
|
|
||||||
:param s_start: starting node
|
|
||||||
:param s_goal: end node
|
|
||||||
:return: Cost for this motion
|
|
||||||
:note: Cost function could be more complicate!
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.is_collision(s_start, s_goal):
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
|
|
||||||
|
|
||||||
def is_collision(self, s_start, s_end):
|
|
||||||
if s_start in self.obs or s_end in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
|
|
||||||
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
else:
|
|
||||||
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
|
|
||||||
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
|
|
||||||
|
|
||||||
if s1 in self.obs or s2 in self.obs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (10, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
rtaa = RTAAStar(s_start, s_goal, 240, "euclidean")
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
rtaa.searching()
|
|
||||||
plot.animation_lrta(rtaa.path, rtaa.visited,
|
|
||||||
"Real-time Adaptive A* (RTAA*)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,69 +0,0 @@
|
|||||||
"""
|
|
||||||
Breadth-first Searching_2D (BFS)
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
from Search_2D.Astar import AStar
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
class BFS(AStar):
|
|
||||||
"""BFS add the new visited node in the end of the openset
|
|
||||||
"""
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
Breadth-first Searching.
|
|
||||||
:return: path, visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
self.g[self.s_start] = 0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
heapq.heappush(self.OPEN,
|
|
||||||
(0, self.s_start))
|
|
||||||
|
|
||||||
while self.OPEN:
|
|
||||||
_, s = heapq.heappop(self.OPEN)
|
|
||||||
self.CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g:
|
|
||||||
self.g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g[s_n]: # conditions for updating Cost
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
|
|
||||||
# bfs, add new node to the end of the openset
|
|
||||||
prior = self.OPEN[-1][0]+1 if len(self.OPEN)>0 else 0
|
|
||||||
heapq.heappush(self.OPEN, (prior, s_n))
|
|
||||||
|
|
||||||
return self.extract_path(self.PARENT), self.CLOSED
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
bfs = BFS(s_start, s_goal, 'None')
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = bfs.searching()
|
|
||||||
plot.animation(path, visited, "Breadth-first Searching (BFS)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,65 +0,0 @@
|
|||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import plotting, env
|
|
||||||
from Search_2D.Astar import AStar
|
|
||||||
|
|
||||||
class DFS(AStar):
|
|
||||||
"""DFS add the new visited node in the front of the openset
|
|
||||||
"""
|
|
||||||
def searching(self):
|
|
||||||
"""
|
|
||||||
Breadth-first Searching.
|
|
||||||
:return: path, visited order
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.PARENT[self.s_start] = self.s_start
|
|
||||||
self.g[self.s_start] = 0
|
|
||||||
self.g[self.s_goal] = math.inf
|
|
||||||
heapq.heappush(self.OPEN,
|
|
||||||
(0, self.s_start))
|
|
||||||
|
|
||||||
while self.OPEN:
|
|
||||||
_, s = heapq.heappop(self.OPEN)
|
|
||||||
self.CLOSED.append(s)
|
|
||||||
|
|
||||||
if s == self.s_goal:
|
|
||||||
break
|
|
||||||
|
|
||||||
for s_n in self.get_neighbor(s):
|
|
||||||
new_cost = self.g[s] + self.cost(s, s_n)
|
|
||||||
|
|
||||||
if s_n not in self.g:
|
|
||||||
self.g[s_n] = math.inf
|
|
||||||
|
|
||||||
if new_cost < self.g[s_n]: # conditions for updating Cost
|
|
||||||
self.g[s_n] = new_cost
|
|
||||||
self.PARENT[s_n] = s
|
|
||||||
|
|
||||||
# dfs, add new node to the front of the openset
|
|
||||||
prior = self.OPEN[0][0]-1 if len(self.OPEN)>0 else 0
|
|
||||||
heapq.heappush(self.OPEN, (prior, s_n))
|
|
||||||
|
|
||||||
return self.extract_path(self.PARENT), self.CLOSED
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
s_start = (5, 5)
|
|
||||||
s_goal = (45, 25)
|
|
||||||
|
|
||||||
dfs = DFS(s_start, s_goal, 'None')
|
|
||||||
plot = plotting.Plotting(s_start, s_goal)
|
|
||||||
|
|
||||||
path, visited = dfs.searching()
|
|
||||||
visited = list(dict.fromkeys(visited))
|
|
||||||
plot.animation(path, visited, "Depth-first Searching (DFS)") # animation
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,52 +0,0 @@
|
|||||||
"""
|
|
||||||
Env 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Env:
|
|
||||||
def __init__(self):
|
|
||||||
self.x_range = 51 # size of background
|
|
||||||
self.y_range = 31
|
|
||||||
self.motions = [(-1, 0), (-1, 1), (0, 1), (1, 1),
|
|
||||||
(1, 0), (1, -1), (0, -1), (-1, -1)]
|
|
||||||
self.obs = self.obs_map()
|
|
||||||
|
|
||||||
def update_obs(self, obs):
|
|
||||||
self.obs = obs
|
|
||||||
|
|
||||||
def obs_map(self):
|
|
||||||
"""
|
|
||||||
Initialize obstacles' positions
|
|
||||||
:return: map of obstacles
|
|
||||||
"""
|
|
||||||
|
|
||||||
x = self.x_range #51
|
|
||||||
y = self.y_range #31
|
|
||||||
obs = set()
|
|
||||||
#画上下边框
|
|
||||||
for i in range(x):
|
|
||||||
obs.add((i, 0))
|
|
||||||
for i in range(x):
|
|
||||||
obs.add((i, y - 1))
|
|
||||||
#画左右边框
|
|
||||||
for i in range(y):
|
|
||||||
obs.add((0, i))
|
|
||||||
for i in range(y):
|
|
||||||
obs.add((x - 1, i))
|
|
||||||
|
|
||||||
for i in range(2, 21):
|
|
||||||
obs.add((i, 15))
|
|
||||||
for i in range(15):
|
|
||||||
obs.add((20, i))
|
|
||||||
|
|
||||||
for i in range(15, 30):
|
|
||||||
obs.add((30, i))
|
|
||||||
for i in range(16):
|
|
||||||
obs.add((40, i))
|
|
||||||
|
|
||||||
return obs
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
|
||||||
# a = Env()
|
|
||||||
# print(a.obs)
|
|
@ -1,165 +0,0 @@
|
|||||||
"""
|
|
||||||
Plot tools 2D
|
|
||||||
@author: huiming zhou
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
|
|
||||||
"/../../Search_based_Planning/")
|
|
||||||
|
|
||||||
from Search_2D import env
|
|
||||||
|
|
||||||
|
|
||||||
class Plotting:
|
|
||||||
def __init__(self, xI, xG):
|
|
||||||
self.xI, self.xG = xI, xG
|
|
||||||
self.env = env.Env()
|
|
||||||
self.obs = self.env.obs_map()
|
|
||||||
|
|
||||||
def update_obs(self, obs):
|
|
||||||
self.obs = obs
|
|
||||||
|
|
||||||
def animation(self, path, visited, name):
|
|
||||||
self.plot_grid(name)
|
|
||||||
self.plot_visited(visited)
|
|
||||||
self.plot_path(path)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def animation_lrta(self, path, visited, name):
|
|
||||||
self.plot_grid(name)
|
|
||||||
cl = self.color_list_2()
|
|
||||||
path_combine = []
|
|
||||||
|
|
||||||
for k in range(len(path)):
|
|
||||||
self.plot_visited(visited[k], cl[k])
|
|
||||||
plt.pause(0.2)
|
|
||||||
self.plot_path(path[k])
|
|
||||||
path_combine += path[k]
|
|
||||||
plt.pause(0.2)
|
|
||||||
if self.xI in path_combine:
|
|
||||||
path_combine.remove(self.xI)
|
|
||||||
self.plot_path(path_combine)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def animation_ara_star(self, path, visited, name):
|
|
||||||
self.plot_grid(name)
|
|
||||||
cl_v, cl_p = self.color_list()
|
|
||||||
|
|
||||||
for k in range(len(path)):
|
|
||||||
self.plot_visited(visited[k], cl_v[k])
|
|
||||||
self.plot_path(path[k], cl_p[k], True)
|
|
||||||
plt.pause(0.5)
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def animation_bi_astar(self, path, v_fore, v_back, name):
|
|
||||||
self.plot_grid(name)
|
|
||||||
self.plot_visited_bi(v_fore, v_back)
|
|
||||||
self.plot_path(path)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def plot_grid(self, name):
|
|
||||||
obs_x = [x[0] for x in self.obs]
|
|
||||||
obs_y = [x[1] for x in self.obs]
|
|
||||||
|
|
||||||
plt.plot(self.xI[0], self.xI[1], "bs")
|
|
||||||
plt.plot(self.xG[0], self.xG[1], "gs")
|
|
||||||
plt.plot(obs_x, obs_y, "sk")
|
|
||||||
plt.title(name)
|
|
||||||
plt.axis("equal")
|
|
||||||
|
|
||||||
def plot_visited(self, visited, cl='gray'):
|
|
||||||
if self.xI in visited:
|
|
||||||
visited.remove(self.xI)
|
|
||||||
|
|
||||||
if self.xG in visited:
|
|
||||||
visited.remove(self.xG)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for x in visited:
|
|
||||||
count += 1
|
|
||||||
plt.plot(x[0], x[1], color=cl, marker='o')
|
|
||||||
plt.gcf().canvas.mpl_connect('key_release_event',
|
|
||||||
lambda event: [exit(0) if event.key == 'escape' else None])
|
|
||||||
|
|
||||||
if count < len(visited) / 3:
|
|
||||||
length = 20
|
|
||||||
elif count < len(visited) * 2 / 3:
|
|
||||||
length = 30
|
|
||||||
else:
|
|
||||||
length = 40
|
|
||||||
#
|
|
||||||
# length = 15
|
|
||||||
|
|
||||||
if count % length == 0:
|
|
||||||
plt.pause(0.001)
|
|
||||||
plt.pause(0.01)
|
|
||||||
|
|
||||||
def plot_path(self, path, cl='r', flag=False):
|
|
||||||
path_x = [path[i][0] for i in range(len(path))]
|
|
||||||
path_y = [path[i][1] for i in range(len(path))]
|
|
||||||
|
|
||||||
if not flag:
|
|
||||||
plt.plot(path_x, path_y, linewidth='3', color='r')
|
|
||||||
else:
|
|
||||||
plt.plot(path_x, path_y, linewidth='3', color=cl)
|
|
||||||
|
|
||||||
plt.plot(self.xI[0], self.xI[1], "bs")
|
|
||||||
plt.plot(self.xG[0], self.xG[1], "gs")
|
|
||||||
|
|
||||||
plt.pause(0.01)
|
|
||||||
|
|
||||||
def plot_visited_bi(self, v_fore, v_back):
|
|
||||||
if self.xI in v_fore:
|
|
||||||
v_fore.remove(self.xI)
|
|
||||||
|
|
||||||
if self.xG in v_back:
|
|
||||||
v_back.remove(self.xG)
|
|
||||||
|
|
||||||
len_fore, len_back = len(v_fore), len(v_back)
|
|
||||||
|
|
||||||
for k in range(max(len_fore, len_back)):
|
|
||||||
if k < len_fore:
|
|
||||||
plt.plot(v_fore[k][0], v_fore[k][1], linewidth='3', color='gray', marker='o')
|
|
||||||
if k < len_back:
|
|
||||||
plt.plot(v_back[k][0], v_back[k][1], linewidth='3', color='cornflowerblue', marker='o')
|
|
||||||
|
|
||||||
plt.gcf().canvas.mpl_connect('key_release_event',
|
|
||||||
lambda event: [exit(0) if event.key == 'escape' else None])
|
|
||||||
|
|
||||||
if k % 10 == 0:
|
|
||||||
plt.pause(0.001)
|
|
||||||
plt.pause(0.01)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def color_list():
|
|
||||||
cl_v = ['silver',
|
|
||||||
'wheat',
|
|
||||||
'lightskyblue',
|
|
||||||
'royalblue',
|
|
||||||
'slategray']
|
|
||||||
cl_p = ['gray',
|
|
||||||
'orange',
|
|
||||||
'deepskyblue',
|
|
||||||
'red',
|
|
||||||
'm']
|
|
||||||
return cl_v, cl_p
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def color_list_2():
|
|
||||||
cl = ['silver',
|
|
||||||
'steelblue',
|
|
||||||
'dimgray',
|
|
||||||
'cornflowerblue',
|
|
||||||
'dodgerblue',
|
|
||||||
'royalblue',
|
|
||||||
'plum',
|
|
||||||
'mediumslateblue',
|
|
||||||
'mediumpurple',
|
|
||||||
'blueviolet',
|
|
||||||
]
|
|
||||||
return cl
|
|
@ -1,62 +0,0 @@
|
|||||||
import collections
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
|
|
||||||
class QueueFIFO:
|
|
||||||
"""
|
|
||||||
Class: QueueFIFO
|
|
||||||
Description: QueueFIFO is designed for First-in-First-out rule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.queue = collections.deque()
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
return len(self.queue) == 0
|
|
||||||
|
|
||||||
def put(self, node):
|
|
||||||
self.queue.append(node) # enter from back
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return self.queue.popleft() # leave from front
|
|
||||||
|
|
||||||
|
|
||||||
class QueueLIFO:
|
|
||||||
"""
|
|
||||||
Class: QueueLIFO
|
|
||||||
Description: QueueLIFO is designed for Last-in-First-out rule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.queue = collections.deque()
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
return len(self.queue) == 0
|
|
||||||
|
|
||||||
def put(self, node):
|
|
||||||
self.queue.append(node) # enter from back
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return self.queue.pop() # leave from back
|
|
||||||
|
|
||||||
|
|
||||||
class QueuePrior:
|
|
||||||
"""
|
|
||||||
Class: QueuePrior
|
|
||||||
Description: QueuePrior reorders elements using value [priority]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.queue = []
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
return len(self.queue) == 0
|
|
||||||
|
|
||||||
def put(self, item, priority):
|
|
||||||
heapq.heappush(self.queue, (priority, item)) # reorder s using priority
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return heapq.heappop(self.queue)[1] # pop out the smallest item
|
|
||||||
|
|
||||||
def enumerate(self):
|
|
||||||
return self.queue
|
|
Loading…
Reference in new issue