diff --git a/default_plugins.xml b/default_plugins.xml index bfd5022..620815d 100644 --- a/default_plugins.xml +++ b/default_plugins.xml @@ -58,6 +58,12 @@ This is a increment filter which works on a stream of ints. + + + This is a increment filter which updates a stream of ints in place. + + diff --git a/include/filters/filter_base.hpp b/include/filters/filter_base.hpp index dfbb338..47dca23 100644 --- a/include/filters/filter_base.hpp +++ b/include/filters/filter_base.hpp @@ -312,6 +312,29 @@ class FilterBase rclcpp::node_interfaces::NodeLoggingInterface::SharedPtr logging_interface_; }; +/** + * \brief Optional base class for filters that can update data in place. + */ +template +class InPlaceFilter : public FilterBase +{ +public: + /** + * \brief Update data in place. + * \param data A reference to the data to be filtered. + */ + virtual bool update(T & data) = 0; + + /** + * \brief Compatibility implementation for the standard FilterBase API. + */ + bool update(const T & data_in, T & data_out) override + { + data_out = data_in; + return update(data_out); + } +}; + template class MultiChannelFilterBase : public FilterBase diff --git a/include/filters/filter_chain.hpp b/include/filters/filter_chain.hpp index 4269b46..efe2eaf 100644 --- a/include/filters/filter_chain.hpp +++ b/include/filters/filter_chain.hpp @@ -194,34 +194,67 @@ class FilterChain if (!configured_) { throw std::runtime_error("The update cannot be called without configuring the filter chain!"); } - bool result; - size_t list_size = reference_pointers_.size(); - if (list_size == 0) { + + if (reference_pointers_.empty()) { data_out = data_in; - result = true; - } else if (list_size == 1) { - result = reference_pointers_[0]->update(data_in, data_out); - } else if (list_size == 2) { - result = reference_pointers_[0]->update(data_in, buffer0_); - if (result == false) {return false;} // don't keep processing on failure - result = result && reference_pointers_[1]->update(buffer0_, data_out); - } else { - result = reference_pointers_[0]->update(data_in, buffer0_); // first copy in - for (size_t i = 1; i < reference_pointers_.size() - 1 && result; ++i) { - // all but first and last (never called if size=2) - if (i % 2 == 1) { - result = result && reference_pointers_[i]->update(buffer0_, buffer1_); - } else { - result = result && reference_pointers_[i]->update(buffer1_, buffer0_); + return true; + } + + // The first filter cannot update in place because the chain input is const. + bool result = reference_pointers_[0]->update(data_in, data_out); + if (!result) { + return false; + } + + T * current = &data_out; + T * next = &buffer0_; + for (size_t i = 1; i < reference_pointers_.size(); ++i) { + auto in_place_filter = + dynamic_cast *>(reference_pointers_[i].get()); + if (in_place_filter) { + result = in_place_filter->update(*current); + } else { + // Fall back to the existing two-buffer path for filters without in-place support. + result = reference_pointers_[i]->update(*current, *next); + if (result) { + std::swap(current, next); } } - if (list_size % 2 == 1) { // odd number last deposit was in buffer1 - result = result && reference_pointers_.back()->update(buffer1_, data_out); + if (!result) { + return false; + } + } + + if (current != &data_out) { + data_out = std::move(*current); + } + return true; + } + + /** + * \brief process data in place through each of the filters added sequentially + */ + bool update(T & data) + { + if (!configured_) { + throw std::runtime_error("The update cannot be called without configuring the filter chain!"); + } + + for (auto & filter : reference_pointers_) { + auto in_place_filter = dynamic_cast *>(filter.get()); + if (in_place_filter) { + if (!in_place_filter->update(data)) { + return false; + } } else { - result = result && reference_pointers_.back()->update(buffer0_, data_out); + // Filters without in-place support still work by using an internal buffer. + if (!filter->update(data, buffer0_)) { + return false; + } + data = std::move(buffer0_); } } - return result; + return true; } /** @@ -306,6 +339,18 @@ class FilterChain return reference_pointers_.size(); } + /** + * \brief Check whether every configured filter supports in-place updates. + */ + bool can_update_fully_in_place() const + { + return std::all_of( + reference_pointers_.begin(), reference_pointers_.end(), + [](const auto & filter) { + return dynamic_cast *>(filter.get()) != nullptr; + }); + } + rcl_interfaces::msg::SetParametersResult reconfigureCB(std::vector parameters) { auto result = rcl_interfaces::msg::SetParametersResult(); diff --git a/include/filters/increment.hpp b/include/filters/increment.hpp index 5e9da27..92363dd 100644 --- a/include/filters/increment.hpp +++ b/include/filters/increment.hpp @@ -88,6 +88,54 @@ bool IncrementFilter::update(const T & data_in, T & data_out) return true; } +/** + * \brief A increment filter which updates data in place. + */ +template +class InPlaceIncrementFilter : public InPlaceFilter +{ +public: + /** + * \brief Construct the filter with the expected width and height + */ + InPlaceIncrementFilter(); + + /** + * \brief Destructor to clean up + */ + ~InPlaceIncrementFilter() override; + + bool configure() override; + + /** + * \brief Update the filter in place + */ + bool update(T & data) override; +}; + +template +InPlaceIncrementFilter::InPlaceIncrementFilter() +{ +} + +template +InPlaceIncrementFilter::~InPlaceIncrementFilter() +{ +} + +template +bool InPlaceIncrementFilter::configure() +{ + return true; +} + +template +bool InPlaceIncrementFilter::update(T & data) +{ + ++data; + return true; +} + /** * \brief A increment filter which works on arrays. */ diff --git a/src/increment.cpp b/src/increment.cpp index bc311c3..98ed361 100644 --- a/src/increment.cpp +++ b/src/increment.cpp @@ -32,6 +32,7 @@ PLUGINLIB_EXPORT_CLASS(filters::IncrementFilter, filters::FilterBase) +PLUGINLIB_EXPORT_CLASS(filters::InPlaceIncrementFilter, filters::FilterBase) PLUGINLIB_EXPORT_CLASS( filters::MultiChannelIncrementFilter, filters::MultiChannelFilterBase) diff --git a/test/test_chain.cpp b/test/test_chain.cpp index 409b338..6dfd41a 100644 --- a/test/test_chain.cpp +++ b/test/test_chain.cpp @@ -378,6 +378,76 @@ TEST_F(ChainTest, TenIncrementChains) { EXPECT_EQ(11, v1a); } +TEST_F(ChainTest, InPlaceFallbackUpdate) { + filters::FilterChain chain("int"); + + std::vector overrides; + overrides.emplace_back("InPlaceFallback.filter1.name", std::string("increment1")); + overrides.emplace_back( + "InPlaceFallback.filter1.type", std::string("filters/InPlaceIncrementFilterInt")); + auto node = make_node_with_params(overrides); + + ASSERT_TRUE( + chain.configure( + "InPlaceFallback", node->get_node_logging_interface(), + node->get_node_parameters_interface())); + EXPECT_TRUE(chain.can_update_fully_in_place()); + + int v1 = 1; + int v1a = 9; + EXPECT_TRUE(chain.update(v1, v1a)); + EXPECT_EQ(1, v1); + EXPECT_EQ(2, v1a); +} + +TEST_F(ChainTest, InPlaceUpdate) { + filters::FilterChain chain("int"); + + std::vector overrides; + overrides.emplace_back("InPlaceIncrements.filter1.name", std::string("increment1")); + overrides.emplace_back( + "InPlaceIncrements.filter1.type", std::string("filters/InPlaceIncrementFilterInt")); + overrides.emplace_back("InPlaceIncrements.filter2.name", std::string("increment2")); + overrides.emplace_back( + "InPlaceIncrements.filter2.type", std::string("filters/InPlaceIncrementFilterInt")); + auto node = make_node_with_params(overrides); + + ASSERT_TRUE( + chain.configure( + "InPlaceIncrements", node->get_node_logging_interface(), + node->get_node_parameters_interface())); + EXPECT_TRUE(chain.can_update_fully_in_place()); + + int v1 = 1; + EXPECT_TRUE(chain.update(v1)); + EXPECT_EQ(3, v1); +} + +TEST_F(ChainTest, MixedInPlaceUpdate) { + filters::FilterChain chain("int"); + + std::vector overrides; + overrides.emplace_back("MixedIncrements.filter1.name", std::string("increment1")); + overrides.emplace_back( + "MixedIncrements.filter1.type", std::string("filters/InPlaceIncrementFilterInt")); + overrides.emplace_back("MixedIncrements.filter2.name", std::string("increment2")); + overrides.emplace_back("MixedIncrements.filter2.type", std::string("filters/IncrementFilterInt")); + overrides.emplace_back("MixedIncrements.filter3.name", std::string("increment3")); + overrides.emplace_back( + "MixedIncrements.filter3.type", std::string("filters/InPlaceIncrementFilterInt")); + auto node = make_node_with_params(overrides); + + ASSERT_TRUE( + chain.configure( + "MixedIncrements", node->get_node_logging_interface(), + node->get_node_parameters_interface())); + EXPECT_FALSE(chain.can_update_fully_in_place()); + + int v1 = 1; + EXPECT_TRUE(chain.update(v1)); + EXPECT_EQ(4, v1); +} + TEST_F(ChainTest, TenMultiChannelIncrementChains) { filters::MultiChannelFilterChain chain("int"); std::vector v1; @@ -467,6 +537,7 @@ TEST_F(ChainTest, TestChainLength) { chain.configure( "ZeroFilters", node->get_node_logging_interface(), node->get_node_parameters_interface())); EXPECT_EQ(chain.get_length(), 0); + EXPECT_TRUE(chain.can_update_fully_in_place()); chain.clear(); ASSERT_TRUE(