Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions rocketpy/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
interactive_objects,
controller_function,
sampling_rate,
sampling_rate=None,
initial_observed_variables=None,
name="Controller",
):
Expand Down Expand Up @@ -71,12 +71,14 @@ def __init__(
objects as needed. The function return statement can be used to save
relevant information in the `observed_variables` list.

.. note:: The function will be called according to the sampling rate
specified.
sampling_rate : float
.. note:: The function will be called according to the sampling
rate specified. If `sampling_rate` is None, the controller
function is called at every solver step of the simulation.
sampling_rate : float, optional
The sampling rate of the controller function in Hertz (Hz). This
means that the controller function will be called every
`1/sampling_rate` seconds.
`1/sampling_rate` seconds. If None, it is treated as a
continuous controller and called at every solver step.
initial_observed_variables : list, optional
A list of the initial values of the variables that the controller
function returns. This list is used to initialize the
Expand Down
18 changes: 17 additions & 1 deletion rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,16 @@ def __simulate(self, verbose):
self.y_sol = phase.solver.y
if verbose:
print(f"Current Simulation Time: {self.t:3.4f} s", end="\r")

if self._continuous_controllers:
for controller in self._continuous_controllers:
controller(
self.t,
self.y_sol,
self._controller_state_history,
self.sensors,
self.env,
)
self._controller_state_history.append(list(self.y_sol))
Comment on lines +780 to +789
if self.__check_simulation_events(phase, phase_index, node_index):
break # Stop if simulation termination event occurred

Expand Down Expand Up @@ -1537,6 +1546,7 @@ def __init_solver_monitors(self):

self.t_initial = self.initial_solution[0]
self.solution.append(self.initial_solution)
self._controller_state_history = [self.initial_solution[1:]]
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

Expand Down Expand Up @@ -1576,6 +1586,9 @@ def __init_equations_of_motion(self):
def __init_controllers(self):
"""Initialize controllers and sensors"""
self._controllers = self.rocket._controllers[:]
self._continuous_controllers = [
c for c in self._controllers if c.sampling_rate is None
]
self.sensors = self.rocket.sensors.get_components()

# reset controllable object to initial state (only airbrakes for now)
Expand Down Expand Up @@ -4488,6 +4501,9 @@ def add_parachutes(self, parachutes, t_init, t_end):

def add_controllers(self, controllers, t_init, t_end):
for controller in controllers:
# Skip node creation for continuous controllers
if controller.sampling_rate is None:
continue
# Calculate start of sampling time nodes
controller_time_step = 1 / controller.sampling_rate
controller_node_list = [
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/simulation/test_flight_time_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TimeNode.
"""

# from rocketpy.rocket import Parachute, _Controller
from rocketpy.control.controller import _Controller


def test_time_nodes_init(flight_calisto):
Expand Down Expand Up @@ -49,6 +49,32 @@ def test_time_nodes_add_node(flight_calisto):
# TODO: implement this test


def test_time_nodes_add_controllers_skips_continuous_controllers(flight_calisto):
"""Ensure only discrete controllers create time nodes."""
# Arrange
discrete_controller = _Controller(
interactive_objects=[],
controller_function=lambda t, sr, sv, sh, ov, io: None,
sampling_rate=10,
name="Discrete",
)
continuous_controller = _Controller(
interactive_objects=[],
controller_function=lambda t, sr, sv, sh, ov, io: None,
sampling_rate=None,
name="Continuous",
)
time_nodes = flight_calisto.TimeNodes()

# Act
time_nodes.add_controllers([discrete_controller, continuous_controller], 0, 1)

# Assert
assert len(time_nodes) == 11
assert all(node._controllers == [discrete_controller] for node in time_nodes)
assert all(continuous_controller not in node._controllers for node in time_nodes)


def test_time_nodes_sort(flight_calisto):
time_nodes = flight_calisto.TimeNodes()
time_nodes.add_node(3.0, [], [], [])
Expand Down
Loading