forked from pz4kybsvg/Conception
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
986 lines
41 KiB
986 lines
41 KiB
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import copy
|
||
|
from types import SimpleNamespace
|
||
|
import unittest
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from pydrake.autodiffutils import AutoDiffXd
|
||
|
from pydrake.common.test_utilities.deprecation import catch_drake_warnings
|
||
|
from pydrake.common.value import AbstractValue
|
||
|
from pydrake.symbolic import Expression
|
||
|
from pydrake.systems.analysis import (
|
||
|
Simulator,
|
||
|
)
|
||
|
from pydrake.systems.framework import (
|
||
|
AbstractParameterIndex,
|
||
|
AbstractStateIndex,
|
||
|
BasicVector, BasicVector_,
|
||
|
CacheEntry,
|
||
|
CacheEntryValue,
|
||
|
CacheIndex,
|
||
|
Context,
|
||
|
ContinuousStateIndex,
|
||
|
DependencyTicket,
|
||
|
Diagram,
|
||
|
DiagramBuilder,
|
||
|
DiscreteStateIndex,
|
||
|
DiscreteValues,
|
||
|
EventStatus,
|
||
|
InputPortIndex,
|
||
|
LeafSystem, LeafSystem_,
|
||
|
NumericParameterIndex,
|
||
|
PortDataType,
|
||
|
PublishEvent,
|
||
|
State,
|
||
|
System,
|
||
|
TriggerType,
|
||
|
UnrestrictedUpdateEvent,
|
||
|
ValueProducer,
|
||
|
VectorSystem,
|
||
|
WitnessFunctionDirection,
|
||
|
kUseDefaultName,
|
||
|
)
|
||
|
from pydrake.systems.primitives import (
|
||
|
Adder,
|
||
|
ZeroOrderHold,
|
||
|
)
|
||
|
|
||
|
from pydrake.systems.test.test_util import (
|
||
|
call_leaf_system_overrides,
|
||
|
call_vector_system_overrides,
|
||
|
)
|
||
|
|
||
|
from pydrake.common.test_utilities import numpy_compare
|
||
|
|
||
|
|
||
|
def noop(*args, **kwargs):
|
||
|
# When a callback is required for an interface, but not useful for testing.
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CustomAdder(LeafSystem):
|
||
|
# Reimplements `Adder`.
|
||
|
def __init__(self, num_inputs, size):
|
||
|
LeafSystem.__init__(self)
|
||
|
for i in range(num_inputs):
|
||
|
self.DeclareVectorInputPort(
|
||
|
"input{}".format(i), size)
|
||
|
self.DeclareVectorOutputPort("sum", size, self._calc_sum)
|
||
|
|
||
|
def _calc_sum(self, context, sum_data):
|
||
|
# @note This will NOT work if the scalar type is AutoDiff or symbolic,
|
||
|
# since they are not stored densely.
|
||
|
sum = sum_data.get_mutable_value()
|
||
|
sum[:] = 0
|
||
|
for i in range(context.num_input_ports()):
|
||
|
input_vector = self.EvalVectorInput(context=context, port_index=i)
|
||
|
sum += input_vector.get_value()
|
||
|
|
||
|
|
||
|
# TODO(eric.cousineau): Make this class work with custom scalar types once
|
||
|
# referencing with custom dtypes lands.
|
||
|
# WARNING: At present, dtype=object matrices are NOT well supported, and may
|
||
|
# produce unexpected results (e.g. references not actually being respected).
|
||
|
|
||
|
|
||
|
class CustomVectorSystem(VectorSystem):
|
||
|
def __init__(self, is_discrete):
|
||
|
# VectorSystem only supports pure Continuous or pure Discrete.
|
||
|
# Dimensions:
|
||
|
# 1 Input, 2 States, 3 Outputs.
|
||
|
VectorSystem.__init__(self, 1, 3)
|
||
|
self._is_discrete = is_discrete
|
||
|
if self._is_discrete:
|
||
|
self.DeclareDiscreteState(2)
|
||
|
else:
|
||
|
self.DeclareContinuousState(2)
|
||
|
# Record calls for testing.
|
||
|
self.has_called = []
|
||
|
|
||
|
def DoCalcVectorOutput(self, context, u, x, y):
|
||
|
y[:] = np.hstack([u, x])
|
||
|
self.has_called.append("output")
|
||
|
|
||
|
def DoCalcVectorTimeDerivatives(self, context, u, x, x_dot):
|
||
|
x_dot[:] = x + u
|
||
|
self.has_called.append("continuous")
|
||
|
|
||
|
def DoCalcVectorDiscreteVariableUpdates(self, context, u, x, x_n):
|
||
|
x_n[:] = x + 2*u
|
||
|
self.has_called.append("discrete")
|
||
|
|
||
|
|
||
|
# Wraps `Adder`.
|
||
|
class CustomDiagram(Diagram):
|
||
|
# N.B. The CustomDiagram is used to unit test the DiagramBuilder.BuildInto
|
||
|
# method. For pydrake users, this is not a good example. The best way in
|
||
|
# pydrake to create a Diagram is DiagramBuilder.Build (as seen in the test
|
||
|
# case named test_adder_simulation).
|
||
|
|
||
|
def __init__(self, num_inputs, size):
|
||
|
Diagram.__init__(self)
|
||
|
builder = DiagramBuilder()
|
||
|
adder = Adder(num_inputs, size)
|
||
|
builder.AddSystem(adder)
|
||
|
builder.ExportOutput(adder.get_output_port(0))
|
||
|
for i in range(num_inputs):
|
||
|
builder.ExportInput(adder.get_input_port(i))
|
||
|
builder.BuildInto(self)
|
||
|
|
||
|
|
||
|
class TestCustom(unittest.TestCase):
|
||
|
def _create_adder_system(self):
|
||
|
system = CustomAdder(2, 3)
|
||
|
return system
|
||
|
|
||
|
def _fix_adder_inputs(self, system, context):
|
||
|
self.assertEqual(context.num_input_ports(), 2)
|
||
|
system.get_input_port(0).FixValue(context, [1, 2, 3])
|
||
|
system.get_input_port(1).FixValue(context, [4, 5, 6])
|
||
|
|
||
|
def test_diagram_adder(self):
|
||
|
system = CustomDiagram(2, 3)
|
||
|
self.assertEqual(system.GetSystemType(), f"{__name__}.CustomDiagram")
|
||
|
self.assertEqual(system.num_input_ports(), 2)
|
||
|
self.assertEqual(system.get_input_port(0).size(), 3)
|
||
|
self.assertEqual(system.num_output_ports(), 1)
|
||
|
self.assertEqual(system.get_output_port(0).size(), 3)
|
||
|
|
||
|
def test_adder_execution(self):
|
||
|
system = self._create_adder_system()
|
||
|
self.assertEqual(system.GetSystemType(), f"{__name__}.CustomAdder")
|
||
|
context = system.CreateDefaultContext()
|
||
|
self.assertEqual(context.num_output_ports(), 1)
|
||
|
self._fix_adder_inputs(system, context)
|
||
|
output = system.AllocateOutput()
|
||
|
self.assertEqual(output.num_ports(), 1)
|
||
|
system.CalcOutput(context, output)
|
||
|
value = output.get_vector_data(0).get_value()
|
||
|
self.assertTrue(np.allclose([5, 7, 9], value))
|
||
|
|
||
|
def test_adder_simulation(self):
|
||
|
builder = DiagramBuilder()
|
||
|
adder = builder.AddSystem(self._create_adder_system())
|
||
|
adder.set_name("custom_adder")
|
||
|
# Add ZOH so we can easily extract state.
|
||
|
zoh = builder.AddSystem(ZeroOrderHold(0.1, 3))
|
||
|
zoh.set_name("zoh")
|
||
|
|
||
|
builder.ExportInput(adder.get_input_port(0))
|
||
|
builder.ExportInput(adder.get_input_port(1))
|
||
|
builder.Connect(adder.get_output_port(0), zoh.get_input_port(0))
|
||
|
diagram = builder.Build()
|
||
|
context = diagram.CreateDefaultContext()
|
||
|
self._fix_adder_inputs(diagram, context)
|
||
|
|
||
|
simulator = Simulator(diagram, context)
|
||
|
simulator.Initialize()
|
||
|
simulator.AdvanceTo(1)
|
||
|
# Ensure that we have the outputs we want.
|
||
|
value = (diagram.GetMutableSubsystemContext(zoh, context)
|
||
|
.get_discrete_state_vector().get_value())
|
||
|
self.assertTrue(np.allclose([5, 7, 9], value))
|
||
|
|
||
|
def test_leaf_system_well_known_tickets(self):
|
||
|
for func in [
|
||
|
LeafSystem.accuracy_ticket,
|
||
|
LeafSystem.all_input_ports_ticket,
|
||
|
LeafSystem.all_parameters_ticket,
|
||
|
LeafSystem.all_sources_except_input_ports_ticket,
|
||
|
LeafSystem.all_sources_ticket,
|
||
|
LeafSystem.all_state_ticket,
|
||
|
LeafSystem.configuration_ticket,
|
||
|
LeafSystem.ke_ticket,
|
||
|
LeafSystem.kinematics_ticket,
|
||
|
LeafSystem.nothing_ticket,
|
||
|
LeafSystem.pa_ticket,
|
||
|
LeafSystem.pc_ticket,
|
||
|
LeafSystem.pe_ticket,
|
||
|
LeafSystem.pn_ticket,
|
||
|
LeafSystem.pnc_ticket,
|
||
|
LeafSystem.q_ticket,
|
||
|
LeafSystem.time_ticket,
|
||
|
LeafSystem.v_ticket,
|
||
|
LeafSystem.xa_ticket,
|
||
|
LeafSystem.xc_ticket,
|
||
|
LeafSystem.xcdot_ticket,
|
||
|
LeafSystem.xd_ticket,
|
||
|
LeafSystem.z_ticket]:
|
||
|
self.assertIsInstance(func(), DependencyTicket, func)
|
||
|
|
||
|
def test_leaf_system_per_item_tickets(self):
|
||
|
dut = LeafSystem()
|
||
|
dut.DeclareAbstractParameter(model_value=AbstractValue.Make(1))
|
||
|
dut.DeclareAbstractState(model_value=AbstractValue.Make(1))
|
||
|
dut.DeclareDiscreteState(1)
|
||
|
dut.DeclareVectorInputPort("u0", BasicVector(1))
|
||
|
self.assertEqual(dut.DeclareVectorInputPort("u1", 2).size(), 2)
|
||
|
dut.DeclareNumericParameter(model_vector=BasicVector(1))
|
||
|
for func, arg in [
|
||
|
(dut.abstract_parameter_ticket, AbstractParameterIndex(0)),
|
||
|
(dut.abstract_state_ticket, AbstractStateIndex(0)),
|
||
|
(dut.cache_entry_ticket, CacheIndex(0)),
|
||
|
(dut.discrete_state_ticket, DiscreteStateIndex(0)),
|
||
|
(dut.input_port_ticket, InputPortIndex(0)),
|
||
|
(dut.numeric_parameter_ticket, NumericParameterIndex(0)),
|
||
|
]:
|
||
|
self.assertIsInstance(func(arg), DependencyTicket, func)
|
||
|
|
||
|
def test_cache_entry(self):
|
||
|
"""Checks the existence of CacheEntry-related bindings."""
|
||
|
|
||
|
# Cover DeclareCacheEntry.
|
||
|
dummy = LeafSystem()
|
||
|
model_value = AbstractValue.Make(SimpleNamespace())
|
||
|
|
||
|
def calc_cache(context, abstract_value):
|
||
|
cache = abstract_value.get_mutable_value()
|
||
|
self.assertIsInstance(cache, SimpleNamespace)
|
||
|
cache.updated = True
|
||
|
|
||
|
cache_entry = dummy.DeclareCacheEntry(
|
||
|
description="scratch",
|
||
|
value_producer=ValueProducer(
|
||
|
allocate=model_value.Clone,
|
||
|
calc=calc_cache),
|
||
|
prerequisites_of_calc={dummy.nothing_ticket()})
|
||
|
self.assertIsInstance(cache_entry, CacheEntry)
|
||
|
|
||
|
# Cover CacheEntry and get_cache_entry.
|
||
|
self.assertIsInstance(cache_entry.prerequisites(), set)
|
||
|
cache_index = cache_entry.cache_index()
|
||
|
self.assertIsInstance(cache_index, CacheIndex)
|
||
|
self.assertIsInstance(cache_entry.ticket(), DependencyTicket)
|
||
|
self.assertIs(dummy.get_cache_entry(cache_index), cache_entry)
|
||
|
|
||
|
# Cover CacheEntryValue.
|
||
|
# WARNING: This is not the suggested workflow for proper bindings. See
|
||
|
# below for proper workflow using .Eval().
|
||
|
context = dummy.CreateDefaultContext()
|
||
|
cache_entry_value = cache_entry.get_mutable_cache_entry_value(context)
|
||
|
self.assertIsInstance(cache_entry_value, CacheEntryValue)
|
||
|
data = cache_entry_value.GetMutableValueOrThrow()
|
||
|
self.assertIsInstance(data, SimpleNamespace)
|
||
|
# This has not yet been updated.
|
||
|
self.assertFalse(hasattr(data, "updated"))
|
||
|
# Const flavor access.
|
||
|
cache_entry_value_const = cache_entry.get_cache_entry_value(context)
|
||
|
self.assertIs(cache_entry_value_const, cache_entry_value)
|
||
|
# Const flavor is out of date.
|
||
|
with self.assertRaises(RuntimeError) as cm:
|
||
|
cache_entry_value_const.GetValueOrThrow()
|
||
|
self.assertIn("the current value is out of date", str(cm.exception))
|
||
|
|
||
|
# Now properly update the cache entry.
|
||
|
# Using .Eval() is the best workflow to follow.
|
||
|
data_updated = cache_entry.Eval(context)
|
||
|
# Ensure we didn't clone.
|
||
|
self.assertIs(data, data_updated)
|
||
|
# Mutated!
|
||
|
self.assertTrue(data.updated)
|
||
|
# Check abstract access.
|
||
|
self.assertIs(cache_entry.EvalAbstract(context).get_value(), data)
|
||
|
# Now check const aliasing.
|
||
|
data_const = cache_entry_value_const.GetValueOrThrow()
|
||
|
self.assertIs(data_const, data)
|
||
|
|
||
|
def test_leaf_system_issue13792(self):
|
||
|
"""
|
||
|
Ensures that users get a better error when forgetting to explicitly
|
||
|
call the C++ superclass's __init__.
|
||
|
"""
|
||
|
|
||
|
class Oops(LeafSystem):
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
with self.assertRaisesRegex(TypeError, "LeafSystem.*__init__"):
|
||
|
Oops()
|
||
|
|
||
|
def test_all_leaf_system_overrides(self):
|
||
|
test = self
|
||
|
|
||
|
class TrivialSystem(LeafSystem):
|
||
|
def __init__(self):
|
||
|
LeafSystem.__init__(self)
|
||
|
self.called_publish = False
|
||
|
self.called_continuous = False
|
||
|
self.called_discrete = False
|
||
|
self.called_initialize = False
|
||
|
self.called_per_step = False
|
||
|
self.called_periodic = False
|
||
|
self.called_initialize_publish = False
|
||
|
self.called_initialize_discrete = False
|
||
|
self.called_initialize_unrestricted = False
|
||
|
self.called_periodic_publish = False
|
||
|
self.called_periodic_discrete = False
|
||
|
self.called_periodic_unrestricted = False
|
||
|
self.called_per_step_publish = False
|
||
|
self.called_per_step_discrete = False
|
||
|
self.called_per_step_unrestricted = False
|
||
|
self.called_forced_publish = False
|
||
|
self.called_forced_discrete = False
|
||
|
self.called_forced_unrestricted = False
|
||
|
self.called_getwitness = False
|
||
|
self.called_witness = False
|
||
|
self.called_guard = False
|
||
|
self.called_reset = False
|
||
|
self.called_system_reset = False
|
||
|
# Ensure we have desired overloads.
|
||
|
self.DeclarePeriodicPublishNoHandler(1.0)
|
||
|
self.DeclarePeriodicPublishNoHandler(1.0, 0)
|
||
|
self.DeclarePeriodicPublishNoHandler(
|
||
|
period_sec=1.0, offset_sec=0)
|
||
|
self.DeclareInitializationPublishEvent(
|
||
|
publish=self._on_initialize_publish)
|
||
|
self.DeclareInitializationDiscreteUpdateEvent(
|
||
|
update=self._on_initialize_discrete)
|
||
|
self.DeclareInitializationUnrestrictedUpdateEvent(
|
||
|
update=self._on_initialize_unrestricted)
|
||
|
self.DeclareInitializationEvent(
|
||
|
event=PublishEvent(
|
||
|
trigger_type=TriggerType.kInitialization,
|
||
|
callback=self._on_initialize))
|
||
|
self.DeclarePeriodicDiscreteUpdateNoHandler(
|
||
|
period_sec=1.0, offset_sec=0.)
|
||
|
self.DeclarePeriodicPublishEvent(
|
||
|
period_sec=1.0,
|
||
|
offset_sec=0,
|
||
|
publish=self._on_periodic_publish)
|
||
|
self.DeclarePeriodicDiscreteUpdateEvent(
|
||
|
period_sec=1.0,
|
||
|
offset_sec=0,
|
||
|
update=self._on_periodic_discrete)
|
||
|
self.DeclarePeriodicUnrestrictedUpdateEvent(
|
||
|
period_sec=1.0,
|
||
|
offset_sec=0,
|
||
|
update=self._on_periodic_unrestricted)
|
||
|
self.DeclarePerStepPublishEvent(
|
||
|
publish=self._on_per_step_publish)
|
||
|
self.DeclarePerStepDiscreteUpdateEvent(
|
||
|
update=self._on_per_step_discrete)
|
||
|
self.DeclarePerStepUnrestrictedUpdateEvent(
|
||
|
update=self._on_per_step_unrestricted)
|
||
|
self.DeclarePerStepEvent(
|
||
|
event=PublishEvent(
|
||
|
trigger_type=TriggerType.kPerStep,
|
||
|
callback=self._on_per_step))
|
||
|
self.DeclareForcedPublishEvent(
|
||
|
publish=self._on_forced_publish)
|
||
|
self.DeclareForcedDiscreteUpdateEvent(
|
||
|
update=self._on_forced_discrete)
|
||
|
self.DeclareForcedUnrestrictedUpdateEvent(
|
||
|
update=self._on_forced_unrestricted)
|
||
|
self.DeclarePeriodicEvent(
|
||
|
period_sec=1.0,
|
||
|
offset_sec=0.0,
|
||
|
event=PublishEvent(
|
||
|
trigger_type=TriggerType.kPeriodic,
|
||
|
callback=self._on_periodic))
|
||
|
self.DeclareContinuousState(2)
|
||
|
self.DeclareDiscreteState(1)
|
||
|
# Ensure that we have inputs / outputs to call direct
|
||
|
# feedthrough.
|
||
|
self.DeclareInputPort(
|
||
|
kUseDefaultName, PortDataType.kVectorValued, 1)
|
||
|
self.DeclareVectorInputPort(
|
||
|
name="test_input", model_vector=BasicVector(1),
|
||
|
random_type=None)
|
||
|
self.DeclareVectorOutputPort(
|
||
|
"noop", BasicVector(1), noop,
|
||
|
prerequisites_of_calc=set([self.nothing_ticket()]))
|
||
|
self.DeclareVectorOutputPort("noop2",
|
||
|
1,
|
||
|
noop,
|
||
|
prerequisites_of_calc=set(
|
||
|
[self.nothing_ticket()]))
|
||
|
self.witness = self.MakeWitnessFunction(
|
||
|
"witness", WitnessFunctionDirection.kCrossesZero,
|
||
|
self._witness)
|
||
|
# Test bindings for both callback function signatures.
|
||
|
self.reset_witness = self.MakeWitnessFunction(
|
||
|
"reset", WitnessFunctionDirection.kCrossesZero,
|
||
|
self._guard, UnrestrictedUpdateEvent(self._reset))
|
||
|
self.system_reset_witness = self.MakeWitnessFunction(
|
||
|
"system reset", WitnessFunctionDirection.kCrossesZero,
|
||
|
self._guard, UnrestrictedUpdateEvent(
|
||
|
system_callback=self._system_reset))
|
||
|
|
||
|
def DoPublish(self, context, events):
|
||
|
# Call base method to ensure we do not get recursion.
|
||
|
LeafSystem.DoPublish(self, context, events)
|
||
|
# N.B. We do not test for a singular call to `DoPublish`
|
||
|
# (checking `assertFalse(self.called_publish)` first) because
|
||
|
# the above `_DeclareInitializationEvent` will call both its
|
||
|
# callback and this event when invoked via
|
||
|
# `Simulator::Initialize` from `call_leaf_system_overrides`,
|
||
|
# even when we explicitly say not to publish at initialize.
|
||
|
self.called_publish = True
|
||
|
|
||
|
def DoCalcTimeDerivatives(self, context, derivatives):
|
||
|
# Note: Don't call base method here; it would abort because
|
||
|
# derivatives.size() != 0.
|
||
|
test.assertEqual(derivatives.get_vector().size(), 2)
|
||
|
self.called_continuous = True
|
||
|
|
||
|
def DoCalcDiscreteVariableUpdates(
|
||
|
self, context, events, discrete_state):
|
||
|
# Call base method to ensure we do not get recursion.
|
||
|
LeafSystem.DoCalcDiscreteVariableUpdates(
|
||
|
self, context, events, discrete_state)
|
||
|
self.called_discrete = True
|
||
|
|
||
|
def DoGetWitnessFunctions(self, context):
|
||
|
self.called_getwitness = True
|
||
|
return [self.witness, self.reset_witness,
|
||
|
self.system_reset_witness]
|
||
|
|
||
|
def _on_initialize(self, context, event):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(event, PublishEvent)
|
||
|
test.assertFalse(self.called_initialize)
|
||
|
self.called_initialize = True
|
||
|
|
||
|
def _on_per_step(self, context, event):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(event, PublishEvent)
|
||
|
self.called_per_step = True
|
||
|
|
||
|
def _on_periodic(self, context, event):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(event, PublishEvent)
|
||
|
test.assertFalse(self.called_periodic)
|
||
|
self.called_periodic = True
|
||
|
|
||
|
def _on_initialize_publish(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertFalse(self.called_initialize_publish)
|
||
|
self.called_initialize_publish = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_initialize_discrete(self, context, discrete_state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(discrete_state, DiscreteValues)
|
||
|
test.assertFalse(self.called_initialize_discrete)
|
||
|
self.called_initialize_discrete = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_initialize_unrestricted(self, context, state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(state, State)
|
||
|
test.assertFalse(self.called_initialize_unrestricted)
|
||
|
self.called_initialize_unrestricted = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_periodic_publish(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertFalse(self.called_periodic_publish)
|
||
|
self.called_periodic_publish = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_periodic_discrete(self, context, discrete_state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(discrete_state, DiscreteValues)
|
||
|
test.assertFalse(self.called_periodic_discrete)
|
||
|
self.called_periodic_discrete = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_periodic_unrestricted(self, context, state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(state, State)
|
||
|
test.assertFalse(self.called_periodic_unrestricted)
|
||
|
self.called_periodic_unrestricted = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_per_step_publish(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
self.called_per_step_publish = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_per_step_discrete(self, context, discrete_state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(discrete_state, DiscreteValues)
|
||
|
self.called_per_step_discrete = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_per_step_unrestricted(self, context, state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(state, State)
|
||
|
self.called_per_step_unrestricted = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_forced_publish(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertFalse(self.called_forced_publish)
|
||
|
self.called_forced_publish = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_forced_discrete(self, context, discrete_state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(discrete_state, DiscreteValues)
|
||
|
test.assertFalse(self.called_forced_discrete)
|
||
|
self.called_forced_discrete = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _on_forced_unrestricted(self, context, state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(state, State)
|
||
|
test.assertFalse(self.called_forced_unrestricted)
|
||
|
self.called_forced_unrestricted = True
|
||
|
return EventStatus.Succeeded()
|
||
|
|
||
|
def _witness(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
self.called_witness = True
|
||
|
return 1.0
|
||
|
|
||
|
def _guard(self, context):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
self.called_guard = True
|
||
|
return context.get_time() - 0.5
|
||
|
|
||
|
def _reset(self, context, event, state):
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(event, UnrestrictedUpdateEvent)
|
||
|
test.assertIsInstance(state, State)
|
||
|
self.called_reset = True
|
||
|
|
||
|
def _system_reset(self, system, context, event, state):
|
||
|
test.assertIsInstance(system, System)
|
||
|
test.assertIsInstance(context, Context)
|
||
|
test.assertIsInstance(event, UnrestrictedUpdateEvent)
|
||
|
test.assertIsInstance(state, State)
|
||
|
self.called_system_reset = True
|
||
|
|
||
|
system = TrivialSystem()
|
||
|
self.assertFalse(system.called_publish)
|
||
|
self.assertFalse(system.called_continuous)
|
||
|
self.assertFalse(system.called_discrete)
|
||
|
self.assertFalse(system.called_initialize)
|
||
|
results = call_leaf_system_overrides(system)
|
||
|
self.assertTrue(system.called_publish)
|
||
|
self.assertFalse(results["has_direct_feedthrough"])
|
||
|
self.assertTrue(system.called_continuous)
|
||
|
self.assertTrue(system.called_discrete)
|
||
|
self.assertTrue(system.called_initialize)
|
||
|
self.assertEqual(results["discrete_next_t"], 1.0)
|
||
|
|
||
|
self.assertFalse(system.HasAnyDirectFeedthrough())
|
||
|
self.assertFalse(system.HasDirectFeedthrough(output_port=0))
|
||
|
self.assertFalse(
|
||
|
system.HasDirectFeedthrough(input_port=0, output_port=0))
|
||
|
|
||
|
# Test explicit calls.
|
||
|
system = TrivialSystem()
|
||
|
context = system.CreateDefaultContext()
|
||
|
system.ForcedPublish(context=context)
|
||
|
self.assertTrue(system.called_publish)
|
||
|
self.assertTrue(system.called_forced_publish)
|
||
|
|
||
|
context_update = context.Clone()
|
||
|
system.CalcTimeDerivatives(
|
||
|
context=context,
|
||
|
derivatives=context_update.get_mutable_continuous_state())
|
||
|
self.assertTrue(system.called_continuous)
|
||
|
|
||
|
system.called_continuous = False
|
||
|
residual = system.AllocateImplicitTimeDerivativesResidual()
|
||
|
system.CalcImplicitTimeDerivativesResidual(
|
||
|
context=context,
|
||
|
proposed_derivatives=context_update.get_continuous_state(),
|
||
|
residual=residual)
|
||
|
np.testing.assert_allclose(residual, 0, 1e-14)
|
||
|
self.assertTrue(system.called_continuous)
|
||
|
np.testing.assert_allclose(
|
||
|
system.CalcImplicitTimeDerivativesResidual(
|
||
|
context=context,
|
||
|
proposed_derivatives=context_update.get_continuous_state()), 0,
|
||
|
1e-14)
|
||
|
|
||
|
witnesses = system.GetWitnessFunctions(context)
|
||
|
self.assertEqual(len(witnesses), 3)
|
||
|
|
||
|
system.CalcForcedDiscreteVariableUpdate(
|
||
|
context=context,
|
||
|
discrete_state=context_update.get_mutable_discrete_state())
|
||
|
self.assertTrue(system.called_discrete)
|
||
|
self.assertTrue(system.called_forced_discrete)
|
||
|
|
||
|
system.CalcForcedUnrestrictedUpdate(
|
||
|
context=context,
|
||
|
state=context_update.get_mutable_state()
|
||
|
)
|
||
|
self.assertTrue(system.called_forced_unrestricted)
|
||
|
|
||
|
# Test per-step, periodic, and witness call backs
|
||
|
system = TrivialSystem()
|
||
|
simulator = Simulator(system)
|
||
|
simulator.get_mutable_context().SetAccuracy(0.1)
|
||
|
# Stepping to 0.99 so that we get exactly one periodic event.
|
||
|
simulator.AdvanceTo(0.99)
|
||
|
self.assertTrue(system.called_per_step)
|
||
|
self.assertTrue(system.called_periodic)
|
||
|
self.assertTrue(system.called_initialize_publish)
|
||
|
self.assertTrue(system.called_initialize_discrete)
|
||
|
self.assertTrue(system.called_initialize_unrestricted)
|
||
|
self.assertTrue(system.called_periodic_publish)
|
||
|
self.assertTrue(system.called_periodic_discrete)
|
||
|
self.assertTrue(system.called_periodic_unrestricted)
|
||
|
self.assertTrue(system.called_per_step_publish)
|
||
|
self.assertTrue(system.called_per_step_discrete)
|
||
|
self.assertTrue(system.called_per_step_unrestricted)
|
||
|
self.assertTrue(system.called_getwitness)
|
||
|
self.assertTrue(system.called_witness)
|
||
|
self.assertTrue(system.called_guard)
|
||
|
self.assertTrue(system.called_reset)
|
||
|
self.assertTrue(system.called_system_reset)
|
||
|
|
||
|
def test_event_handler_returns_none(self):
|
||
|
"""Checks that a Python event handler callback function is allowed to
|
||
|
(implicitly) return None, instead of an EventStatus. Because of all the
|
||
|
setup boilerplate, we only test one specific event type and assume that
|
||
|
the other event types (which are implemented similarly) will likewise
|
||
|
behave the same.
|
||
|
"""
|
||
|
|
||
|
class PublishReturnsNoneSystem(LeafSystem):
|
||
|
def __init__(self):
|
||
|
LeafSystem.__init__(self)
|
||
|
self.called_periodic_publish = False
|
||
|
self.DeclarePeriodicPublishEvent(
|
||
|
period_sec=1.0, offset_sec=0.0,
|
||
|
publish=self._on_periodic_publish)
|
||
|
|
||
|
def _on_periodic_publish(self, context):
|
||
|
self.called_periodic_publish = True
|
||
|
# There is no `return` statement here; Python implicitly treats
|
||
|
# this like a `return None`.
|
||
|
|
||
|
system = PublishReturnsNoneSystem()
|
||
|
simulator = Simulator(system)
|
||
|
simulator.AdvanceTo(0.25)
|
||
|
self.assertTrue(system.called_periodic_publish)
|
||
|
|
||
|
def test_state_output_port_declarations(self):
|
||
|
"""Checks that DeclareStateOutputPort is bound."""
|
||
|
dut = LeafSystem()
|
||
|
|
||
|
xc_index = dut.DeclareContinuousState(2)
|
||
|
xc_port = dut.DeclareStateOutputPort(name="xc", state_index=xc_index)
|
||
|
self.assertEqual(xc_port.size(), 2)
|
||
|
|
||
|
xd_index = dut.DeclareDiscreteState(3)
|
||
|
xd_port = dut.DeclareStateOutputPort(name="xd", state_index=xd_index)
|
||
|
self.assertEqual(xd_port.size(), 3)
|
||
|
|
||
|
xa_index = dut.DeclareAbstractState(AbstractValue.Make(1))
|
||
|
xa_port = dut.DeclareStateOutputPort(name="xa", state_index=xa_index)
|
||
|
self.assertEqual(xa_port.get_name(), "xa")
|
||
|
|
||
|
def test_vector_system_overrides(self):
|
||
|
dt = 0.5
|
||
|
for is_discrete in [False, True]:
|
||
|
system = CustomVectorSystem(is_discrete)
|
||
|
self.assertEqual(
|
||
|
system.GetSystemType(), f"{__name__}.CustomVectorSystem")
|
||
|
context = system.CreateDefaultContext()
|
||
|
|
||
|
u = np.array([1.])
|
||
|
system.get_input_port(0).FixValue(context, u)
|
||
|
|
||
|
# Dispatch virtual calls from C++.
|
||
|
output = call_vector_system_overrides(
|
||
|
system, context, is_discrete, dt)
|
||
|
self.assertTrue(system.HasAnyDirectFeedthrough())
|
||
|
|
||
|
# Check call order.
|
||
|
update_type = is_discrete and "discrete" or "continuous"
|
||
|
self.assertEqual(
|
||
|
system.has_called,
|
||
|
[update_type, "output"])
|
||
|
|
||
|
# Check values.
|
||
|
state = context.get_state()
|
||
|
x = (is_discrete and state.get_discrete_state()
|
||
|
or state.get_continuous_state()).get_vector().get_value()
|
||
|
|
||
|
x0 = [0., 0.]
|
||
|
c = is_discrete and 2 or 1*dt
|
||
|
x_expected = x0 + c*u
|
||
|
self.assertTrue(np.allclose(x, x_expected))
|
||
|
|
||
|
# Check output.
|
||
|
y_expected = np.hstack([u, x])
|
||
|
y = output.get_vector_data(0).get_value()
|
||
|
self.assertTrue(np.allclose(y, y_expected))
|
||
|
|
||
|
def test_context_api(self):
|
||
|
# Capture miscellaneous functions not yet tested.
|
||
|
model_value = AbstractValue.Make("Hello")
|
||
|
model_vector = BasicVector([1., 2.])
|
||
|
|
||
|
class TrivialSystem(LeafSystem):
|
||
|
def __init__(self):
|
||
|
LeafSystem.__init__(self)
|
||
|
self.DeclareContinuousState(1)
|
||
|
self.DeclareDiscreteState(2)
|
||
|
self.DeclareAbstractState(model_value=model_value.Clone())
|
||
|
self.DeclareAbstractParameter(model_value=model_value.Clone())
|
||
|
self.DeclareNumericParameter(model_vector=model_vector.Clone())
|
||
|
|
||
|
system = TrivialSystem()
|
||
|
context = system.CreateDefaultContext()
|
||
|
self.assertTrue(
|
||
|
context.get_state() is context.get_mutable_state())
|
||
|
self.assertEqual(context.num_continuous_states(), 1)
|
||
|
self.assertTrue(
|
||
|
context.get_continuous_state_vector() is
|
||
|
context.get_mutable_continuous_state_vector())
|
||
|
self.assertEqual(context.num_discrete_state_groups(), 1)
|
||
|
self.assertTrue(
|
||
|
context.get_discrete_state_vector() is
|
||
|
context.get_mutable_discrete_state_vector())
|
||
|
self.assertTrue(
|
||
|
context.get_discrete_state(0) is
|
||
|
context.get_discrete_state_vector())
|
||
|
self.assertTrue(
|
||
|
context.get_discrete_state(0) is
|
||
|
context.get_discrete_state().get_vector(0))
|
||
|
self.assertTrue(
|
||
|
context.get_mutable_discrete_state(0) is
|
||
|
context.get_mutable_discrete_state_vector())
|
||
|
self.assertTrue(
|
||
|
context.get_mutable_discrete_state(0) is
|
||
|
context.get_mutable_discrete_state().get_vector(0))
|
||
|
self.assertEqual(context.num_abstract_states(), 1)
|
||
|
self.assertTrue(
|
||
|
context.get_abstract_state() is
|
||
|
context.get_mutable_abstract_state())
|
||
|
self.assertTrue(
|
||
|
context.get_abstract_state(0) is
|
||
|
context.get_mutable_abstract_state(0))
|
||
|
self.assertEqual(
|
||
|
context.get_abstract_state(0).get_value(), model_value.get_value())
|
||
|
|
||
|
# Check state API.
|
||
|
state = context.get_mutable_state()
|
||
|
self.assertTrue(
|
||
|
state.get_mutable_discrete_state(index=0) is
|
||
|
state.get_mutable_discrete_state().get_vector(index=0))
|
||
|
self.assertTrue(
|
||
|
state.get_abstract_state(index=0) is
|
||
|
state.get_abstract_state().get_value(index=0))
|
||
|
self.assertTrue(
|
||
|
state.get_mutable_abstract_state(index=0) is
|
||
|
state.get_mutable_abstract_state().get_value(index=0))
|
||
|
|
||
|
# Check abstract state API (also test AbstractValues).
|
||
|
values = context.get_abstract_state()
|
||
|
self.assertEqual(values.size(), 1)
|
||
|
self.assertEqual(
|
||
|
values.get_value(0).get_value(), model_value.get_value())
|
||
|
self.assertEqual(
|
||
|
values.get_mutable_value(0).get_value(), model_value.get_value())
|
||
|
values.SetFrom(values.Clone())
|
||
|
|
||
|
# Check parameter accessors.
|
||
|
self.assertEqual(system.num_abstract_parameters(), 1)
|
||
|
self.assertEqual(
|
||
|
context.get_abstract_parameter(index=0).get_value(),
|
||
|
model_value.get_value())
|
||
|
self.assertEqual(system.num_numeric_parameter_groups(), 1)
|
||
|
np.testing.assert_equal(
|
||
|
context.get_numeric_parameter(index=0).get_value(),
|
||
|
model_vector.get_value())
|
||
|
|
||
|
# Check diagram context accessors.
|
||
|
builder = DiagramBuilder()
|
||
|
builder.AddSystem(system)
|
||
|
diagram = builder.Build()
|
||
|
context = diagram.CreateDefaultContext()
|
||
|
# Existence check.
|
||
|
self.assertIsNot(
|
||
|
diagram.GetMutableSubsystemState(system, context), None)
|
||
|
subcontext = diagram.GetMutableSubsystemContext(subsystem=system,
|
||
|
context=context)
|
||
|
self.assertIsNot(subcontext, None)
|
||
|
self.assertIs(
|
||
|
diagram.GetSubsystemContext(subsystem=system, context=context),
|
||
|
subcontext)
|
||
|
subcontext2 = system.GetMyMutableContextFromRoot(root_context=context)
|
||
|
self.assertIsNot(subcontext2, None)
|
||
|
self.assertIs(subcontext2, subcontext)
|
||
|
self.assertIs(system.GetMyContextFromRoot(root_context=context),
|
||
|
subcontext2)
|
||
|
|
||
|
def test_continuous_state_api(self):
|
||
|
# N.B. Since this has trivial operations, we can test all scalar types.
|
||
|
for T in [float, AutoDiffXd, Expression]:
|
||
|
|
||
|
class TrivialSystem(LeafSystem_[T]):
|
||
|
def __init__(self, index):
|
||
|
LeafSystem_[T].__init__(self)
|
||
|
num_q = 2
|
||
|
num_v = 1
|
||
|
num_z = 3
|
||
|
num_state = num_q + num_v + num_z
|
||
|
if index == 0:
|
||
|
self.DeclareContinuousState(
|
||
|
num_state_variables=num_state)
|
||
|
elif index == 1:
|
||
|
self.DeclareContinuousState(
|
||
|
num_q=num_q, num_v=num_v, num_z=num_z)
|
||
|
elif index == 2:
|
||
|
self.DeclareContinuousState(
|
||
|
BasicVector_[T](num_state))
|
||
|
elif index == 3:
|
||
|
self.DeclareContinuousState(
|
||
|
BasicVector_[T](num_state),
|
||
|
num_q=num_q, num_v=num_v, num_z=num_z)
|
||
|
|
||
|
def DoCalcTimeDerivatives(self, context, derivatives):
|
||
|
derivatives.get_mutable_vector().SetZero()
|
||
|
|
||
|
for index in range(4):
|
||
|
system = TrivialSystem(index)
|
||
|
context = system.CreateDefaultContext()
|
||
|
self.assertEqual(
|
||
|
context.get_continuous_state_vector().size(), 6)
|
||
|
self.assertEqual(system.AllocateTimeDerivatives().size(), 6)
|
||
|
self.assertEqual(
|
||
|
system.EvalTimeDerivatives(context=context).size(), 6)
|
||
|
|
||
|
def test_discrete_state_api(self):
|
||
|
# N.B. Since this has trivial operations, we can test all scalar types.
|
||
|
for T in [float, AutoDiffXd, Expression]:
|
||
|
|
||
|
class TrivialSystem(LeafSystem_[T]):
|
||
|
def __init__(self, index):
|
||
|
LeafSystem_[T].__init__(self)
|
||
|
num_states = 3
|
||
|
if index == 0:
|
||
|
self.DeclareDiscreteState(
|
||
|
num_state_variables=num_states)
|
||
|
elif index == 1:
|
||
|
self.DeclareDiscreteState([1, 2, 3])
|
||
|
elif index == 2:
|
||
|
self.DeclareDiscreteState(
|
||
|
BasicVector_[T](num_states))
|
||
|
|
||
|
for index in range(3):
|
||
|
system = TrivialSystem(index)
|
||
|
context = system.CreateDefaultContext()
|
||
|
self.assertEqual(
|
||
|
context.get_discrete_state(0).size(), 3)
|
||
|
self.assertEqual(system.AllocateDiscreteVariables().size(), 3)
|
||
|
|
||
|
def test_abstract_io_port(self):
|
||
|
test = self
|
||
|
|
||
|
def assert_value_equal(a, b):
|
||
|
a_name, a_value = a
|
||
|
b_name, b_value = b
|
||
|
self.assertEqual(a_name, b_name)
|
||
|
numpy_compare.assert_equal(a_value, b_value)
|
||
|
|
||
|
# N.B. Since this has trivial operations, we can test all scalar types.
|
||
|
for T in [float, AutoDiffXd, Expression]:
|
||
|
default_value = ("default", T(0.))
|
||
|
expected_input_value = ("input", T(np.pi))
|
||
|
expected_output_value = ("output", 2*T(np.pi))
|
||
|
|
||
|
class CustomAbstractSystem(LeafSystem_[T]):
|
||
|
def __init__(self):
|
||
|
LeafSystem_[T].__init__(self)
|
||
|
self.input_port = self.DeclareAbstractInputPort(
|
||
|
"in", AbstractValue.Make(default_value))
|
||
|
self.output_port = self.DeclareAbstractOutputPort(
|
||
|
"out",
|
||
|
lambda: AbstractValue.Make(default_value),
|
||
|
self.DoCalcAbstractOutput,
|
||
|
prerequisites_of_calc=set([
|
||
|
self.input_port.ticket()]))
|
||
|
|
||
|
def DoCalcAbstractOutput(self, context, y_data):
|
||
|
input_value = self.EvalAbstractInput(
|
||
|
context=context, port_index=0).get_value()
|
||
|
# The allocator function will populate the output with
|
||
|
# the "input"
|
||
|
assert_value_equal(input_value, expected_input_value)
|
||
|
y_data.set_value(expected_output_value)
|
||
|
assert_value_equal(
|
||
|
y_data.get_value(), expected_output_value)
|
||
|
|
||
|
system = CustomAbstractSystem()
|
||
|
context = system.CreateDefaultContext()
|
||
|
|
||
|
self.assertEqual(context.num_input_ports(), 1)
|
||
|
system.get_input_port(0).FixValue(context, expected_input_value)
|
||
|
output = system.AllocateOutput()
|
||
|
self.assertEqual(output.num_ports(), 1)
|
||
|
system.CalcOutput(context, output)
|
||
|
value = output.get_data(0)
|
||
|
self.assertEqual(value.get_value(), expected_output_value)
|
||
|
|
||
|
def assert_equal_but_not_aliased(self, a, b):
|
||
|
self.assertEqual(a, b)
|
||
|
self.assertIsNot(a, b)
|
||
|
|
||
|
def test_context_and_value_object_set_from(self):
|
||
|
"""
|
||
|
Shows how `Value[object]` behaves in a context, especially in
|
||
|
connection to `Context.SetTimeStateAndParametersFrom()`.
|
||
|
|
||
|
Helps to highlight failure mode illustrated in #18653.
|
||
|
"""
|
||
|
arbitrary_object = {"key": "value"}
|
||
|
|
||
|
class SystemWithCacheAndState(LeafSystem):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
model_value = AbstractValue.Make(arbitrary_object)
|
||
|
self.state_index = self.DeclareAbstractState(model_value)
|
||
|
|
||
|
def calc_cache_noop(context, abstract_value):
|
||
|
pass
|
||
|
|
||
|
self.cache_entry = self.DeclareCacheEntry(
|
||
|
description="test",
|
||
|
value_producer=ValueProducer(
|
||
|
allocate=model_value.Clone,
|
||
|
calc=calc_cache_noop,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def eval_state(self, context):
|
||
|
return context.get_abstract_state(self.state_index).get_value()
|
||
|
|
||
|
system = SystemWithCacheAndState()
|
||
|
context = system.CreateDefaultContext()
|
||
|
context_init = context.Clone()
|
||
|
|
||
|
cache = system.cache_entry.Eval(context)
|
||
|
self.assert_equal_but_not_aliased(cache, arbitrary_object)
|
||
|
state = system.eval_state(context)
|
||
|
self.assert_equal_but_not_aliased(state, arbitrary_object)
|
||
|
|
||
|
def check_set_from():
|
||
|
nonlocal cache, state
|
||
|
context.SetTimeStateAndParametersFrom(context_init)
|
||
|
# Ensure that we have cloned the object.
|
||
|
old_state = state
|
||
|
state = system.eval_state(context)
|
||
|
self.assert_equal_but_not_aliased(state, old_state)
|
||
|
# Warning: Cache objects are not cloned!
|
||
|
old_cache = cache
|
||
|
cache = system.cache_entry.Eval(context)
|
||
|
self.assertIs(cache, old_cache)
|
||
|
|
||
|
# Check twice. Per #18653, if we did not implement
|
||
|
# Value[object].SetFrom() correctly, this would fail the second time.
|
||
|
check_set_from()
|
||
|
check_set_from()
|