You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Conception/drake-master/solvers/solver_options.cc

253 lines
8.4 KiB

#include "drake/solvers/solver_options.h"
#include <map>
#include <sstream>
#include <utility>
#include <fmt/format.h>
#include "drake/common/never_destroyed.h"
namespace drake {
namespace solvers {
// A shorthand for our member field type, for options typed as T's, as in
// MapMap[SolverId][string] => T.
template <typename T>
using MapMap = std::unordered_map<SolverId, std::unordered_map<std::string, T>>;
// A shorthand for our member field type.
using CommonMap =
std::unordered_map<CommonSolverOption, SolverOptions::OptionValue>;
void SolverOptions::SetOption(const SolverId& solver_id,
const std::string& solver_option,
double option_value) {
solver_options_double_[solver_id][solver_option] = option_value;
}
void SolverOptions::SetOption(const SolverId& solver_id,
const std::string& solver_option,
int option_value) {
solver_options_int_[solver_id][solver_option] = option_value;
}
void SolverOptions::SetOption(const SolverId& solver_id,
const std::string& solver_option,
const std::string& option_value) {
solver_options_str_[solver_id][solver_option] = option_value;
}
void SolverOptions::SetOption(CommonSolverOption key, OptionValue value) {
switch (key) {
case CommonSolverOption::kPrintToConsole: {
if (!std::holds_alternative<int>(value)) {
throw std::runtime_error(fmt::format(
"SolverOptions::SetOption support {} only with int value.", key));
}
const int int_value = std::get<int>(value);
if (int_value != 0 && int_value != 1) {
throw std::runtime_error(
fmt::format("{} expects value either 0 or 1", key));
}
common_solver_options_[key] = std::move(value);
return;
}
case CommonSolverOption::kPrintFileName: {
if (!std::holds_alternative<std::string>(value)) {
throw std::runtime_error(fmt::format(
"SolverOptions::SetOption support {} only with std::string value.",
key));
}
common_solver_options_[key] = std::move(value);
return;
}
}
DRAKE_UNREACHABLE();
}
namespace {
// If options has an entry for the given solver_id, returns a reference to the
// mapped value. Otherwise, returns a long-lived reference to an empty value.
template <typename T>
const std::unordered_map<std::string, T>& GetOptionsHelper(
const SolverId& solver_id, const MapMap<T>& options) {
static never_destroyed<std::unordered_map<std::string, T>> empty;
const auto iter = options.find(solver_id);
return (iter != options.end()) ? iter->second : empty.access();
}
} // namespace
const std::unordered_map<std::string, double>& SolverOptions::GetOptionsDouble(
const SolverId& solver_id) const {
return GetOptionsHelper(solver_id, solver_options_double_);
}
const std::unordered_map<std::string, int>& SolverOptions::GetOptionsInt(
const SolverId& solver_id) const {
return GetOptionsHelper(solver_id, solver_options_int_);
}
const std::unordered_map<std::string, std::string>&
SolverOptions::GetOptionsStr(const SolverId& solver_id) const {
return GetOptionsHelper(solver_id, solver_options_str_);
}
std::string SolverOptions::get_print_file_name() const {
// N.B. SetOption sanity checks the value; we don't need to re-check here.
std::string result;
auto iter = common_solver_options_.find(CommonSolverOption::kPrintFileName);
if (iter != common_solver_options_.end()) {
result = std::get<std::string>(iter->second);
}
return result;
}
bool SolverOptions::get_print_to_console() const {
// N.B. SetOption sanity checks the value; we don't need to re-check here.
bool result = false;
auto iter = common_solver_options_.find(CommonSolverOption::kPrintToConsole);
if (iter != common_solver_options_.end()) {
const int value = std::get<int>(iter->second);
result = static_cast<bool>(value);
}
return result;
}
std::unordered_set<SolverId> SolverOptions::GetSolverIds() const {
std::unordered_set<SolverId> result;
for (const auto& pair : solver_options_double_) {
result.insert(pair.first);
}
for (const auto& pair : solver_options_int_) {
result.insert(pair.first);
}
for (const auto& pair : solver_options_str_) {
result.insert(pair.first);
}
return result;
}
namespace {
template <typename T>
void MergeHelper(const MapMap<T>& other, MapMap<T>* self) {
for (const auto& other_id_keyvals : other) {
const SolverId& id = other_id_keyvals.first;
std::unordered_map<std::string, T>& self_keyvals = (*self)[id];
for (const auto& other_keyval : other_id_keyvals.second) {
// This is a no-op when the key already exists.
self_keyvals.insert(other_keyval);
}
}
}
void MergeHelper(const CommonMap& other, CommonMap* self) {
for (const auto& other_keyval : other) {
// This is a no-op when the key already exists.
self->insert(other_keyval);
}
}
} // namespace
void SolverOptions::Merge(const SolverOptions& other) {
MergeHelper(other.solver_options_double_, &solver_options_double_);
MergeHelper(other.solver_options_int_, &solver_options_int_);
MergeHelper(other.solver_options_str_, &solver_options_str_);
MergeHelper(other.common_solver_options_, &common_solver_options_);
}
bool SolverOptions::operator==(const SolverOptions& other) const {
return solver_options_double_ == other.solver_options_double_ &&
solver_options_int_ == other.solver_options_int_ &&
solver_options_str_ == other.solver_options_str_ &&
common_solver_options_ == other.common_solver_options_;
}
bool SolverOptions::operator!=(const SolverOptions& other) const {
return !(*this == other);
}
namespace {
template <typename T>
void Summarize(const SolverId& id,
const std::unordered_map<std::string, T>& keyvals,
std::map<std::string, std::string>* pairs) {
for (const auto& keyval : keyvals) {
(*pairs)[fmt::format("{}:{}", id.name(), keyval.first)] =
fmt::format("{}", keyval.second);
}
}
} // namespace
std::ostream& operator<<(std::ostream& os, const SolverOptions& x) {
os << "{SolverOptions";
const auto& ids = x.GetSolverIds();
if (ids.empty()) {
os << " empty";
} else {
// Map keyed on "solver_name:option_key" so our output is deterministic.
std::map<std::string, std::string> pairs;
for (const auto& id : ids) {
Summarize(id, x.GetOptionsDouble(id), &pairs);
Summarize(id, x.GetOptionsInt(id), &pairs);
Summarize(id, x.GetOptionsStr(id), &pairs);
}
for (const auto& keyval : x.common_solver_options()) {
const CommonSolverOption& key = keyval.first;
const auto& val = keyval.second;
std::visit(
[key, &pairs](auto& val_x) {
pairs[fmt::format("CommonSolverOption::{}", key)] =
fmt::format("{}", val_x);
},
val);
}
for (const auto& pair : pairs) {
os << ", " << pair.first << "=" << pair.second;
}
}
os << "}";
return os;
}
std::string to_string(const SolverOptions& x) {
std::ostringstream result;
result << x;
return result.str();
}
namespace {
// Check if all the keys in key_value pair key_vals is a subset of
// allowable_keys, and throw an invalid argument if not.
template <typename T>
void CheckOptionKeysForSolverHelper(
const std::unordered_map<std::string, T>& key_vals,
const std::unordered_set<std::string>& allowable_keys,
const std::string& solver_name) {
for (const auto& key_val : key_vals) {
if (allowable_keys.count(key_val.first) == 0) {
throw std::invalid_argument(key_val.first +
" is not allowed in the SolverOptions for " +
solver_name + ".");
}
}
}
} // namespace
void SolverOptions::CheckOptionKeysForSolver(
const SolverId& solver_id,
const std::unordered_set<std::string>& double_keys,
const std::unordered_set<std::string>& int_keys,
const std::unordered_set<std::string>& str_keys) const {
CheckOptionKeysForSolverHelper(GetOptionsDouble(solver_id), double_keys,
solver_id.name());
CheckOptionKeysForSolverHelper(GetOptionsInt(solver_id), int_keys,
solver_id.name());
CheckOptionKeysForSolverHelper(GetOptionsStr(solver_id), str_keys,
solver_id.name());
}
} // namespace solvers
} // namespace drake