diff --git a/src/PaddleClas/MANIFEST.in b/src/PaddleClas/MANIFEST.in
new file mode 100644
index 0000000..b0a4f6d
--- /dev/null
+++ b/src/PaddleClas/MANIFEST.in
@@ -0,0 +1,7 @@
+include LICENSE.txt
+include README.md
+include docs/en/whl_en.md
+recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py
+recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
+
+recursive-include ppcls/ *.py *.txt
\ No newline at end of file
diff --git a/src/PaddleClas/__init__.py b/src/PaddleClas/__init__.py
new file mode 100644
index 0000000..2128a6c
--- /dev/null
+++ b/src/PaddleClas/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ['PaddleClas']
+from .paddleclas import PaddleClas
+from ppcls.arch.backbone import *
diff --git a/src/PaddleClas/hubconf.py b/src/PaddleClas/hubconf.py
new file mode 100644
index 0000000..b7f7674
--- /dev/null
+++ b/src/PaddleClas/hubconf.py
@@ -0,0 +1,788 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+dependencies = ['paddle']
+
+import paddle
+import os
+import sys
+
+
+class _SysPathG(object):
+ """
+ _SysPathG used to add/clean path for sys.path. Making sure minimal pkgs dependents by skiping parent dirs.
+
+ __enter__
+ add path into sys.path
+ __exit__
+ clean user's sys.path to avoid unexpect behaviors
+ """
+
+ def __init__(self, path):
+ self.path = path
+
+ def __enter__(self, ):
+ sys.path.insert(0, self.path)
+
+ def __exit__(self, type, value, traceback):
+ _p = sys.path.pop(0)
+ assert _p == self.path, 'Make sure sys.path cleaning {} correctly.'.format(
+ self.path)
+
+
+with _SysPathG(os.path.dirname(os.path.abspath(__file__)), ):
+ import ppcls
+ import ppcls.arch.backbone as backbone
+
+ def ppclas_init():
+ if ppcls.utils.logger._logger is None:
+ ppcls.utils.logger.init_logger()
+
+ ppclas_init()
+
+ def _load_pretrained_parameters(model, name):
+ url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(
+ name)
+ path = paddle.utils.download.get_weights_path_from_url(url)
+ model.set_state_dict(paddle.load(path))
+ return model
+
+ def alexnet(pretrained=False, **kwargs):
+ """
+ AlexNet
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `AlexNet` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.AlexNet(**kwargs)
+
+ return model
+
+ def vgg11(pretrained=False, **kwargs):
+ """
+ VGG11
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
+ Returns:
+ model: nn.Layer. Specific `VGG11` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.VGG11(**kwargs)
+
+ return model
+
+ def vgg13(pretrained=False, **kwargs):
+ """
+ VGG13
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
+ Returns:
+ model: nn.Layer. Specific `VGG13` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.VGG13(**kwargs)
+
+ return model
+
+ def vgg16(pretrained=False, **kwargs):
+ """
+ VGG16
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
+ Returns:
+ model: nn.Layer. Specific `VGG16` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.VGG16(**kwargs)
+
+ return model
+
+ def vgg19(pretrained=False, **kwargs):
+ """
+ VGG19
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
+ Returns:
+ model: nn.Layer. Specific `VGG19` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.VGG19(**kwargs)
+
+ return model
+
+ def resnet18(pretrained=False, **kwargs):
+ """
+ ResNet18
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ input_image_channel: int=3. The number of input image channels
+ data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
+ Returns:
+ model: nn.Layer. Specific `ResNet18` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNet18(**kwargs)
+
+ return model
+
+ def resnet34(pretrained=False, **kwargs):
+ """
+ ResNet34
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ input_image_channel: int=3. The number of input image channels
+ data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
+ Returns:
+ model: nn.Layer. Specific `ResNet34` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNet34(**kwargs)
+
+ return model
+
+ def resnet50(pretrained=False, **kwargs):
+ """
+ ResNet50
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ input_image_channel: int=3. The number of input image channels
+ data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
+ Returns:
+ model: nn.Layer. Specific `ResNet50` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNet50(**kwargs)
+
+ return model
+
+ def resnet101(pretrained=False, **kwargs):
+ """
+ ResNet101
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ input_image_channel: int=3. The number of input image channels
+ data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
+ Returns:
+ model: nn.Layer. Specific `ResNet101` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNet101(**kwargs)
+
+ return model
+
+ def resnet152(pretrained=False, **kwargs):
+ """
+ ResNet152
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ input_image_channel: int=3. The number of input image channels
+ data_format: str='NCHW'. The data format of batch input images, should in ('NCHW', 'NHWC')
+ Returns:
+ model: nn.Layer. Specific `ResNet152` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNet152(**kwargs)
+
+ return model
+
+ def squeezenet1_0(pretrained=False, **kwargs):
+ """
+ SqueezeNet1_0
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `SqueezeNet1_0` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.SqueezeNet1_0(**kwargs)
+
+ return model
+
+ def squeezenet1_1(pretrained=False, **kwargs):
+ """
+ SqueezeNet1_1
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `SqueezeNet1_1` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.SqueezeNet1_1(**kwargs)
+
+ return model
+
+ def densenet121(pretrained=False, **kwargs):
+ """
+ DenseNet121
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ dropout: float=0. Probability of setting units to zero.
+ bn_size: int=4. The number of channals per group
+ Returns:
+ model: nn.Layer. Specific `DenseNet121` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DenseNet121(**kwargs)
+
+ return model
+
+ def densenet161(pretrained=False, **kwargs):
+ """
+ DenseNet161
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ dropout: float=0. Probability of setting units to zero.
+ bn_size: int=4. The number of channals per group
+ Returns:
+ model: nn.Layer. Specific `DenseNet161` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DenseNet161(**kwargs)
+
+ return model
+
+ def densenet169(pretrained=False, **kwargs):
+ """
+ DenseNet169
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ dropout: float=0. Probability of setting units to zero.
+ bn_size: int=4. The number of channals per group
+ Returns:
+ model: nn.Layer. Specific `DenseNet169` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DenseNet169(**kwargs)
+
+ return model
+
+ def densenet201(pretrained=False, **kwargs):
+ """
+ DenseNet201
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ dropout: float=0. Probability of setting units to zero.
+ bn_size: int=4. The number of channals per group
+ Returns:
+ model: nn.Layer. Specific `DenseNet201` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DenseNet201(**kwargs)
+
+ return model
+
+ def densenet264(pretrained=False, **kwargs):
+ """
+ DenseNet264
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ dropout: float=0. Probability of setting units to zero.
+ bn_size: int=4. The number of channals per group
+ Returns:
+ model: nn.Layer. Specific `DenseNet264` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DenseNet264(**kwargs)
+
+ return model
+
+ def inceptionv3(pretrained=False, **kwargs):
+ """
+ InceptionV3
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `InceptionV3` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.InceptionV3(**kwargs)
+
+ return model
+
+ def inceptionv4(pretrained=False, **kwargs):
+ """
+ InceptionV4
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `InceptionV4` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.InceptionV4(**kwargs)
+
+ return model
+
+ def googlenet(pretrained=False, **kwargs):
+ """
+ GoogLeNet
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `GoogLeNet` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.GoogLeNet(**kwargs)
+
+ return model
+
+ def shufflenetv2_x0_25(pretrained=False, **kwargs):
+ """
+ ShuffleNetV2_x0_25
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ShuffleNetV2_x0_25(**kwargs)
+
+ return model
+
+ def mobilenetv1(pretrained=False, **kwargs):
+ """
+ MobileNetV1
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV1` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV1(**kwargs)
+
+ return model
+
+ def mobilenetv1_x0_25(pretrained=False, **kwargs):
+ """
+ MobileNetV1_x0_25
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV1_x0_25(**kwargs)
+
+ return model
+
+ def mobilenetv1_x0_5(pretrained=False, **kwargs):
+ """
+ MobileNetV1_x0_5
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV1_x0_5(**kwargs)
+
+ return model
+
+ def mobilenetv1_x0_75(pretrained=False, **kwargs):
+ """
+ MobileNetV1_x0_75
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV1_x0_75(**kwargs)
+
+ return model
+
+ def mobilenetv2_x0_25(pretrained=False, **kwargs):
+ """
+ MobileNetV2_x0_25
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV2_x0_25` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV2_x0_25(**kwargs)
+
+ return model
+
+ def mobilenetv2_x0_5(pretrained=False, **kwargs):
+ """
+ MobileNetV2_x0_5
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV2_x0_5` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV2_x0_5(**kwargs)
+
+ return model
+
+ def mobilenetv2_x0_75(pretrained=False, **kwargs):
+ """
+ MobileNetV2_x0_75
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV2_x0_75` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV2_x0_75(**kwargs)
+
+ return model
+
+ def mobilenetv2_x1_5(pretrained=False, **kwargs):
+ """
+ MobileNetV2_x1_5
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV2_x1_5` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV2_x1_5(**kwargs)
+
+ return model
+
+ def mobilenetv2_x2_0(pretrained=False, **kwargs):
+ """
+ MobileNetV2_x2_0
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV2_x2_0` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV2_x2_0(**kwargs)
+
+ return model
+
+ def mobilenetv3_large_x0_35(pretrained=False, **kwargs):
+ """
+ MobileNetV3_large_x0_35
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_large_x0_35(**kwargs)
+
+ return model
+
+ def mobilenetv3_large_x0_5(pretrained=False, **kwargs):
+ """
+ MobileNetV3_large_x0_5
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_large_x0_5(**kwargs)
+
+ return model
+
+ def mobilenetv3_large_x0_75(pretrained=False, **kwargs):
+ """
+ MobileNetV3_large_x0_75
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_large_x0_75(**kwargs)
+
+ return model
+
+ def mobilenetv3_large_x1_0(pretrained=False, **kwargs):
+ """
+ MobileNetV3_large_x1_0
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_large_x1_0(**kwargs)
+
+ return model
+
+ def mobilenetv3_large_x1_25(pretrained=False, **kwargs):
+ """
+ MobileNetV3_large_x1_25
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_large_x1_25(**kwargs)
+
+ return model
+
+ def mobilenetv3_small_x0_35(pretrained=False, **kwargs):
+ """
+ MobileNetV3_small_x0_35
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_small_x0_35(**kwargs)
+
+ return model
+
+ def mobilenetv3_small_x0_5(pretrained=False, **kwargs):
+ """
+ MobileNetV3_small_x0_5
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_small_x0_5(**kwargs)
+
+ return model
+
+ def mobilenetv3_small_x0_75(pretrained=False, **kwargs):
+ """
+ MobileNetV3_small_x0_75
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_small_x0_75(**kwargs)
+
+ return model
+
+ def mobilenetv3_small_x1_0(pretrained=False, **kwargs):
+ """
+ MobileNetV3_small_x1_0
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_small_x1_0(**kwargs)
+
+ return model
+
+ def mobilenetv3_small_x1_25(pretrained=False, **kwargs):
+ """
+ MobileNetV3_small_x1_25
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.MobileNetV3_small_x1_25(**kwargs)
+
+ return model
+
+ def resnext101_32x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt101_32x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt101_32x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt101_32x4d(**kwargs)
+
+ return model
+
+ def resnext101_64x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt101_64x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt101_64x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt101_64x4d(**kwargs)
+
+ return model
+
+ def resnext152_32x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt152_32x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt152_32x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt152_32x4d(**kwargs)
+
+ return model
+
+ def resnext152_64x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt152_64x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt152_64x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt152_64x4d(**kwargs)
+
+ return model
+
+ def resnext50_32x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt50_32x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt50_32x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt50_32x4d(**kwargs)
+
+ return model
+
+ def resnext50_64x4d(pretrained=False, **kwargs):
+ """
+ ResNeXt50_64x4d
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.ResNeXt50_64x4d(**kwargs)
+
+ return model
+
+ def darknet53(pretrained=False, **kwargs):
+ """
+ DarkNet53
+ Args:
+ pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
+ kwargs:
+ class_dim: int=1000. Output dim of last fc layer.
+ Returns:
+ model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
+ """
+ kwargs.update({'pretrained': pretrained})
+ model = backbone.DarkNet53(**kwargs)
+
+ return model
diff --git a/src/PaddleClas/paddleclas.py b/src/PaddleClas/paddleclas.py
new file mode 100644
index 0000000..bfad193
--- /dev/null
+++ b/src/PaddleClas/paddleclas.py
@@ -0,0 +1,572 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+__dir__ = os.path.dirname(__file__)
+sys.path.append(os.path.join(__dir__, ""))
+sys.path.append(os.path.join(__dir__, "deploy"))
+
+from typing import Union, Generator
+import argparse
+import shutil
+import textwrap
+import tarfile
+import requests
+import warnings
+from functools import partial
+from difflib import SequenceMatcher
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+from prettytable import PrettyTable
+
+from deploy.python.predict_cls import ClsPredictor
+from deploy.utils.get_image_list import get_image_list
+from deploy.utils import config
+
+from ppcls.arch.backbone import *
+from ppcls.utils.logger import init_logger
+
+# for building model with loading pretrained weights from backbone
+init_logger()
+
+__all__ = ["PaddleClas"]
+
+BASE_DIR = os.path.expanduser("~/.paddleclas/")
+BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
+BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
+BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
+MODEL_SERIES = {
+ "AlexNet": ["AlexNet"],
+ "DarkNet": ["DarkNet53"],
+ "DeiT": [
+ "DeiT_base_distilled_patch16_224", "DeiT_base_distilled_patch16_384",
+ "DeiT_base_patch16_224", "DeiT_base_patch16_384",
+ "DeiT_small_distilled_patch16_224", "DeiT_small_patch16_224",
+ "DeiT_tiny_distilled_patch16_224", "DeiT_tiny_patch16_224"
+ ],
+ "DenseNet": [
+ "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
+ "DenseNet264"
+ ],
+ "DLA": [
+ "DLA46_c", "DLA60x_c", "DLA34", "DLA60", "DLA60x", "DLA102", "DLA102x",
+ "DLA102x2", "DLA169"
+ ],
+ "DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
+ "EfficientNet": [
+ "EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
+ "EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
+ "EfficientNetB6", "EfficientNetB7"
+ ],
+ "ESNet": ["ESNet_x0_25", "ESNet_x0_5", "ESNet_x0_75", "ESNet_x1_0"],
+ "GhostNet":
+ ["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
+ "HarDNet": ["HarDNet39_ds", "HarDNet68_ds", "HarDNet68", "HarDNet85"],
+ "HRNet": [
+ "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
+ "HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
+ "HRNet_W48_C_ssld"
+ ],
+ "Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
+ "MixNet": ["MixNet_S", "MixNet_M", "MixNet_L"],
+ "MobileNetV1": [
+ "MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
+ "MobileNetV1", "MobileNetV1_ssld"
+ ],
+ "MobileNetV2": [
+ "MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
+ "MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
+ "MobileNetV2_ssld"
+ ],
+ "MobileNetV3": [
+ "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
+ "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
+ "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
+ "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
+ "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
+ "MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
+ ],
+ "PPLCNet": [
+ "PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
+ "PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
+ ],
+ "RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
+ "RegNet": ["RegNetX_4GF"],
+ "Res2Net": [
+ "Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
+ "Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
+ "Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
+ "Res2Net200_vd_26w_4s_ssld"
+ ],
+ "ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
+ "ResNet": [
+ "ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
+ "ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
+ "ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
+ "ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
+ "ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
+ ],
+ "ResNeXt": [
+ "ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
+ "ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
+ "ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
+ "ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
+ "Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
+ "ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
+ "ResNeXt152_vd_64x4d"
+ ],
+ "ReXNet":
+ ["ReXNet_1_0", "ReXNet_1_3", "ReXNet_1_5", "ReXNet_2_0", "ReXNet_3_0"],
+ "SENet": [
+ "SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
+ "SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
+ "SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
+ ],
+ "ShuffleNetV2": [
+ "ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
+ "ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
+ "ShuffleNetV2_x2_0"
+ ],
+ "SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
+ "SwinTransformer": [
+ "SwinTransformer_large_patch4_window7_224_22kto1k",
+ "SwinTransformer_large_patch4_window12_384_22kto1k",
+ "SwinTransformer_base_patch4_window7_224_22kto1k",
+ "SwinTransformer_base_patch4_window12_384_22kto1k",
+ "SwinTransformer_base_patch4_window12_384",
+ "SwinTransformer_base_patch4_window7_224",
+ "SwinTransformer_small_patch4_window7_224",
+ "SwinTransformer_tiny_patch4_window7_224"
+ ],
+ "Twins": [
+ "pcpvt_small", "pcpvt_base", "pcpvt_large", "alt_gvt_small",
+ "alt_gvt_base", "alt_gvt_large"
+ ],
+ "VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
+ "VisionTransformer": [
+ "ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
+ "ViT_large_patch16_224", "ViT_large_patch16_384",
+ "ViT_large_patch32_384", "ViT_small_patch16_224"
+ ],
+ "Xception": [
+ "Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
+ "Xception71"
+ ]
+}
+
+
+class ImageTypeError(Exception):
+ """ImageTypeError.
+ """
+
+ def __init__(self, message=""):
+ super().__init__(message)
+
+
+class InputModelError(Exception):
+ """InputModelError.
+ """
+
+ def __init__(self, message=""):
+ super().__init__(message)
+
+
+def init_config(model_name,
+ inference_model_dir,
+ use_gpu=True,
+ batch_size=1,
+ topk=5,
+ **kwargs):
+ imagenet1k_map_path = os.path.join(
+ os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt")
+ cfg = {
+ "Global": {
+ "infer_imgs": kwargs["infer_imgs"]
+ if "infer_imgs" in kwargs else False,
+ "model_name": model_name,
+ "inference_model_dir": inference_model_dir,
+ "batch_size": batch_size,
+ "use_gpu": use_gpu,
+ "enable_mkldnn": kwargs["enable_mkldnn"]
+ if "enable_mkldnn" in kwargs else False,
+ "cpu_num_threads": kwargs["cpu_num_threads"]
+ if "cpu_num_threads" in kwargs else 1,
+ "enable_benchmark": False,
+ "use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False,
+ "ir_optim": True,
+ "use_tensorrt": kwargs["use_tensorrt"]
+ if "use_tensorrt" in kwargs else False,
+ "gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000,
+ "enable_profile": False
+ },
+ "PreProcess": {
+ "transform_ops": [{
+ "ResizeImage": {
+ "resize_short": kwargs["resize_short"]
+ if "resize_short" in kwargs else 256
+ }
+ }, {
+ "CropImage": {
+ "size": kwargs["crop_size"]
+ if "crop_size" in kwargs else 224
+ }
+ }, {
+ "NormalizeImage": {
+ "scale": 0.00392157,
+ "mean": [0.485, 0.456, 0.406],
+ "std": [0.229, 0.224, 0.225],
+ "order": ''
+ }
+ }, {
+ "ToCHWImage": None
+ }]
+ },
+ "PostProcess": {
+ "main_indicator": "Topk",
+ "Topk": {
+ "topk": topk,
+ "class_id_map_file": imagenet1k_map_path
+ }
+ }
+ }
+ if "save_dir" in kwargs:
+ if kwargs["save_dir"] is not None:
+ cfg["PostProcess"]["SavePreLabel"] = {
+ "save_dir": kwargs["save_dir"]
+ }
+ if "class_id_map_file" in kwargs:
+ if kwargs["class_id_map_file"] is not None:
+ cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
+ "class_id_map_file"]
+
+ cfg = config.AttrDict(cfg)
+ config.create_attr_dict(cfg)
+ return cfg
+
+
+def args_cfg():
+ def str2bool(v):
+ return v.lower() in ("true", "t", "1")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--infer_imgs",
+ type=str,
+ required=True,
+ help="The image(s) to be predicted.")
+ parser.add_argument(
+ "--model_name", type=str, help="The model name to be used.")
+ parser.add_argument(
+ "--inference_model_dir",
+ type=str,
+ help="The directory of model files. Valid when model_name not specifed."
+ )
+ parser.add_argument(
+ "--use_gpu", type=str, default=True, help="Whether use GPU.")
+ parser.add_argument("--gpu_mem", type=int, default=8000, help="")
+ parser.add_argument(
+ "--enable_mkldnn",
+ type=str2bool,
+ default=False,
+ help="Whether use MKLDNN. Valid when use_gpu is False")
+ parser.add_argument("--cpu_num_threads", type=int, default=1, help="")
+ parser.add_argument(
+ "--use_tensorrt", type=str2bool, default=False, help="")
+ parser.add_argument("--use_fp16", type=str2bool, default=False, help="")
+ parser.add_argument(
+ "--batch_size", type=int, default=1, help="Batch size. Default by 1.")
+ parser.add_argument(
+ "--topk",
+ type=int,
+ default=5,
+ help="Return topk score(s) and corresponding results. Default by 5.")
+ parser.add_argument(
+ "--class_id_map_file",
+ type=str,
+ help="The path of file that map class_id and label.")
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ help="The directory to save prediction results as pre-label.")
+ parser.add_argument(
+ "--resize_short",
+ type=int,
+ default=256,
+ help="Resize according to short size.")
+ parser.add_argument(
+ "--crop_size", type=int, default=224, help="Centor crop size.")
+
+ args = parser.parse_args()
+ return vars(args)
+
+
+def print_info():
+ """Print list of supported models in formatted.
+ """
+ table = PrettyTable(["Series", "Name"])
+ try:
+ sz = os.get_terminal_size()
+ width = sz.columns - 30 if sz.columns > 50 else 10
+ except OSError:
+ width = 100
+ for series in MODEL_SERIES:
+ names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width)
+ table.add_row([series, names])
+ width = len(str(table).split("\n")[0])
+ print("{}".format("-" * width))
+ print("Models supported by PaddleClas".center(width))
+ print(table)
+ print("Powered by PaddlePaddle!".rjust(width))
+ print("{}".format("-" * width))
+
+
+def get_model_names():
+ """Get the model names list.
+ """
+ model_names = []
+ for series in MODEL_SERIES:
+ model_names += (MODEL_SERIES[series])
+ return model_names
+
+
+def similar_architectures(name="", names=[], thresh=0.1, topk=10):
+ """Find the most similar topk model names.
+ """
+ scores = []
+ for idx, n in enumerate(names):
+ if n.startswith("__"):
+ continue
+ score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
+ if score > thresh:
+ scores.append((idx, score))
+ scores.sort(key=lambda x: x[1], reverse=True)
+ similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
+ return similar_names
+
+
+def download_with_progressbar(url, save_path):
+ """Download from url with progressbar.
+ """
+ if os.path.isfile(save_path):
+ os.remove(save_path)
+ response = requests.get(url, stream=True)
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
+ with open(save_path, "wb") as file:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ file.write(data)
+ progress_bar.close()
+ if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile(
+ save_path):
+ raise Exception(
+ f"Something went wrong while downloading file from {url}")
+
+
+def check_model_file(model_name):
+ """Check the model files exist and download and untar when no exist.
+ """
+ storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
+ model_name)
+ url = BASE_DOWNLOAD_URL.format(model_name)
+
+ tar_file_name_list = [
+ "inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
+ ]
+ model_file_path = storage_directory("inference.pdmodel")
+ params_file_path = storage_directory("inference.pdiparams")
+ if not os.path.exists(model_file_path) or not os.path.exists(
+ params_file_path):
+ tmp_path = storage_directory(url.split("/")[-1])
+ print(f"download {url} to {tmp_path}")
+ os.makedirs(storage_directory(), exist_ok=True)
+ download_with_progressbar(url, tmp_path)
+ with tarfile.open(tmp_path, "r") as tarObj:
+ for member in tarObj.getmembers():
+ filename = None
+ for tar_file_name in tar_file_name_list:
+ if tar_file_name in member.name:
+ filename = tar_file_name
+ if filename is None:
+ continue
+ file = tarObj.extractfile(member)
+ with open(storage_directory(filename), "wb") as f:
+ f.write(file.read())
+ os.remove(tmp_path)
+ if not os.path.exists(model_file_path) or not os.path.exists(
+ params_file_path):
+ raise Exception(
+ f"Something went wrong while praparing the model[{model_name}] files!"
+ )
+
+ return storage_directory()
+
+
+class PaddleClas(object):
+ """PaddleClas.
+ """
+
+ print_info()
+
+ def __init__(self,
+ model_name: str=None,
+ inference_model_dir: str=None,
+ use_gpu: bool=True,
+ batch_size: int=1,
+ topk: int=5,
+ **kwargs):
+ """Init PaddleClas with config.
+
+ Args:
+ model_name (str, optional): The model name supported by PaddleClas. If specified, override config. Defaults to None.
+ inference_model_dir (str, optional): The directory that contained model file and params file to be used. If specified, override config. Defaults to None.
+ use_gpu (bool, optional): Whether use GPU. If specified, override config. Defaults to True.
+ batch_size (int, optional): The batch size to pridict. If specified, override config. Defaults to 1.
+ topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
+ """
+ super().__init__()
+ self._config = init_config(model_name, inference_model_dir, use_gpu,
+ batch_size, topk, **kwargs)
+ self._check_input_model()
+ self.cls_predictor = ClsPredictor(self._config)
+
+ def get_config(self):
+ """Get the config.
+ """
+ return self._config
+
+ def _check_input_model(self):
+ """Check input model name or model files.
+ """
+ candidate_model_names = get_model_names()
+ input_model_name = self._config.Global.get("model_name", None)
+ inference_model_dir = self._config.Global.get("inference_model_dir",
+ None)
+ if input_model_name is not None:
+ similar_names = similar_architectures(input_model_name,
+ candidate_model_names)
+ similar_names_str = ", ".join(similar_names)
+ if input_model_name not in candidate_model_names:
+ err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
+ raise InputModelError(err)
+ self._config.Global.inference_model_dir = check_model_file(
+ input_model_name)
+ return
+ elif inference_model_dir is not None:
+ model_file_path = os.path.join(inference_model_dir,
+ "inference.pdmodel")
+ params_file_path = os.path.join(inference_model_dir,
+ "inference.pdiparams")
+ if not os.path.isfile(model_file_path) or not os.path.isfile(
+ params_file_path):
+ err = f"There is no model file or params file in this directory: {inference_model_dir}"
+ raise InputModelError(err)
+ return
+ else:
+ err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
+ raise InputModelError(err)
+ return
+
+ def predict(self, input_data: Union[str, np.array],
+ print_pred: bool=False) -> Generator[list, None, None]:
+ """Predict input_data.
+
+ Args:
+ input_data (Union[str, np.array]):
+ When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
+ When the type is np.array, it is the image data whose channel order is RGB.
+ print_pred (bool, optional): Whether print the prediction result. Defaults to False.
+
+ Raises:
+ ImageTypeError: Illegal input_data.
+
+ Yields:
+ Generator[list, None, None]:
+ The prediction result(s) of input_data by batch_size. For every one image,
+ prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
+ The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
+ """
+
+ if isinstance(input_data, np.ndarray):
+ yield self.cls_predictor.predict(input_data)
+ elif isinstance(input_data, str):
+ if input_data.startswith("http") or input_data.startswith("https"):
+ image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
+ if not os.path.exists(image_storage_dir()):
+ os.makedirs(image_storage_dir())
+ image_save_path = image_storage_dir("tmp.jpg")
+ download_with_progressbar(input_data, image_save_path)
+ input_data = image_save_path
+ warnings.warn(
+ f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
+ )
+ image_list = get_image_list(input_data)
+
+ batch_size = self._config.Global.get("batch_size", 1)
+ topk = self._config.PostProcess.Topk.get('topk', 1)
+
+ img_list = []
+ img_path_list = []
+ cnt = 0
+ for idx, img_path in enumerate(image_list):
+ img = cv2.imread(img_path)
+ if img is None:
+ warnings.warn(
+ f"Image file failed to read and has been skipped. The path: {img_path}"
+ )
+ continue
+ img = img[:, :, ::-1]
+ img_list.append(img)
+ img_path_list.append(img_path)
+ cnt += 1
+
+ if cnt % batch_size == 0 or (idx + 1) == len(image_list):
+ preds = self.cls_predictor.predict(img_list)
+
+ if print_pred and preds:
+ for idx, pred in enumerate(preds):
+ pred_str = ", ".join(
+ [f"{k}: {pred[k]}" for k in pred])
+ print(
+ f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
+ )
+
+ img_list = []
+ img_path_list = []
+ yield preds
+ else:
+ err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
+ raise ImageTypeError(err)
+ return
+
+
+# for CLI
+def main():
+ """Function API used for commad line.
+ """
+ cfg = args_cfg()
+ clas_engine = PaddleClas(**cfg)
+ res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
+ for _ in res:
+ pass
+ print("Predict complete!")
+ return
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/PaddleClas/requirements.txt b/src/PaddleClas/requirements.txt
new file mode 100644
index 0000000..79f548c
--- /dev/null
+++ b/src/PaddleClas/requirements.txt
@@ -0,0 +1,11 @@
+prettytable
+ujson
+opencv-python==4.4.0.46
+pillow
+tqdm
+PyYAML
+visualdl >= 2.2.0
+scipy
+scikit-learn==0.23.2
+gast==0.3.3
+faiss-cpu==1.7.1.post2
diff --git a/src/PaddleClas/setup.py b/src/PaddleClas/setup.py
new file mode 100644
index 0000000..57045d3
--- /dev/null
+++ b/src/PaddleClas/setup.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from io import open
+from setuptools import setup
+
+with open('requirements.txt', encoding="utf-8-sig") as f:
+ requirements = f.readlines()
+
+
+def readme():
+ with open(
+ 'docs/en/inference_deployment/whl_deploy_en.md',
+ encoding="utf-8-sig") as f:
+ README = f.read()
+ return README
+
+
+setup(
+ name='paddleclas',
+ packages=['paddleclas'],
+ package_dir={'paddleclas': ''},
+ include_package_data=True,
+ entry_points={
+ "console_scripts": ["paddleclas= paddleclas.paddleclas:main"]
+ },
+ version='0.0.0',
+ install_requires=requirements,
+ license='Apache License 2.0',
+ description='Awesome Image Classification toolkits based on PaddlePaddle ',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ url='https://github.com/PaddlePaddle/PaddleClas',
+ download_url='https://github.com/PaddlePaddle/PaddleClas.git',
+ keywords=[
+ 'A treasure chest for image classification powered by PaddlePaddle.'
+ ],
+ classifiers=[
+ 'Intended Audience :: Developers',
+ 'Operating System :: OS Independent',
+ 'Natural Language :: Chinese (Simplified)',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.2',
+ 'Programming Language :: Python :: 3.3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
+ ], )
diff --git a/src/Search_2D/.idea/.gitignore b/src/Search_2D/.idea/.gitignore
deleted file mode 100644
index 359bb53..0000000
--- a/src/Search_2D/.idea/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-# 默认忽略的文件
-/shelf/
-/workspace.xml
diff --git a/src/Search_2D/.idea/Search_2D.iml b/src/Search_2D/.idea/Search_2D.iml
deleted file mode 100644
index 8b8c395..0000000
--- a/src/Search_2D/.idea/Search_2D.iml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/src/Search_2D/.idea/inspectionProfiles/Project_Default.xml b/src/Search_2D/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index 6736707..0000000
--- a/src/Search_2D/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,15 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/src/Search_2D/.idea/inspectionProfiles/profiles_settings.xml b/src/Search_2D/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/src/Search_2D/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/src/Search_2D/.idea/misc.xml b/src/Search_2D/.idea/misc.xml
deleted file mode 100644
index d56657a..0000000
--- a/src/Search_2D/.idea/misc.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/src/Search_2D/.idea/modules.xml b/src/Search_2D/.idea/modules.xml
deleted file mode 100644
index 01049a7..0000000
--- a/src/Search_2D/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/src/Search_2D/ARAstar.py b/src/Search_2D/ARAstar.py
deleted file mode 100644
index c014616..0000000
--- a/src/Search_2D/ARAstar.py
+++ /dev/null
@@ -1,222 +0,0 @@
-"""
-ARA_star 2D (Anytime Repairing A*)
-@author: huiming zhou
-
-@description: local inconsistency: g-value decreased.
-g(s) decreased introduces a local inconsistency between s and its successors.
-
-"""
-
-import os
-import sys
-import math
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class AraStar:
- def __init__(self, s_start, s_goal, e, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env() # class Env
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
- self.e = e # weight
-
- self.g = dict() # Cost to come
- self.OPEN = dict() # priority queue / OPEN set
- self.CLOSED = set() # CLOSED set
- self.INCONS = {} # INCONSISTENT set
- self.PARENT = dict() # relations
- self.path = [] # planning path
- self.visited = [] # order of visited nodes
-
- def init(self):
- """
- initialize each set.
- """
-
- self.g[self.s_start] = 0.0
- self.g[self.s_goal] = math.inf
- self.OPEN[self.s_start] = self.f_value(self.s_start)
- self.PARENT[self.s_start] = self.s_start
-
- def searching(self):
- self.init()
- self.ImprovePath()
- self.path.append(self.extract_path())
-
- while self.update_e() > 1: # continue condition
- self.e -= 0.4 # increase weight
- self.OPEN.update(self.INCONS)
- self.OPEN = {s: self.f_value(s) for s in self.OPEN} # update f_value of OPEN set
-
- self.INCONS = dict()
- self.CLOSED = set()
- self.ImprovePath() # improve path
- self.path.append(self.extract_path())
-
- return self.path, self.visited
-
- def ImprovePath(self):
- """
- :return: a e'-suboptimal path
- """
-
- visited_each = []
-
- while True:
- s, f_small = self.calc_smallest_f()
-
- if self.f_value(self.s_goal) <= f_small:
- break
-
- self.OPEN.pop(s)
- self.CLOSED.add(s)
-
- for s_n in self.get_neighbor(s):
- if s_n in self.obs:
- continue
-
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g or new_cost < self.g[s_n]:
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
- visited_each.append(s_n)
-
- if s_n not in self.CLOSED:
- self.OPEN[s_n] = self.f_value(s_n)
- else:
- self.INCONS[s_n] = 0.0
-
- self.visited.append(visited_each)
-
- def calc_smallest_f(self):
- """
- :return: node with smallest f_value in OPEN set.
- """
-
- s_small = min(self.OPEN, key=self.OPEN.get)
-
- return s_small, self.OPEN[s_small]
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- return {(s[0] + u[0], s[1] + u[1]) for u in self.u_set}
-
- def update_e(self):
- v = float("inf")
-
- if self.OPEN:
- v = min(self.g[s] + self.h(s) for s in self.OPEN)
- if self.INCONS:
- v = min(v, min(self.g[s] + self.h(s) for s in self.INCONS))
-
- return min(self.e, self.g[self.s_goal] / v)
-
- def f_value(self, x):
- """
- f = g + e * h
- f = cost-to-come + weight * cost-to-go
- :param x: current state
- :return: f_value
- """
-
- return self.g[x] + self.e * self.h(x)
-
- def extract_path(self):
- """
- Extract the path based on the PARENT set.
- :return: The planning path
- """
-
- path = [self.s_goal]
- s = self.s_goal
-
- while True:
- s = self.PARENT[s]
- path.append(s)
-
- if s == self.s_start:
- break
-
- return list(path)
-
- def h(self, s):
- """
- Calculate heuristic.
- :param s: current node (state)
- :return: heuristic function value
- """
-
- heuristic_type = self.heuristic_type # heuristic type
- goal = self.s_goal # goal node
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return math.inf
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- """
- check if the line segment (s_start, s_end) is collision.
- :param s_start: start node
- :param s_end: end node
- :return: True: is collision / False: not collision
- """
-
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- arastar = AraStar(s_start, s_goal, 2.5, "euclidean")
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = arastar.searching()
- plot.animation_ara_star(path, visited, "Anytime Repairing A* (ARA*)")
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/Anytime_D_star.py b/src/Search_2D/Anytime_D_star.py
deleted file mode 100644
index cd1d62b..0000000
--- a/src/Search_2D/Anytime_D_star.py
+++ /dev/null
@@ -1,317 +0,0 @@
-"""
-Anytime_D_star 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import matplotlib.pyplot as plt
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting
-from Search_2D import env
-
-
-class ADStar:
- def __init__(self, s_start, s_goal, eps, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env() # class Env
- self.Plot = plotting.Plotting(s_start, s_goal)
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
- self.x = self.Env.x_range
- self.y = self.Env.y_range
-
- self.g, self.rhs, self.OPEN = {}, {}, {}
-
- for i in range(1, self.Env.x_range - 1):
- for j in range(1, self.Env.y_range - 1):
- self.rhs[(i, j)] = float("inf")
- self.g[(i, j)] = float("inf")
-
- self.rhs[self.s_goal] = 0.0
- self.eps = eps
- self.OPEN[self.s_goal] = self.Key(self.s_goal)
- self.CLOSED, self.INCONS = set(), dict()
-
- self.visited = set()
- self.count = 0
- self.count_env_change = 0
- self.obs_add = set()
- self.obs_remove = set()
- self.title = "Anytime D*: Small changes" # Significant changes
- self.fig = plt.figure()
-
- def run(self):
- self.Plot.plot_grid(self.title)
- self.ComputeOrImprovePath()
- self.plot_visited()
- self.plot_path(self.extract_path())
- self.visited = set()
-
- while True:
- if self.eps <= 1.0:
- break
- self.eps -= 0.5
- self.OPEN.update(self.INCONS)
- for s in self.OPEN:
- self.OPEN[s] = self.Key(s)
- self.CLOSED = set()
- self.ComputeOrImprovePath()
- self.plot_visited()
- self.plot_path(self.extract_path())
- self.visited = set()
- plt.pause(0.5)
-
- self.fig.canvas.mpl_connect('button_press_event', self.on_press)
- plt.show()
-
- def on_press(self, event):
- x, y = event.xdata, event.ydata
- if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
- print("Please choose right area!")
- else:
- self.count_env_change += 1
- x, y = int(x), int(y)
- print("Change position: s =", x, ",", "y =", y)
-
- # for small changes
- if self.title == "Anytime D*: Small changes":
- if (x, y) not in self.obs:
- self.obs.add((x, y))
- self.g[(x, y)] = float("inf")
- self.rhs[(x, y)] = float("inf")
- else:
- self.obs.remove((x, y))
- self.UpdateState((x, y))
-
- self.Plot.update_obs(self.obs)
-
- for sn in self.get_neighbor((x, y)):
- self.UpdateState(sn)
-
- plt.cla()
- self.Plot.plot_grid(self.title)
-
- while True:
- if len(self.INCONS) == 0:
- break
- self.OPEN.update(self.INCONS)
- for s in self.OPEN:
- self.OPEN[s] = self.Key(s)
- self.CLOSED = set()
- self.ComputeOrImprovePath()
- self.plot_visited()
- self.plot_path(self.extract_path())
- # plt.plot(self.title)
- self.visited = set()
-
- if self.eps <= 1.0:
- break
-
- else:
- if (x, y) not in self.obs:
- self.obs.add((x, y))
- self.obs_add.add((x, y))
- plt.plot(x, y, 'sk')
- if (x, y) in self.obs_remove:
- self.obs_remove.remove((x, y))
- else:
- self.obs.remove((x, y))
- self.obs_remove.add((x, y))
- plt.plot(x, y, marker='s', color='white')
- if (x, y) in self.obs_add:
- self.obs_add.remove((x, y))
-
- self.Plot.update_obs(self.obs)
-
- if self.count_env_change >= 15:
- self.count_env_change = 0
- self.eps += 2.0
- for s in self.obs_add:
- self.g[(x, y)] = float("inf")
- self.rhs[(x, y)] = float("inf")
-
- for sn in self.get_neighbor(s):
- self.UpdateState(sn)
-
- for s in self.obs_remove:
- for sn in self.get_neighbor(s):
- self.UpdateState(sn)
- self.UpdateState(s)
-
- plt.cla()
- self.Plot.plot_grid(self.title)
-
- while True:
- if self.eps <= 1.0:
- break
- self.eps -= 0.5
- self.OPEN.update(self.INCONS)
- for s in self.OPEN:
- self.OPEN[s] = self.Key(s)
- self.CLOSED = set()
- self.ComputeOrImprovePath()
- self.plot_visited()
- self.plot_path(self.extract_path())
- plt.title(self.title)
- self.visited = set()
- plt.pause(0.5)
-
- self.fig.canvas.draw_idle()
-
- def ComputeOrImprovePath(self):
- while True:
- s, v = self.TopKey()
- if v >= self.Key(self.s_start) and \
- self.rhs[self.s_start] == self.g[self.s_start]:
- break
-
- self.OPEN.pop(s)
- self.visited.add(s)
-
- if self.g[s] > self.rhs[s]:
- self.g[s] = self.rhs[s]
- self.CLOSED.add(s)
- for sn in self.get_neighbor(s):
- self.UpdateState(sn)
- else:
- self.g[s] = float("inf")
- for sn in self.get_neighbor(s):
- self.UpdateState(sn)
- self.UpdateState(s)
-
- def UpdateState(self, s):
- if s != self.s_goal:
- self.rhs[s] = float("inf")
- for x in self.get_neighbor(s):
- self.rhs[s] = min(self.rhs[s], self.g[x] + self.cost(s, x))
- if s in self.OPEN:
- self.OPEN.pop(s)
-
- if self.g[s] != self.rhs[s]:
- if s not in self.CLOSED:
- self.OPEN[s] = self.Key(s)
- else:
- self.INCONS[s] = 0
-
- def Key(self, s):
- if self.g[s] > self.rhs[s]:
- return [self.rhs[s] + self.eps * self.h(self.s_start, s), self.rhs[s]]
- else:
- return [self.g[s] + self.h(self.s_start, s), self.g[s]]
-
- def TopKey(self):
- """
- :return: return the min key and its value.
- """
-
- s = min(self.OPEN, key=self.OPEN.get)
- return s, self.OPEN[s]
-
- def h(self, s_start, s_goal):
- heuristic_type = self.heuristic_type # heuristic type
-
- if heuristic_type == "manhattan":
- return abs(s_goal[0] - s_start[0]) + abs(s_goal[1] - s_start[1])
- else:
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
- def get_neighbor(self, s):
- nei_list = set()
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- nei_list.add(s_next)
-
- return nei_list
-
- def extract_path(self):
- """
- Extract the path based on the PARENT set.
- :return: The planning path
- """
-
- path = [self.s_start]
- s = self.s_start
-
- for k in range(100):
- g_list = {}
- for x in self.get_neighbor(s):
- if not self.is_collision(s, x):
- g_list[x] = self.g[x]
- s = min(g_list, key=g_list.get)
- path.append(s)
- if s == self.s_goal:
- break
-
- return list(path)
-
- def plot_path(self, path):
- px = [x[0] for x in path]
- py = [x[1] for x in path]
- plt.plot(px, py, linewidth=2)
- plt.plot(self.s_start[0], self.s_start[1], "bs")
- plt.plot(self.s_goal[0], self.s_goal[1], "gs")
-
- def plot_visited(self):
- self.count += 1
-
- color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
- 'bisque', 'navajowhite', 'moccasin', 'wheat',
- 'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
-
- if self.count >= len(color) - 1:
- self.count = 0
-
- for x in self.visited:
- plt.plot(x[0], x[1], marker='s', color=color[self.count])
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- dstar = ADStar(s_start, s_goal, 2.5, "euclidean")
- dstar.run()
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/Astar.py b/src/Search_2D/Astar.py
deleted file mode 100644
index adf676b..0000000
--- a/src/Search_2D/Astar.py
+++ /dev/null
@@ -1,225 +0,0 @@
-"""
-A_star 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import heapq
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class AStar:
- """AStar set the cost + heuristics as the priority
- """
- def __init__(self, s_start, s_goal, heuristic_type):
- self.s_start = s_start
- self.s_goal = s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env() # class Env
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
-
- self.OPEN = [] # priority queue / OPEN set
- self.CLOSED = [] # CLOSED set / VISITED order
- self.PARENT = dict() # recorded parent
- self.g = dict() # cost to come
-
- def searching(self):
- """
- A_star Searching.
- :return: path, visited order
- """
-
- self.PARENT[self.s_start] = self.s_start
- self.g[self.s_start] = 0
- self.g[self.s_goal] = math.inf
- heapq.heappush(self.OPEN,
- (self.f_value(self.s_start), self.s_start))
-
- while self.OPEN:
- _, s = heapq.heappop(self.OPEN)
- self.CLOSED.append(s)
-
- if s == self.s_goal: # stop condition
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g:
- self.g[s_n] = math.inf
-
- if new_cost < self.g[s_n]: # conditions for updating Cost
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
- heapq.heappush(self.OPEN, (self.f_value(s_n), s_n))
-
- return self.extract_path(self.PARENT), self.CLOSED
-
- def searching_repeated_astar(self, e):
- """
- repeated A*.
- :param e: weight of A*
- :return: path and visited order
- """
-
- path, visited = [], []
-
- while e >= 1:
- p_k, v_k = self.repeated_searching(self.s_start, self.s_goal, e)
- path.append(p_k)
- visited.append(v_k)
- e -= 0.5
-
- return path, visited
-
- def repeated_searching(self, s_start, s_goal, e):
- """
- run A* with weight e.
- :param s_start: starting state
- :param s_goal: goal state
- :param e: weight of a*
- :return: path and visited order.
- """
-
- g = {s_start: 0, s_goal: float("inf")}
- PARENT = {s_start: s_start}
- OPEN = []
- CLOSED = []
- heapq.heappush(OPEN,
- (g[s_start] + e * self.heuristic(s_start), s_start))
-
- while OPEN:
- _, s = heapq.heappop(OPEN)
- CLOSED.append(s)
-
- if s == s_goal:
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = g[s] + self.cost(s, s_n)
-
- if s_n not in g:
- g[s_n] = math.inf
-
- if new_cost < g[s_n]: # conditions for updating Cost
- g[s_n] = new_cost
- PARENT[s_n] = s
- heapq.heappush(OPEN, (g[s_n] + e * self.heuristic(s_n), s_n))
-
- return self.extract_path(PARENT), CLOSED
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return math.inf
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- """
- check if the line segment (s_start, s_end) is collision.
- :param s_start: start node
- :param s_end: end node
- :return: True: is collision / False: not collision
- """
-
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
- def f_value(self, s):
- """
- f = g + h. (g: Cost to come, h: heuristic value)
- :param s: current state
- :return: f
- """
-
- return self.g[s] + self.heuristic(s)
-
- def extract_path(self, PARENT):
- """
- Extract the path based on the PARENT set.
- :return: The planning path
- """
-
- path = [self.s_goal]
- s = self.s_goal
-
- while True:
- s = PARENT[s]
- path.append(s)
-
- if s == self.s_start:
- break
-
- return list(path)
-
- def heuristic(self, s):
- """
- Calculate heuristic.
- :param s: current node (state)
- :return: heuristic function value
- """
-
- heuristic_type = self.heuristic_type # heuristic type
- goal = self.s_goal # goal node
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- astar = AStar(s_start, s_goal, "euclidean")
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = astar.searching()
- plot.animation(path, visited, "A*") # animation
-
- # path, visited = astar.searching_repeated_astar(2.5) # initial weight e = 2.5
- # plot.animation_ara_star(path, visited, "Repeated A*")
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/Best_First.py b/src/Search_2D/Best_First.py
deleted file mode 100644
index 0c85fba..0000000
--- a/src/Search_2D/Best_First.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Best-First Searching
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import heapq
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-from Search_2D.Astar import AStar
-
-
-class BestFirst(AStar):
- """BestFirst set the heuristics as the priority
- """
- def searching(self):
- """
- Breadth-first Searching.
- :return: path, visited order
- """
-
- self.PARENT[self.s_start] = self.s_start
- self.g[self.s_start] = 0
- self.g[self.s_goal] = math.inf
- heapq.heappush(self.OPEN,
- (self.heuristic(self.s_start), self.s_start))
-
- while self.OPEN:
- _, s = heapq.heappop(self.OPEN)
- self.CLOSED.append(s)
-
- if s == self.s_goal:
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g:
- self.g[s_n] = math.inf
-
- if new_cost < self.g[s_n]: # conditions for updating Cost
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
-
- # best first set the heuristics as the priority
- heapq.heappush(self.OPEN, (self.heuristic(s_n), s_n))
-
- return self.extract_path(self.PARENT), self.CLOSED
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- BF = BestFirst(s_start, s_goal, 'euclidean')
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = BF.searching()
- plot.animation(path, visited, "Best-first Searching") # animation
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/Bidirectional_a_star.py b/src/Search_2D/Bidirectional_a_star.py
deleted file mode 100644
index 3580c1a..0000000
--- a/src/Search_2D/Bidirectional_a_star.py
+++ /dev/null
@@ -1,229 +0,0 @@
-"""
-Bidirectional_a_star 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import heapq
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class BidirectionalAStar:
- def __init__(self, s_start, s_goal, heuristic_type):
- self.s_start = s_start
- self.s_goal = s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env() # class Env
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
-
- self.OPEN_fore = [] # OPEN set for forward searching
- self.OPEN_back = [] # OPEN set for backward searching
- self.CLOSED_fore = [] # CLOSED set for forward
- self.CLOSED_back = [] # CLOSED set for backward
- self.PARENT_fore = dict() # recorded parent for forward
- self.PARENT_back = dict() # recorded parent for backward
- self.g_fore = dict() # cost to come for forward
- self.g_back = dict() # cost to come for backward
-
- def init(self):
- """
- initialize parameters
- """
-
- self.g_fore[self.s_start] = 0.0
- self.g_fore[self.s_goal] = math.inf
- self.g_back[self.s_goal] = 0.0
- self.g_back[self.s_start] = math.inf
- self.PARENT_fore[self.s_start] = self.s_start
- self.PARENT_back[self.s_goal] = self.s_goal
- heapq.heappush(self.OPEN_fore,
- (self.f_value_fore(self.s_start), self.s_start))
- heapq.heappush(self.OPEN_back,
- (self.f_value_back(self.s_goal), self.s_goal))
-
- def searching(self):
- """
- Bidirectional A*
- :return: connected path, visited order of forward, visited order of backward
- """
-
- self.init()
- s_meet = self.s_start
-
- while self.OPEN_fore and self.OPEN_back:
- # solve foreward-search
- _, s_fore = heapq.heappop(self.OPEN_fore)
-
- if s_fore in self.PARENT_back:
- s_meet = s_fore
- break
-
- self.CLOSED_fore.append(s_fore)
-
- for s_n in self.get_neighbor(s_fore):
- new_cost = self.g_fore[s_fore] + self.cost(s_fore, s_n)
-
- if s_n not in self.g_fore:
- self.g_fore[s_n] = math.inf
-
- if new_cost < self.g_fore[s_n]:
- self.g_fore[s_n] = new_cost
- self.PARENT_fore[s_n] = s_fore
- heapq.heappush(self.OPEN_fore,
- (self.f_value_fore(s_n), s_n))
-
- # solve backward-search
- _, s_back = heapq.heappop(self.OPEN_back)
-
- if s_back in self.PARENT_fore:
- s_meet = s_back
- break
-
- self.CLOSED_back.append(s_back)
-
- for s_n in self.get_neighbor(s_back):
- new_cost = self.g_back[s_back] + self.cost(s_back, s_n)
-
- if s_n not in self.g_back:
- self.g_back[s_n] = math.inf
-
- if new_cost < self.g_back[s_n]:
- self.g_back[s_n] = new_cost
- self.PARENT_back[s_n] = s_back
- heapq.heappush(self.OPEN_back,
- (self.f_value_back(s_n), s_n))
-
- return self.extract_path(s_meet), self.CLOSED_fore, self.CLOSED_back
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
-
- def extract_path(self, s_meet):
- """
- extract path from start and goal
- :param s_meet: meet point of bi-direction a*
- :return: path
- """
-
- # extract path for foreward part
- path_fore = [s_meet]
- s = s_meet
-
- while True:
- s = self.PARENT_fore[s]
- path_fore.append(s)
- if s == self.s_start:
- break
-
- # extract path for backward part
- path_back = []
- s = s_meet
-
- while True:
- s = self.PARENT_back[s]
- path_back.append(s)
- if s == self.s_goal:
- break
-
- return list(reversed(path_fore)) + list(path_back)
-
- def f_value_fore(self, s):
- """
- forward searching: f = g + h. (g: Cost to come, h: heuristic value)
- :param s: current state
- :return: f
- """
-
- return self.g_fore[s] + self.h(s, self.s_goal)
-
- def f_value_back(self, s):
- """
- backward searching: f = g + h. (g: Cost to come, h: heuristic value)
- :param s: current state
- :return: f
- """
-
- return self.g_back[s] + self.h(s, self.s_start)
-
- def h(self, s, goal):
- """
- Calculate heuristic value.
- :param s: current node (state)
- :param goal: goal node (state)
- :return: heuristic value
- """
-
- heuristic_type = self.heuristic_type
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return math.inf
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- """
- check if the line segment (s_start, s_end) is collision.
- :param s_start: start node
- :param s_end: end node
- :return: True: is collision / False: not collision
- """
-
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
-
-def main():
- x_start = (5, 5)
- x_goal = (45, 25)
-
- bastar = BidirectionalAStar(x_start, x_goal, "euclidean")
- plot = plotting.Plotting(x_start, x_goal)
-
- path, visited_fore, visited_back = bastar.searching()
- plot.animation_bi_astar(path, visited_fore, visited_back, "Bidirectional-A*") # animation
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/D_star.py b/src/Search_2D/D_star.py
deleted file mode 100644
index 60b6c7e..0000000
--- a/src/Search_2D/D_star.py
+++ /dev/null
@@ -1,304 +0,0 @@
-"""
-D_star 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import matplotlib.pyplot as plt
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class DStar:
- def __init__(self, s_start, s_goal):
- self.s_start, self.s_goal = s_start, s_goal
-
- self.Env = env.Env()
- self.Plot = plotting.Plotting(self.s_start, self.s_goal)
-
- self.u_set = self.Env.motions
- self.obs = self.Env.obs
- self.x = self.Env.x_range
- self.y = self.Env.y_range
-
- self.fig = plt.figure()
-
- self.OPEN = set()
- self.t = dict()
- self.PARENT = dict()
- self.h = dict()
- self.k = dict()
- self.path = []
- self.visited = set()
- self.count = 0
-
- def init(self):
- for i in range(self.Env.x_range):
- for j in range(self.Env.y_range):
- self.t[(i, j)] = 'NEW'
- self.k[(i, j)] = 0.0
- self.h[(i, j)] = float("inf")
- self.PARENT[(i, j)] = None
-
- self.h[self.s_goal] = 0.0
-
- def run(self, s_start, s_end):
- self.init()
- self.insert(s_end, 0)
-
- while True:
- self.process_state()
- if self.t[s_start] == 'CLOSED':
- break
-
- self.path = self.extract_path(s_start, s_end)
- self.Plot.plot_grid("Dynamic A* (D*)")
- self.plot_path(self.path)
- self.fig.canvas.mpl_connect('button_press_event', self.on_press)
- plt.show()
-
- def on_press(self, event):
- x, y = event.xdata, event.ydata
- if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
- print("Please choose right area!")
- else:
- x, y = int(x), int(y)
- if (x, y) not in self.obs:
- print("Add obstacle at: s =", x, ",", "y =", y)
- self.obs.add((x, y))
- self.Plot.update_obs(self.obs)
-
- s = self.s_start
- self.visited = set()
- self.count += 1
-
- while s != self.s_goal:
- if self.is_collision(s, self.PARENT[s]):
- self.modify(s)
- continue
- s = self.PARENT[s]
-
- self.path = self.extract_path(self.s_start, self.s_goal)
-
- plt.cla()
- self.Plot.plot_grid("Dynamic A* (D*)")
- self.plot_visited(self.visited)
- self.plot_path(self.path)
-
- self.fig.canvas.draw_idle()
-
- def extract_path(self, s_start, s_end):
- path = [s_start]
- s = s_start
- while True:
- s = self.PARENT[s]
- path.append(s)
- if s == s_end:
- return path
-
- def process_state(self):
- s = self.min_state() # get node in OPEN set with min k value
- self.visited.add(s)
-
- if s is None:
- return -1 # OPEN set is empty
-
- k_old = self.get_k_min() # record the min k value of this iteration (min path cost)
- self.delete(s) # move state s from OPEN set to CLOSED set
-
- # k_min < h[s] --> s: RAISE state (increased cost)
- if k_old < self.h[s]:
- for s_n in self.get_neighbor(s):
- if self.h[s_n] <= k_old and \
- self.h[s] > self.h[s_n] + self.cost(s_n, s):
-
- # update h_value and choose parent
- self.PARENT[s] = s_n
- self.h[s] = self.h[s_n] + self.cost(s_n, s)
-
- # s: k_min >= h[s] -- > s: LOWER state (cost reductions)
- if k_old == self.h[s]:
- for s_n in self.get_neighbor(s):
- if self.t[s_n] == 'NEW' or \
- (self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)) or \
- (self.PARENT[s_n] != s and self.h[s_n] > self.h[s] + self.cost(s, s_n)):
-
- # Condition:
- # 1) t[s_n] == 'NEW': not visited
- # 2) s_n's parent: cost reduction
- # 3) s_n find a better parent
- self.PARENT[s_n] = s
- self.insert(s_n, self.h[s] + self.cost(s, s_n))
- else:
- for s_n in self.get_neighbor(s):
- if self.t[s_n] == 'NEW' or \
- (self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)):
-
- # Condition:
- # 1) t[s_n] == 'NEW': not visited
- # 2) s_n's parent: cost reduction
- self.PARENT[s_n] = s
- self.insert(s_n, self.h[s] + self.cost(s, s_n))
- else:
- if self.PARENT[s_n] != s and \
- self.h[s_n] > self.h[s] + self.cost(s, s_n):
-
- # Condition: LOWER happened in OPEN set (s), s should be explored again
- self.insert(s, self.h[s])
- else:
- if self.PARENT[s_n] != s and \
- self.h[s] > self.h[s_n] + self.cost(s_n, s) and \
- self.t[s_n] == 'CLOSED' and \
- self.h[s_n] > k_old:
-
- # Condition: LOWER happened in CLOSED set (s_n), s_n should be explored again
- self.insert(s_n, self.h[s_n])
-
- return self.get_k_min()
-
- def min_state(self):
- """
- choose the node with the minimum k value in OPEN set.
- :return: state
- """
-
- if not self.OPEN:
- return None
-
- return min(self.OPEN, key=lambda x: self.k[x])
-
- def get_k_min(self):
- """
- calc the min k value for nodes in OPEN set.
- :return: k value
- """
-
- if not self.OPEN:
- return -1
-
- return min([self.k[x] for x in self.OPEN])
-
- def insert(self, s, h_new):
- """
- insert node into OPEN set.
- :param s: node
- :param h_new: new or better cost to come value
- """
-
- if self.t[s] == 'NEW':
- self.k[s] = h_new
- elif self.t[s] == 'OPEN':
- self.k[s] = min(self.k[s], h_new)
- elif self.t[s] == 'CLOSED':
- self.k[s] = min(self.h[s], h_new)
-
- self.h[s] = h_new
- self.t[s] = 'OPEN'
- self.OPEN.add(s)
-
- def delete(self, s):
- """
- delete: move state s from OPEN set to CLOSED set.
- :param s: state should be deleted
- """
-
- if self.t[s] == 'OPEN':
- self.t[s] = 'CLOSED'
-
- self.OPEN.remove(s)
-
- def modify(self, s):
- """
- start processing from state s.
- :param s: is a node whose status is RAISE or LOWER.
- """
-
- self.modify_cost(s)
-
- while True:
- k_min = self.process_state()
-
- if k_min >= self.h[s]:
- break
-
- def modify_cost(self, s):
- # if node in CLOSED set, put it into OPEN set.
- # Since cost may be changed between s - s.parent, calc cost(s, s.p) again
-
- if self.t[s] == 'CLOSED':
- self.insert(s, self.h[self.PARENT[s]] + self.cost(s, self.PARENT[s]))
-
- def get_neighbor(self, s):
- nei_list = set()
-
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- nei_list.add(s_next)
-
- return nei_list
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
- def plot_path(self, path):
- px = [x[0] for x in path]
- py = [x[1] for x in path]
- plt.plot(px, py, linewidth=2)
- plt.plot(self.s_start[0], self.s_start[1], "bs")
- plt.plot(self.s_goal[0], self.s_goal[1], "gs")
-
- def plot_visited(self, visited):
- color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
- 'bisque', 'navajowhite', 'moccasin', 'wheat',
- 'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
-
- if self.count >= len(color) - 1:
- self.count = 0
-
- for x in visited:
- plt.plot(x[0], x[1], marker='s', color=color[self.count])
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
- dstar = DStar(s_start, s_goal)
- dstar.run(s_start, s_goal)
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/D_star_Lite.py b/src/Search_2D/D_star_Lite.py
deleted file mode 100644
index 4996be2..0000000
--- a/src/Search_2D/D_star_Lite.py
+++ /dev/null
@@ -1,239 +0,0 @@
-"""
-D_star_Lite 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import matplotlib.pyplot as plt
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class DStar:
- def __init__(self, s_start, s_goal, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env() # class Env
- self.Plot = plotting.Plotting(s_start, s_goal)
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
- self.x = self.Env.x_range
- self.y = self.Env.y_range
-
- self.g, self.rhs, self.U = {}, {}, {}
- self.km = 0
-
- for i in range(1, self.Env.x_range - 1):
- for j in range(1, self.Env.y_range - 1):
- self.rhs[(i, j)] = float("inf")
- self.g[(i, j)] = float("inf")
-
- self.rhs[self.s_goal] = 0.0
- self.U[self.s_goal] = self.CalculateKey(self.s_goal)
- self.visited = set()
- self.count = 0
- self.fig = plt.figure()
-
- def run(self):
- self.Plot.plot_grid("D* Lite")
- self.ComputePath()
- self.plot_path(self.extract_path())
- self.fig.canvas.mpl_connect('button_press_event', self.on_press)
- plt.show()
-
- def on_press(self, event):
- x, y = event.xdata, event.ydata
- if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
- print("Please choose right area!")
- else:
- x, y = int(x), int(y)
- print("Change position: s =", x, ",", "y =", y)
-
- s_curr = self.s_start
- s_last = self.s_start
- i = 0
- path = [self.s_start]
-
- while s_curr != self.s_goal:
- s_list = {}
-
- for s in self.get_neighbor(s_curr):
- s_list[s] = self.g[s] + self.cost(s_curr, s)
- s_curr = min(s_list, key=s_list.get)
- path.append(s_curr)
-
- if i < 1:
- self.km += self.h(s_last, s_curr)
- s_last = s_curr
- if (x, y) not in self.obs:
- self.obs.add((x, y))
- plt.plot(x, y, 'sk')
- self.g[(x, y)] = float("inf")
- self.rhs[(x, y)] = float("inf")
- else:
- self.obs.remove((x, y))
- plt.plot(x, y, marker='s', color='white')
- self.UpdateVertex((x, y))
- for s in self.get_neighbor((x, y)):
- self.UpdateVertex(s)
- i += 1
-
- self.count += 1
- self.visited = set()
- self.ComputePath()
-
- self.plot_visited(self.visited)
- self.plot_path(path)
- self.fig.canvas.draw_idle()
-
- def ComputePath(self):
- while True:
- s, v = self.TopKey()
- if v >= self.CalculateKey(self.s_start) and \
- self.rhs[self.s_start] == self.g[self.s_start]:
- break
-
- k_old = v
- self.U.pop(s)
- self.visited.add(s)
-
- if k_old < self.CalculateKey(s):
- self.U[s] = self.CalculateKey(s)
- elif self.g[s] > self.rhs[s]:
- self.g[s] = self.rhs[s]
- for x in self.get_neighbor(s):
- self.UpdateVertex(x)
- else:
- self.g[s] = float("inf")
- self.UpdateVertex(s)
- for x in self.get_neighbor(s):
- self.UpdateVertex(x)
-
- def UpdateVertex(self, s):
- if s != self.s_goal:
- self.rhs[s] = float("inf")
- for x in self.get_neighbor(s):
- self.rhs[s] = min(self.rhs[s], self.g[x] + self.cost(s, x))
- if s in self.U:
- self.U.pop(s)
-
- if self.g[s] != self.rhs[s]:
- self.U[s] = self.CalculateKey(s)
-
- def CalculateKey(self, s):
- return [min(self.g[s], self.rhs[s]) + self.h(self.s_start, s) + self.km,
- min(self.g[s], self.rhs[s])]
-
- def TopKey(self):
- """
- :return: return the min key and its value.
- """
-
- s = min(self.U, key=self.U.get)
- return s, self.U[s]
-
- def h(self, s_start, s_goal):
- heuristic_type = self.heuristic_type # heuristic type
-
- if heuristic_type == "manhattan":
- return abs(s_goal[0] - s_start[0]) + abs(s_goal[1] - s_start[1])
- else:
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
- def get_neighbor(self, s):
- nei_list = set()
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- nei_list.add(s_next)
-
- return nei_list
-
- def extract_path(self):
- """
- Extract the path based on the PARENT set.
- :return: The planning path
- """
-
- path = [self.s_start]
- s = self.s_start
-
- for k in range(100):
- g_list = {}
- for x in self.get_neighbor(s):
- if not self.is_collision(s, x):
- g_list[x] = self.g[x]
- s = min(g_list, key=g_list.get)
- path.append(s)
- if s == self.s_goal:
- break
-
- return list(path)
-
- def plot_path(self, path):
- px = [x[0] for x in path]
- py = [x[1] for x in path]
- plt.plot(px, py, linewidth=2)
- plt.plot(self.s_start[0], self.s_start[1], "bs")
- plt.plot(self.s_goal[0], self.s_goal[1], "gs")
-
- def plot_visited(self, visited):
- color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
- 'bisque', 'navajowhite', 'moccasin', 'wheat',
- 'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
-
- if self.count >= len(color) - 1:
- self.count = 0
-
- for x in visited:
- plt.plot(x[0], x[1], marker='s', color=color[self.count])
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- dstar = DStar(s_start, s_goal, "euclidean")
- dstar.run()
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/Dijkstra.py b/src/Search_2D/Dijkstra.py
deleted file mode 100644
index e5e7b68..0000000
--- a/src/Search_2D/Dijkstra.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""
-Dijkstra 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import heapq
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-from Search_2D.Astar import AStar
-
-
-class Dijkstra(AStar):
- """Dijkstra set the cost as the priority
- """
- def searching(self):
- """
- Breadth-first Searching.
- :return: path, visited order
- """
-
- self.PARENT[self.s_start] = self.s_start
- self.g[self.s_start] = 0
- self.g[self.s_goal] = math.inf
- heapq.heappush(self.OPEN,
- (0, self.s_start))
-
- while self.OPEN:
- _, s = heapq.heappop(self.OPEN)
- self.CLOSED.append(s)
-
- if s == self.s_goal:
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g:
- self.g[s_n] = math.inf
-
- if new_cost < self.g[s_n]: # conditions for updating Cost
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
-
- # best first set the heuristics as the priority
- heapq.heappush(self.OPEN, (new_cost, s_n))
-
- return self.extract_path(self.PARENT), self.CLOSED
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- dijkstra = Dijkstra(s_start, s_goal, 'None')
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = dijkstra.searching()
- plot.animation(path, visited, "Dijkstra's") # animation generate
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/LPAstar.py b/src/Search_2D/LPAstar.py
deleted file mode 100644
index 4fd70ae..0000000
--- a/src/Search_2D/LPAstar.py
+++ /dev/null
@@ -1,256 +0,0 @@
-"""
-LPA_star 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import math
-import matplotlib.pyplot as plt
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-
-
-class LPAStar:
- def __init__(self, s_start, s_goal, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env()
- self.Plot = plotting.Plotting(self.s_start, self.s_goal)
-
- self.u_set = self.Env.motions
- self.obs = self.Env.obs
- self.x = self.Env.x_range
- self.y = self.Env.y_range
-
- self.g, self.rhs, self.U = {}, {}, {}
-
- for i in range(self.Env.x_range):
- for j in range(self.Env.y_range):
- self.rhs[(i, j)] = float("inf")
- self.g[(i, j)] = float("inf")
-
- self.rhs[self.s_start] = 0
- self.U[self.s_start] = self.CalculateKey(self.s_start)
- self.visited = set()
- self.count = 0
-
- self.fig = plt.figure()
-
- def run(self):
- self.Plot.plot_grid("Lifelong Planning A*")
-
- self.ComputeShortestPath()
- self.plot_path(self.extract_path())
- self.fig.canvas.mpl_connect('button_press_event', self.on_press)
-
- plt.show()
-
- def on_press(self, event):
- x, y = event.xdata, event.ydata
- if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
- print("Please choose right area!")
- else:
- x, y = int(x), int(y)
- print("Change position: s =", x, ",", "y =", y)
-
- self.visited = set()
- self.count += 1
-
- if (x, y) not in self.obs:
- self.obs.add((x, y))
- else:
- self.obs.remove((x, y))
- self.UpdateVertex((x, y))
-
- self.Plot.update_obs(self.obs)
-
- for s_n in self.get_neighbor((x, y)):
- self.UpdateVertex(s_n)
-
- self.ComputeShortestPath()
-
- plt.cla()
- self.Plot.plot_grid("Lifelong Planning A*")
- self.plot_visited(self.visited)
- self.plot_path(self.extract_path())
- self.fig.canvas.draw_idle()
-
- def ComputeShortestPath(self):
- while True:
- s, v = self.TopKey()
-
- if v >= self.CalculateKey(self.s_goal) and \
- self.rhs[self.s_goal] == self.g[self.s_goal]:
- break
-
- self.U.pop(s)
- self.visited.add(s)
-
- if self.g[s] > self.rhs[s]:
-
- # Condition: over-consistent (eg: deleted obstacles)
- # So, rhs[s] decreased -- > rhs[s] < g[s]
- self.g[s] = self.rhs[s]
- else:
-
- # Condition: # under-consistent (eg: added obstacles)
- # So, rhs[s] increased --> rhs[s] > g[s]
- self.g[s] = float("inf")
- self.UpdateVertex(s)
-
- for s_n in self.get_neighbor(s):
- self.UpdateVertex(s_n)
-
- def UpdateVertex(self, s):
- """
- update the status and the current cost to come of state s.
- :param s: state s
- """
-
- if s != self.s_start:
-
- # Condition: cost of parent of s changed
- # Since we do not record the children of a state, we need to enumerate its neighbors
- self.rhs[s] = min(self.g[s_n] + self.cost(s_n, s)
- for s_n in self.get_neighbor(s))
-
- if s in self.U:
- self.U.pop(s)
-
- if self.g[s] != self.rhs[s]:
-
- # Condition: current cost to come is different to that of last time
- # state s should be added into OPEN set (set U)
- self.U[s] = self.CalculateKey(s)
-
- def TopKey(self):
- """
- :return: return the min key and its value.
- """
-
- s = min(self.U, key=self.U.get)
-
- return s, self.U[s]
-
- def CalculateKey(self, s):
-
- return [min(self.g[s], self.rhs[s]) + self.h(s),
- min(self.g[s], self.rhs[s])]
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- s_list = set()
-
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- s_list.add(s_next)
-
- return s_list
-
- def h(self, s):
- """
- Calculate heuristic.
- :param s: current node (state)
- :return: heuristic function value
- """
-
- heuristic_type = self.heuristic_type # heuristic type
- goal = self.s_goal # goal node
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
- def extract_path(self):
- """
- Extract the path based on the PARENT set.
- :return: The planning path
- """
-
- path = [self.s_goal]
- s = self.s_goal
-
- for k in range(100):
- g_list = {}
- for x in self.get_neighbor(s):
- if not self.is_collision(s, x):
- g_list[x] = self.g[x]
- s = min(g_list, key=g_list.get)
- path.append(s)
- if s == self.s_start:
- break
-
- return list(reversed(path))
-
- def plot_path(self, path):
- px = [x[0] for x in path]
- py = [x[1] for x in path]
- plt.plot(px, py, linewidth=2)
- plt.plot(self.s_start[0], self.s_start[1], "bs")
- plt.plot(self.s_goal[0], self.s_goal[1], "gs")
-
- def plot_visited(self, visited):
- color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
- 'bisque', 'navajowhite', 'moccasin', 'wheat',
- 'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
-
- if self.count >= len(color) - 1:
- self.count = 0
-
- for x in visited:
- plt.plot(x[0], x[1], marker='s', color=color[self.count])
-
-
-def main():
- x_start = (5, 5)
- x_goal = (45, 25)
-
- lpastar = LPAStar(x_start, x_goal, "Euclidean")
- lpastar.run()
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/LRTAstar.py b/src/Search_2D/LRTAstar.py
deleted file mode 100644
index 108903b..0000000
--- a/src/Search_2D/LRTAstar.py
+++ /dev/null
@@ -1,230 +0,0 @@
-"""
-LRTA_star 2D (Learning Real-time A*)
-@author: huiming zhou
-"""
-
-import os
-import sys
-import copy
-import math
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import queue, plotting, env
-
-
-class LrtAStarN:
- def __init__(self, s_start, s_goal, N, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env()
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
-
- self.N = N # number of expand nodes each iteration
- self.visited = [] # order of visited nodes in planning
- self.path = [] # path of each iteration
- self.h_table = {} # h_value table
-
- def init(self):
- """
- initialize the h_value of all nodes in the environment.
- it is a global table.
- """
-
- for i in range(self.Env.x_range):
- for j in range(self.Env.y_range):
- self.h_table[(i, j)] = self.h((i, j))
-
- def searching(self):
- self.init()
- s_start = self.s_start # initialize start node
-
- while True:
- OPEN, CLOSED = self.AStar(s_start, self.N) # OPEN, CLOSED sets in each iteration
-
- if OPEN == "FOUND": # reach the goal node
- self.path.append(CLOSED)
- break
-
- h_value = self.iteration(CLOSED) # h_value table of CLOSED nodes
-
- for x in h_value:
- self.h_table[x] = h_value[x]
-
- s_start, path_k = self.extract_path_in_CLOSE(s_start, h_value) # x_init -> expected node in OPEN set
- self.path.append(path_k)
-
- def extract_path_in_CLOSE(self, s_start, h_value):
- path = [s_start]
- s = s_start
-
- while True:
- h_list = {}
-
- for s_n in self.get_neighbor(s):
- if s_n in h_value:
- h_list[s_n] = h_value[s_n]
- else:
- h_list[s_n] = self.h_table[s_n]
-
- s_key = min(h_list, key=h_list.get) # move to the smallest node with min h_value
- path.append(s_key) # generate path
- s = s_key # use end of this iteration as the start of next
-
- if s_key not in h_value: # reach the expected node in OPEN set
- return s_key, path
-
- def iteration(self, CLOSED):
- h_value = {}
-
- for s in CLOSED:
- h_value[s] = float("inf") # initialize h_value of CLOSED nodes
-
- while True:
- h_value_rec = copy.deepcopy(h_value)
- for s in CLOSED:
- h_list = []
- for s_n in self.get_neighbor(s):
- if s_n not in CLOSED:
- h_list.append(self.cost(s, s_n) + self.h_table[s_n])
- else:
- h_list.append(self.cost(s, s_n) + h_value[s_n])
- h_value[s] = min(h_list) # update h_value of current node
-
- if h_value == h_value_rec: # h_value table converged
- return h_value
-
- def AStar(self, x_start, N):
- OPEN = queue.QueuePrior() # OPEN set
- OPEN.put(x_start, self.h(x_start))
- CLOSED = [] # CLOSED set
- g_table = {x_start: 0, self.s_goal: float("inf")} # Cost to come
- PARENT = {x_start: x_start} # relations
- count = 0 # counter
-
- while not OPEN.empty():
- count += 1
- s = OPEN.get()
- CLOSED.append(s)
-
- if s == self.s_goal: # reach the goal node
- self.visited.append(CLOSED)
- return "FOUND", self.extract_path(x_start, PARENT)
-
- for s_n in self.get_neighbor(s):
- if s_n not in CLOSED:
- new_cost = g_table[s] + self.cost(s, s_n)
- if s_n not in g_table:
- g_table[s_n] = float("inf")
- if new_cost < g_table[s_n]: # conditions for updating Cost
- g_table[s_n] = new_cost
- PARENT[s_n] = s
- OPEN.put(s_n, g_table[s_n] + self.h_table[s_n])
-
- if count == N: # expand needed CLOSED nodes
- break
-
- self.visited.append(CLOSED) # visited nodes in each iteration
-
- return OPEN, CLOSED
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- s_list = []
-
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- s_list.append(s_next)
-
- return s_list
-
- def extract_path(self, x_start, parent):
- """
- Extract the path based on the relationship of nodes.
-
- :return: The planning path
- """
-
- path_back = [self.s_goal]
- x_current = self.s_goal
-
- while True:
- x_current = parent[x_current]
- path_back.append(x_current)
-
- if x_current == x_start:
- break
-
- return list(reversed(path_back))
-
- def h(self, s):
- """
- Calculate heuristic.
- :param s: current node (state)
- :return: heuristic function value
- """
-
- heuristic_type = self.heuristic_type # heuristic type
- goal = self.s_goal # goal node
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
-
-def main():
- s_start = (10, 5)
- s_goal = (45, 25)
-
- lrta = LrtAStarN(s_start, s_goal, 250, "euclidean")
- plot = plotting.Plotting(s_start, s_goal)
-
- lrta.searching()
- plot.animation_lrta(lrta.path, lrta.visited,
- "Learning Real-time A* (LRTA*)")
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/RTAAStar.py b/src/Search_2D/RTAAStar.py
deleted file mode 100644
index de0a615..0000000
--- a/src/Search_2D/RTAAStar.py
+++ /dev/null
@@ -1,237 +0,0 @@
-"""
-RTAAstar 2D (Real-time Adaptive A*)
-@author: huiming zhou
-"""
-
-import os
-import sys
-import copy
-import math
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import queue, plotting, env
-
-
-class RTAAStar:
- def __init__(self, s_start, s_goal, N, heuristic_type):
- self.s_start, self.s_goal = s_start, s_goal
- self.heuristic_type = heuristic_type
-
- self.Env = env.Env()
-
- self.u_set = self.Env.motions # feasible input set
- self.obs = self.Env.obs # position of obstacles
-
- self.N = N # number of expand nodes each iteration
- self.visited = [] # order of visited nodes in planning
- self.path = [] # path of each iteration
- self.h_table = {} # h_value table
-
- def init(self):
- """
- initialize the h_value of all nodes in the environment.
- it is a global table.
- """
-
- for i in range(self.Env.x_range):
- for j in range(self.Env.y_range):
- self.h_table[(i, j)] = self.h((i, j))
-
- def searching(self):
- self.init()
- s_start = self.s_start # initialize start node
-
- while True:
- OPEN, CLOSED, g_table, PARENT = \
- self.Astar(s_start, self.N)
-
- if OPEN == "FOUND": # reach the goal node
- self.path.append(CLOSED)
- break
-
- s_next, h_value = self.cal_h_value(OPEN, CLOSED, g_table, PARENT)
-
- for x in h_value:
- self.h_table[x] = h_value[x]
-
- s_start, path_k = self.extract_path_in_CLOSE(s_start, s_next, h_value)
- self.path.append(path_k)
-
- def cal_h_value(self, OPEN, CLOSED, g_table, PARENT):
- v_open = {}
- h_value = {}
- for (_, x) in OPEN.enumerate():
- v_open[x] = g_table[PARENT[x]] + 1 + self.h_table[x]
- s_open = min(v_open, key=v_open.get)
- f_min = v_open[s_open]
- for x in CLOSED:
- h_value[x] = f_min - g_table[x]
-
- return s_open, h_value
-
- def iteration(self, CLOSED):
- h_value = {}
-
- for s in CLOSED:
- h_value[s] = float("inf") # initialize h_value of CLOSED nodes
-
- while True:
- h_value_rec = copy.deepcopy(h_value)
- for s in CLOSED:
- h_list = []
- for s_n in self.get_neighbor(s):
- if s_n not in CLOSED:
- h_list.append(self.cost(s, s_n) + self.h_table[s_n])
- else:
- h_list.append(self.cost(s, s_n) + h_value[s_n])
- h_value[s] = min(h_list) # update h_value of current node
-
- if h_value == h_value_rec: # h_value table converged
- return h_value
-
- def Astar(self, x_start, N):
- OPEN = queue.QueuePrior() # OPEN set
- OPEN.put(x_start, self.h_table[x_start])
- CLOSED = [] # CLOSED set
- g_table = {x_start: 0, self.s_goal: float("inf")} # Cost to come
- PARENT = {x_start: x_start} # relations
- count = 0 # counter
-
- while not OPEN.empty():
- count += 1
- s = OPEN.get()
- CLOSED.append(s)
-
- if s == self.s_goal: # reach the goal node
- self.visited.append(CLOSED)
- return "FOUND", self.extract_path(x_start, PARENT), [], []
-
- for s_n in self.get_neighbor(s):
- if s_n not in CLOSED:
- new_cost = g_table[s] + self.cost(s, s_n)
- if s_n not in g_table:
- g_table[s_n] = float("inf")
- if new_cost < g_table[s_n]: # conditions for updating Cost
- g_table[s_n] = new_cost
- PARENT[s_n] = s
- OPEN.put(s_n, g_table[s_n] + self.h_table[s_n])
-
- if count == N: # expand needed CLOSED nodes
- break
-
- self.visited.append(CLOSED) # visited nodes in each iteration
-
- return OPEN, CLOSED, g_table, PARENT
-
- def get_neighbor(self, s):
- """
- find neighbors of state s that not in obstacles.
- :param s: state
- :return: neighbors
- """
-
- s_list = set()
-
- for u in self.u_set:
- s_next = tuple([s[i] + u[i] for i in range(2)])
- if s_next not in self.obs:
- s_list.add(s_next)
-
- return s_list
-
- def extract_path_in_CLOSE(self, s_end, s_start, h_value):
- path = [s_start]
- s = s_start
-
- while True:
- h_list = {}
- for s_n in self.get_neighbor(s):
- if s_n in h_value:
- h_list[s_n] = h_value[s_n]
- s_key = max(h_list, key=h_list.get) # move to the smallest node with min h_value
- path.append(s_key) # generate path
- s = s_key # use end of this iteration as the start of next
-
- if s_key == s_end: # reach the expected node in OPEN set
- return s_start, list(reversed(path))
-
- def extract_path(self, x_start, parent):
- """
- Extract the path based on the relationship of nodes.
- :return: The planning path
- """
-
- path = [self.s_goal]
- s = self.s_goal
-
- while True:
- s = parent[s]
- path.append(s)
- if s == x_start:
- break
-
- return list(reversed(path))
-
- def h(self, s):
- """
- Calculate heuristic.
- :param s: current node (state)
- :return: heuristic function value
- """
-
- heuristic_type = self.heuristic_type # heuristic type
- goal = self.s_goal # goal node
-
- if heuristic_type == "manhattan":
- return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
- else:
- return math.hypot(goal[0] - s[0], goal[1] - s[1])
-
- def cost(self, s_start, s_goal):
- """
- Calculate Cost for this motion
- :param s_start: starting node
- :param s_goal: end node
- :return: Cost for this motion
- :note: Cost function could be more complicate!
- """
-
- if self.is_collision(s_start, s_goal):
- return float("inf")
-
- return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
-
- def is_collision(self, s_start, s_end):
- if s_start in self.obs or s_end in self.obs:
- return True
-
- if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
- if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
- s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- else:
- s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
- s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
-
- if s1 in self.obs or s2 in self.obs:
- return True
-
- return False
-
-
-def main():
- s_start = (10, 5)
- s_goal = (45, 25)
-
- rtaa = RTAAStar(s_start, s_goal, 240, "euclidean")
- plot = plotting.Plotting(s_start, s_goal)
-
- rtaa.searching()
- plot.animation_lrta(rtaa.path, rtaa.visited,
- "Real-time Adaptive A* (RTAA*)")
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/__pycache__/Astar.cpython-38.pyc b/src/Search_2D/__pycache__/Astar.cpython-38.pyc
deleted file mode 100644
index 7d90ba3..0000000
Binary files a/src/Search_2D/__pycache__/Astar.cpython-38.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/env.cpython-37.pyc b/src/Search_2D/__pycache__/env.cpython-37.pyc
deleted file mode 100644
index 945aa4d..0000000
Binary files a/src/Search_2D/__pycache__/env.cpython-37.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/env.cpython-38.pyc b/src/Search_2D/__pycache__/env.cpython-38.pyc
deleted file mode 100644
index e45c75b..0000000
Binary files a/src/Search_2D/__pycache__/env.cpython-38.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/env.cpython-39.pyc b/src/Search_2D/__pycache__/env.cpython-39.pyc
deleted file mode 100644
index 4776345..0000000
Binary files a/src/Search_2D/__pycache__/env.cpython-39.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/plotting.cpython-37.pyc b/src/Search_2D/__pycache__/plotting.cpython-37.pyc
deleted file mode 100644
index 8a41db2..0000000
Binary files a/src/Search_2D/__pycache__/plotting.cpython-37.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/plotting.cpython-38.pyc b/src/Search_2D/__pycache__/plotting.cpython-38.pyc
deleted file mode 100644
index 5e8cff3..0000000
Binary files a/src/Search_2D/__pycache__/plotting.cpython-38.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/plotting.cpython-39.pyc b/src/Search_2D/__pycache__/plotting.cpython-39.pyc
deleted file mode 100644
index c381ea6..0000000
Binary files a/src/Search_2D/__pycache__/plotting.cpython-39.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/queue.cpython-37.pyc b/src/Search_2D/__pycache__/queue.cpython-37.pyc
deleted file mode 100644
index 6c5f684..0000000
Binary files a/src/Search_2D/__pycache__/queue.cpython-37.pyc and /dev/null differ
diff --git a/src/Search_2D/__pycache__/queue.cpython-38.pyc b/src/Search_2D/__pycache__/queue.cpython-38.pyc
deleted file mode 100644
index 69c46c6..0000000
Binary files a/src/Search_2D/__pycache__/queue.cpython-38.pyc and /dev/null differ
diff --git a/src/Search_2D/bfs.py b/src/Search_2D/bfs.py
deleted file mode 100644
index 881e7ff..0000000
--- a/src/Search_2D/bfs.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""
-Breadth-first Searching_2D (BFS)
-@author: huiming zhou
-"""
-
-import os
-import sys
-from collections import deque
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-from Search_2D.Astar import AStar
-import math
-import heapq
-
-class BFS(AStar):
- """BFS add the new visited node in the end of the openset
- """
- def searching(self):
- """
- Breadth-first Searching.
- :return: path, visited order
- """
-
- self.PARENT[self.s_start] = self.s_start
- self.g[self.s_start] = 0
- self.g[self.s_goal] = math.inf
- heapq.heappush(self.OPEN,
- (0, self.s_start))
-
- while self.OPEN:
- _, s = heapq.heappop(self.OPEN)
- self.CLOSED.append(s)
-
- if s == self.s_goal:
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g:
- self.g[s_n] = math.inf
-
- if new_cost < self.g[s_n]: # conditions for updating Cost
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
-
- # bfs, add new node to the end of the openset
- prior = self.OPEN[-1][0]+1 if len(self.OPEN)>0 else 0
- heapq.heappush(self.OPEN, (prior, s_n))
-
- return self.extract_path(self.PARENT), self.CLOSED
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- bfs = BFS(s_start, s_goal, 'None')
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = bfs.searching()
- plot.animation(path, visited, "Breadth-first Searching (BFS)")
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/dfs.py b/src/Search_2D/dfs.py
deleted file mode 100644
index 3b30b03..0000000
--- a/src/Search_2D/dfs.py
+++ /dev/null
@@ -1,65 +0,0 @@
-
-import os
-import sys
-import math
-import heapq
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import plotting, env
-from Search_2D.Astar import AStar
-
-class DFS(AStar):
- """DFS add the new visited node in the front of the openset
- """
- def searching(self):
- """
- Breadth-first Searching.
- :return: path, visited order
- """
-
- self.PARENT[self.s_start] = self.s_start
- self.g[self.s_start] = 0
- self.g[self.s_goal] = math.inf
- heapq.heappush(self.OPEN,
- (0, self.s_start))
-
- while self.OPEN:
- _, s = heapq.heappop(self.OPEN)
- self.CLOSED.append(s)
-
- if s == self.s_goal:
- break
-
- for s_n in self.get_neighbor(s):
- new_cost = self.g[s] + self.cost(s, s_n)
-
- if s_n not in self.g:
- self.g[s_n] = math.inf
-
- if new_cost < self.g[s_n]: # conditions for updating Cost
- self.g[s_n] = new_cost
- self.PARENT[s_n] = s
-
- # dfs, add new node to the front of the openset
- prior = self.OPEN[0][0]-1 if len(self.OPEN)>0 else 0
- heapq.heappush(self.OPEN, (prior, s_n))
-
- return self.extract_path(self.PARENT), self.CLOSED
-
-
-def main():
- s_start = (5, 5)
- s_goal = (45, 25)
-
- dfs = DFS(s_start, s_goal, 'None')
- plot = plotting.Plotting(s_start, s_goal)
-
- path, visited = dfs.searching()
- visited = list(dict.fromkeys(visited))
- plot.animation(path, visited, "Depth-first Searching (DFS)") # animation
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/Search_2D/env.py b/src/Search_2D/env.py
deleted file mode 100644
index 9523c98..0000000
--- a/src/Search_2D/env.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""
-Env 2D
-@author: huiming zhou
-"""
-
-
-class Env:
- def __init__(self):
- self.x_range = 51 # size of background
- self.y_range = 31
- self.motions = [(-1, 0), (-1, 1), (0, 1), (1, 1),
- (1, 0), (1, -1), (0, -1), (-1, -1)]
- self.obs = self.obs_map()
-
- def update_obs(self, obs):
- self.obs = obs
-
- def obs_map(self):
- """
- Initialize obstacles' positions
- :return: map of obstacles
- """
-
- x = self.x_range #51
- y = self.y_range #31
- obs = set()
- #画上下边框
- for i in range(x):
- obs.add((i, 0))
- for i in range(x):
- obs.add((i, y - 1))
- #画左右边框
- for i in range(y):
- obs.add((0, i))
- for i in range(y):
- obs.add((x - 1, i))
-
- for i in range(2, 21):
- obs.add((i, 15))
- for i in range(15):
- obs.add((20, i))
-
- for i in range(15, 30):
- obs.add((30, i))
- for i in range(16):
- obs.add((40, i))
-
- return obs
-
-# if __name__ == '__main__':
-# a = Env()
-# print(a.obs)
\ No newline at end of file
diff --git a/src/Search_2D/plotting.py b/src/Search_2D/plotting.py
deleted file mode 100644
index 1cf98a3..0000000
--- a/src/Search_2D/plotting.py
+++ /dev/null
@@ -1,165 +0,0 @@
-"""
-Plot tools 2D
-@author: huiming zhou
-"""
-
-import os
-import sys
-import matplotlib.pyplot as plt
-
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
- "/../../Search_based_Planning/")
-
-from Search_2D import env
-
-
-class Plotting:
- def __init__(self, xI, xG):
- self.xI, self.xG = xI, xG
- self.env = env.Env()
- self.obs = self.env.obs_map()
-
- def update_obs(self, obs):
- self.obs = obs
-
- def animation(self, path, visited, name):
- self.plot_grid(name)
- self.plot_visited(visited)
- self.plot_path(path)
- plt.show()
-
- def animation_lrta(self, path, visited, name):
- self.plot_grid(name)
- cl = self.color_list_2()
- path_combine = []
-
- for k in range(len(path)):
- self.plot_visited(visited[k], cl[k])
- plt.pause(0.2)
- self.plot_path(path[k])
- path_combine += path[k]
- plt.pause(0.2)
- if self.xI in path_combine:
- path_combine.remove(self.xI)
- self.plot_path(path_combine)
- plt.show()
-
- def animation_ara_star(self, path, visited, name):
- self.plot_grid(name)
- cl_v, cl_p = self.color_list()
-
- for k in range(len(path)):
- self.plot_visited(visited[k], cl_v[k])
- self.plot_path(path[k], cl_p[k], True)
- plt.pause(0.5)
-
- plt.show()
-
- def animation_bi_astar(self, path, v_fore, v_back, name):
- self.plot_grid(name)
- self.plot_visited_bi(v_fore, v_back)
- self.plot_path(path)
- plt.show()
-
- def plot_grid(self, name):
- obs_x = [x[0] for x in self.obs]
- obs_y = [x[1] for x in self.obs]
-
- plt.plot(self.xI[0], self.xI[1], "bs")
- plt.plot(self.xG[0], self.xG[1], "gs")
- plt.plot(obs_x, obs_y, "sk")
- plt.title(name)
- plt.axis("equal")
-
- def plot_visited(self, visited, cl='gray'):
- if self.xI in visited:
- visited.remove(self.xI)
-
- if self.xG in visited:
- visited.remove(self.xG)
-
- count = 0
-
- for x in visited:
- count += 1
- plt.plot(x[0], x[1], color=cl, marker='o')
- plt.gcf().canvas.mpl_connect('key_release_event',
- lambda event: [exit(0) if event.key == 'escape' else None])
-
- if count < len(visited) / 3:
- length = 20
- elif count < len(visited) * 2 / 3:
- length = 30
- else:
- length = 40
- #
- # length = 15
-
- if count % length == 0:
- plt.pause(0.001)
- plt.pause(0.01)
-
- def plot_path(self, path, cl='r', flag=False):
- path_x = [path[i][0] for i in range(len(path))]
- path_y = [path[i][1] for i in range(len(path))]
-
- if not flag:
- plt.plot(path_x, path_y, linewidth='3', color='r')
- else:
- plt.plot(path_x, path_y, linewidth='3', color=cl)
-
- plt.plot(self.xI[0], self.xI[1], "bs")
- plt.plot(self.xG[0], self.xG[1], "gs")
-
- plt.pause(0.01)
-
- def plot_visited_bi(self, v_fore, v_back):
- if self.xI in v_fore:
- v_fore.remove(self.xI)
-
- if self.xG in v_back:
- v_back.remove(self.xG)
-
- len_fore, len_back = len(v_fore), len(v_back)
-
- for k in range(max(len_fore, len_back)):
- if k < len_fore:
- plt.plot(v_fore[k][0], v_fore[k][1], linewidth='3', color='gray', marker='o')
- if k < len_back:
- plt.plot(v_back[k][0], v_back[k][1], linewidth='3', color='cornflowerblue', marker='o')
-
- plt.gcf().canvas.mpl_connect('key_release_event',
- lambda event: [exit(0) if event.key == 'escape' else None])
-
- if k % 10 == 0:
- plt.pause(0.001)
- plt.pause(0.01)
-
- @staticmethod
- def color_list():
- cl_v = ['silver',
- 'wheat',
- 'lightskyblue',
- 'royalblue',
- 'slategray']
- cl_p = ['gray',
- 'orange',
- 'deepskyblue',
- 'red',
- 'm']
- return cl_v, cl_p
-
- @staticmethod
- def color_list_2():
- cl = ['silver',
- 'steelblue',
- 'dimgray',
- 'cornflowerblue',
- 'dodgerblue',
- 'royalblue',
- 'plum',
- 'mediumslateblue',
- 'mediumpurple',
- 'blueviolet',
- ]
- return cl
diff --git a/src/Search_2D/queue.py b/src/Search_2D/queue.py
deleted file mode 100644
index 51703ae..0000000
--- a/src/Search_2D/queue.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import collections
-import heapq
-
-
-class QueueFIFO:
- """
- Class: QueueFIFO
- Description: QueueFIFO is designed for First-in-First-out rule.
- """
-
- def __init__(self):
- self.queue = collections.deque()
-
- def empty(self):
- return len(self.queue) == 0
-
- def put(self, node):
- self.queue.append(node) # enter from back
-
- def get(self):
- return self.queue.popleft() # leave from front
-
-
-class QueueLIFO:
- """
- Class: QueueLIFO
- Description: QueueLIFO is designed for Last-in-First-out rule.
- """
-
- def __init__(self):
- self.queue = collections.deque()
-
- def empty(self):
- return len(self.queue) == 0
-
- def put(self, node):
- self.queue.append(node) # enter from back
-
- def get(self):
- return self.queue.pop() # leave from back
-
-
-class QueuePrior:
- """
- Class: QueuePrior
- Description: QueuePrior reorders elements using value [priority]
- """
-
- def __init__(self):
- self.queue = []
-
- def empty(self):
- return len(self.queue) == 0
-
- def put(self, item, priority):
- heapq.heappush(self.queue, (priority, item)) # reorder s using priority
-
- def get(self):
- return heapq.heappop(self.queue)[1] # pop out the smallest item
-
- def enumerate(self):
- return self.queue