diff --git a/README.md b/README.md index af0b575b7..62a038963 100644 --- a/README.md +++ b/README.md @@ -93,32 +93,39 @@ The coroutine below shows how to use it ```cpp -auto -receiver(std::shared_ptr conn) -> net::awaitable +auto receiver(std::shared_ptr conn) -> asio::awaitable { - request req; - req.push("SUBSCRIBE", "channel"); - - flat_tree resp; + generic_flat_response resp; conn->set_receive_response(resp); - // Loop while reconnection is enabled - while (conn->will_reconnect()) { - - // Reconnect to channels. - co_await conn->async_exec(req); + // Subscribe to the channel 'mychannel'. You can add any number of channels here. + request req; + req.subscribe({"mychannel"}); + co_await conn->async_exec(req); + + // You're now subscribed to 'mychannel'. Pushes sent over this channel will be stored + // in resp. If the connection encounters a network error and reconnects to the server, + // it will automatically subscribe to 'mychannel' again. This is transparent to the user. + + // Loop to read Redis push messages. + for (error_code ec;;) { + // Wait for pushes + co_await conn->async_receive2(asio::redirect_error(ec)); + + // Check for errors and cancellations + if (ec && (ec != asio::experimental::error::channel_cancelled || !conn->will_reconnect())) { + std::cerr << "Error during receive2: " << ec << std::endl; + break; + } - // Loop reading Redis pushes. - for (error_code ec;;) { - co_await conn->async_receive2(resp, redirect_error(ec)); - if (ec) - break; // Connection lost, break so we can reconnect to channels. + // The response must be consumed without suspending the + // coroutine i.e. without the use of async operations. + for (auto const& elem : resp.value().get_view()) + std::cout << elem.value << "\n"; - // Use the response resp in some way and then clear it. - ... + std::cout << std::endl; - resp.clear(); - } + resp.value().clear(); } } ``` diff --git a/doc/modules/ROOT/pages/index.adoc b/doc/modules/ROOT/pages/index.adoc index 36b83d493..b7194d1b9 100644 --- a/doc/modules/ROOT/pages/index.adoc +++ b/doc/modules/ROOT/pages/index.adoc @@ -104,32 +104,39 @@ The coroutine below shows how to use it [source,cpp] ---- -auto -receiver(std::shared_ptr conn) -> net::awaitable +auto receiver(std::shared_ptr conn) -> asio::awaitable { - request req; - req.push("SUBSCRIBE", "channel"); - - flat_tree resp; + generic_flat_response resp; conn->set_receive_response(resp); - // Loop while reconnection is enabled - while (conn->will_reconnect()) { - - // Reconnect to channels. - co_await conn->async_exec(req); + // Subscribe to the channel 'mychannel'. You can add any number of channels here. + request req; + req.subscribe({"mychannel"}); + co_await conn->async_exec(req); + + // You're now subscribed to 'mychannel'. Pushes sent over this channel will be stored + // in resp. If the connection encounters a network error and reconnects to the server, + // it will automatically subscribe to 'mychannel' again. This is transparent to the user. + + // Loop to read Redis push messages. + for (error_code ec;;) { + // Wait for pushes + co_await conn->async_receive2(asio::redirect_error(ec)); + + // Check for errors and cancellations + if (ec && (ec != asio::experimental::error::channel_cancelled || !conn->will_reconnect())) { + std::cerr << "Error during receive2: " << ec << std::endl; + break; + } - // Loop reading Redis pushes. - for (error_code ec;;) { - co_await conn->async_receive2(resp, redirect_error(ec)); - if (ec) - break; // Connection lost, break so we can reconnect to channels. + // The response must be consumed without suspending the + // coroutine i.e. without the use of async operations. + for (auto const& elem : resp.value().get_view()) + std::cout << elem.value << "\n"; - // Use the response here and then clear it. - ... + std::cout << std::endl; - resp.clear(); - } + resp.value().clear(); } } ---- diff --git a/doc/modules/ROOT/pages/requests_responses.adoc b/doc/modules/ROOT/pages/requests_responses.adoc index 7f40fed11..69728b63c 100644 --- a/doc/modules/ROOT/pages/requests_responses.adoc +++ b/doc/modules/ROOT/pages/requests_responses.adoc @@ -184,7 +184,7 @@ must **NOT** be included in the response tuple. For example, the following reque ---- request req; req.push("PING"); -req.push("SUBSCRIBE", "channel"); +req.subscribe({"channel"}); req.push("QUIT"); ---- diff --git a/example/cpp20_chat_room.cpp b/example/cpp20_chat_room.cpp index 562225f56..4dd8180d7 100644 --- a/example/cpp20_chat_room.cpp +++ b/example/cpp20_chat_room.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -40,31 +41,44 @@ using namespace std::chrono_literals; // Chat over Redis pubsub. To test, run this program from multiple // terminals and type messages to stdin. +namespace { + +auto rethrow_on_error = [](std::exception_ptr exc) { + if (exc) + std::rethrow_exception(exc); +}; + auto receiver(std::shared_ptr conn) -> awaitable { - request req; - req.push("SUBSCRIBE", "channel"); - + // Set the receive response, so pushes are stored in resp generic_flat_response resp; conn->set_receive_response(resp); - while (conn->will_reconnect()) { - // Subscribe to channels. - co_await conn->async_exec(req); + // Subscribe to the channel 'channel'. Using request::subscribe() + // (instead of request::push()) makes the connection re-subscribe + // to 'channel' whenever it re-connects to the server. + request req; + req.subscribe({"channel"}); + co_await conn->async_exec(req); - // Loop reading Redis push messages. - for (error_code ec;;) { - co_await conn->async_receive2(redirect_error(ec)); - if (ec) - break; // Connection lost, break so we can reconnect to channels. + for (error_code ec;;) { + // Wait for pushes + co_await conn->async_receive2(asio::redirect_error(ec)); - for (auto const& elem: resp.value().get_view()) - std::cout << elem.value << "\n"; + // Check for errors and cancellations + if (ec && (ec != asio::experimental::error::channel_cancelled || !conn->will_reconnect())) { + std::cerr << "Error during receive2: " << ec << std::endl; + break; + } - std::cout << std::endl; + // The response must be consumed without suspending the + // coroutine i.e. without the use of async operations. + for (auto const& elem : resp.value().get_view()) + std::cout << elem.value << "\n"; - resp.value().clear(); - } + std::cout << std::endl; + + resp.value().clear(); } } @@ -81,6 +95,8 @@ auto publisher(std::shared_ptr in, std::shared_ptr awaitable { @@ -88,8 +104,8 @@ auto co_main(config cfg) -> awaitable auto conn = std::make_shared(ex); auto stream = std::make_shared(ex, ::dup(STDIN_FILENO)); - co_spawn(ex, receiver(conn), detached); - co_spawn(ex, publisher(stream, conn), detached); + co_spawn(ex, receiver(conn), rethrow_on_error); + co_spawn(ex, publisher(stream, conn), rethrow_on_error); conn->async_run(cfg, consign(detached, conn)); signal_set sig_set{ex, SIGINT, SIGTERM}; diff --git a/example/cpp20_subscriber.cpp b/example/cpp20_subscriber.cpp index 1f0dd86fe..b3eb9cdd4 100644 --- a/example/cpp20_subscriber.cpp +++ b/example/cpp20_subscriber.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,7 @@ using asio::signal_set; * To test send messages with redis-cli * * $ redis-cli -3 - * 127.0.0.1:6379> PUBLISH channel some-message + * 127.0.0.1:6379> PUBLISH mychannel some-message * (integer) 3 * 127.0.0.1:6379> * @@ -46,33 +47,39 @@ using asio::signal_set; // Receives server pushes. auto receiver(std::shared_ptr conn) -> asio::awaitable { - request req; - req.push("SUBSCRIBE", "channel"); - generic_flat_response resp; conn->set_receive_response(resp); - // Loop while reconnection is enabled - while (conn->will_reconnect()) { - // Reconnect to the channels. - co_await conn->async_exec(req); - - // Loop to read Redis push messages. - for (error_code ec;;) { - // Wait for pushes - co_await conn->async_receive2(asio::redirect_error(ec)); - if (ec) - break; // Connection lost, break so we can reconnect to channels. + // Subscribe to the channel 'mychannel'. You can add any number of channels here. + request req; + req.subscribe({"mychannel"}); + co_await conn->async_exec(req); + + // You're now subscribed to 'mychannel'. Pushes sent over this channel will be stored + // in resp. If the connection encounters a network error and reconnects to the server, + // it will automatically subscribe to 'mychannel' again. This is transparent to the user. + // You need to use specialized request::subscribe() function (instead of request::push) + // to enable this behavior. + + // Loop to read Redis push messages. + for (error_code ec;;) { + // Wait for pushes + co_await conn->async_receive2(asio::redirect_error(ec)); + + // Check for errors and cancellations + if (ec && (ec != asio::experimental::error::channel_cancelled || !conn->will_reconnect())) { + std::cerr << "Error during receive2: " << ec << std::endl; + break; + } - // The response must be consumed without suspending the - // coroutine i.e. without the use of async operations. - for (auto const& elem: resp.value().get_view()) - std::cout << elem.value << "\n"; + // The response must be consumed without suspending the + // coroutine i.e. without the use of async operations. + for (auto const& elem : resp.value().get_view()) + std::cout << elem.value << "\n"; - std::cout << std::endl; + std::cout << std::endl; - resp.value().clear(); - } + resp.value().clear(); } } diff --git a/include/boost/redis/connection.hpp b/include/boost/redis/connection.hpp index 3e8298f39..1fcf2d6b1 100644 --- a/include/boost/redis/connection.hpp +++ b/include/boost/redis/connection.hpp @@ -108,7 +108,10 @@ struct connection_impl { { while (true) { // Invoke the state machine - auto act = fsm_.resume(obj_->is_open(), self.get_cancellation_state().cancelled()); + auto act = fsm_.resume( + obj_->is_open(), + obj_->st_, + self.get_cancellation_state().cancelled()); // Do what the FSM said switch (act.type()) { @@ -203,7 +206,7 @@ struct connection_impl { }); return asio::async_compose( - exec_op{this, notifier, exec_fsm(st_.mpx, std::move(info))}, + exec_op{this, notifier, exec_fsm(std::move(info))}, token, writer_cv_); } diff --git a/include/boost/redis/detail/connection_state.hpp b/include/boost/redis/detail/connection_state.hpp index b4d2e1a01..75ddf1772 100644 --- a/include/boost/redis/detail/connection_state.hpp +++ b/include/boost/redis/detail/connection_state.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -47,7 +48,9 @@ struct connection_state { config cfg{}; multiplexer mpx{}; std::string diagnostic{}; // Used by the setup request and Sentinel + request setup_req{}; request ping_req{}; + subscription_tracker tracker{}; // Sentinel stuff lazy_random_engine eng{}; diff --git a/include/boost/redis/detail/exec_fsm.hpp b/include/boost/redis/detail/exec_fsm.hpp index 3727d07c3..7dc2f8c12 100644 --- a/include/boost/redis/detail/exec_fsm.hpp +++ b/include/boost/redis/detail/exec_fsm.hpp @@ -21,6 +21,8 @@ namespace boost::redis::detail { +struct connection_state; + // What should we do next? enum class exec_action_type { @@ -54,16 +56,17 @@ class exec_action { class exec_fsm { int resume_point_{0}; - multiplexer* mpx_{nullptr}; std::shared_ptr elem_; public: - exec_fsm(multiplexer& mpx, std::shared_ptr elem) noexcept - : mpx_(&mpx) - , elem_(std::move(elem)) + exec_fsm(std::shared_ptr elem) noexcept + : elem_(std::move(elem)) { } - exec_action resume(bool connection_is_open, asio::cancellation_type_t cancel_state); + exec_action resume( + bool connection_is_open, + connection_state& st, + asio::cancellation_type_t cancel_state); }; } // namespace boost::redis::detail diff --git a/include/boost/redis/detail/subscription_tracker.hpp b/include/boost/redis/detail/subscription_tracker.hpp new file mode 100644 index 000000000..865d4113e --- /dev/null +++ b/include/boost/redis/detail/subscription_tracker.hpp @@ -0,0 +1,35 @@ +// +// Copyright (c) 2025 Marcelo Zimbres Silva (mzimbres@gmail.com), +// Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#ifndef BOOST_REDIS_SUBSCRIPTION_TRACKER_HPP +#define BOOST_REDIS_SUBSCRIPTION_TRACKER_HPP + +#include +#include + +namespace boost::redis { + +class request; + +namespace detail { + +class subscription_tracker { + std::set channels_; + std::set pchannels_; + +public: + subscription_tracker() = default; + void clear(); + void commit_changes(const request& req); + void compose_subscribe_request(request& to) const; +}; + +} // namespace detail +} // namespace boost::redis + +#endif diff --git a/include/boost/redis/impl/exec_fsm.ipp b/include/boost/redis/impl/exec_fsm.ipp index b7be8c608..3898d1e18 100644 --- a/include/boost/redis/impl/exec_fsm.ipp +++ b/include/boost/redis/impl/exec_fsm.ipp @@ -9,6 +9,7 @@ #ifndef BOOST_REDIS_EXEC_FSM_IPP #define BOOST_REDIS_EXEC_FSM_IPP +#include #include #include #include @@ -28,7 +29,10 @@ inline bool is_total_cancel(asio::cancellation_type_t type) return !!(type & asio::cancellation_type_t::total); } -exec_action exec_fsm::resume(bool connection_is_open, asio::cancellation_type_t cancel_state) +exec_action exec_fsm::resume( + bool connection_is_open, + connection_state& st, + asio::cancellation_type_t cancel_state) { switch (resume_point_) { BOOST_REDIS_CORO_INITIAL @@ -47,7 +51,7 @@ exec_action exec_fsm::resume(bool connection_is_open, asio::cancellation_type_t BOOST_REDIS_YIELD(resume_point_, 2, exec_action_type::setup_cancellation) // Add the request to the multiplexer - mpx_->add(elem_); + st.mpx.add(elem_); // Notify the writer task that there is work to do. If the task is not // listening (e.g. it's already writing or the connection is not healthy), @@ -61,8 +65,14 @@ exec_action exec_fsm::resume(bool connection_is_open, asio::cancellation_type_t // If the request has completed (with error or not), we're done if (elem_->is_done()) { + // If the request completed successfully and we were configured to do so, + // record the changes applied to the pubsub state + if (!elem_->get_error()) + st.tracker.commit_changes(elem_->get_request()); + + // Deallocate memory before finalizing exec_action act{elem_->get_error(), elem_->get_read_size()}; - elem_.reset(); // Deallocate memory before finalizing + elem_.reset(); return act; } @@ -71,7 +81,7 @@ exec_action exec_fsm::resume(bool connection_is_open, asio::cancellation_type_t if ( (is_total_cancel(cancel_state) && elem_->is_waiting()) || is_partial_or_terminal_cancel(cancel_state)) { - mpx_->cancel(elem_); + st.mpx.cancel(elem_); elem_.reset(); // Deallocate memory before finalizing return exec_action{asio::error::operation_aborted}; } diff --git a/include/boost/redis/impl/request.ipp b/include/boost/redis/impl/request.ipp index 741c3b5d3..72653897b 100644 --- a/include/boost/redis/impl/request.ipp +++ b/include/boost/redis/impl/request.ipp @@ -5,7 +5,9 @@ */ #include +#include +#include #include namespace boost::redis::detail { @@ -34,7 +36,32 @@ request make_hello_request() void boost::redis::request::append(const request& other) { + // Remember the old payload size, to update offsets + std::size_t old_offset = payload_.size(); + + // Add the payload payload_ += other.payload_; commands_ += other.commands_; expected_responses_ += other.expected_responses_; + + // Add the pubsub changes. Offsets need to be updated + pubsub_changes_.reserve(pubsub_changes_.size() + other.pubsub_changes_.size()); + for (const auto& change : other.pubsub_changes_) { + pubsub_changes_.push_back({ + change.type, + change.channel_offset + old_offset, + change.channel_size, + }); + } +} + +void boost::redis::request::add_pubsub_arg(detail::pubsub_change_type type, std::string_view value) +{ + // Add the argument + resp3::add_bulk(payload_, value); + + // Track the change. + // The final \r\n adds 2 bytes + std::size_t offset = payload_.size() - value.size() - 2u; + pubsub_changes_.push_back({type, offset, value.size()}); } diff --git a/include/boost/redis/impl/run_fsm.ipp b/include/boost/redis/impl/run_fsm.ipp index 845f9d202..d477d549d 100644 --- a/include/boost/redis/impl/run_fsm.ipp +++ b/include/boost/redis/impl/run_fsm.ipp @@ -101,10 +101,10 @@ run_action run_fsm::resume( return stored_ec_; } - // Compose the setup request. This only depends on the config, so it can be done just once - compose_setup_request(st.cfg); + // Clear any remainder from previous runs + st.tracker.clear(); - // Compose the PING request. Same as above + // Compose the PING request. This only depends on the config, so it can be done just once compose_ping_request(st.cfg, st.ping_req); if (use_sentinel(st.cfg)) { @@ -159,10 +159,11 @@ run_action run_fsm::resume( // Initialization st.mpx.reset(); st.diagnostic.clear(); + compose_setup_request(st.cfg, st.tracker, st.setup_req); // Add the setup request to the multiplexer - if (st.cfg.setup.get_commands() != 0u) { - auto elm = make_elem(st.cfg.setup, make_any_adapter_impl(setup_adapter{st})); + if (st.setup_req.get_commands() != 0u) { + auto elm = make_elem(st.setup_req, make_any_adapter_impl(setup_adapter{st})); elm->set_done_callback([&elem_ref = *elm, &st] { on_setup_done(elem_ref, st); }); diff --git a/include/boost/redis/impl/setup_request_utils.hpp b/include/boost/redis/impl/setup_request_utils.hpp index c220d98e9..2fb0c6033 100644 --- a/include/boost/redis/impl/setup_request_utils.hpp +++ b/include/boost/redis/impl/setup_request_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include // use_sentinel #include @@ -22,14 +23,25 @@ namespace boost::redis::detail { // Modifies config::setup to make a request suitable to be sent // to the server using async_exec -inline void compose_setup_request(config& cfg) +inline void compose_setup_request( + const config& cfg, + const subscription_tracker& pubsub_st, + request& req) { - auto& req = cfg.setup; + // Clear any previous contents + req.clear(); - if (!cfg.use_setup) { + // Set the appropriate flags + request_access::set_priority(req, true); + req.get_config().cancel_if_unresponded = true; + req.get_config().cancel_on_connection_lost = true; + + if (cfg.use_setup) { + // We should use the provided request as-is + req.append(cfg.setup); + } else { // We're not using the setup request as-is, but should compose one based on // the values passed by the user - req.clear(); // Which parts of the command should we send? // Don't send AUTH if the user is the default and the password is empty. @@ -59,12 +71,8 @@ inline void compose_setup_request(config& cfg) if (use_sentinel(cfg)) req.push("ROLE"); - // In any case, the setup request should have the priority - // flag set so it's executed before any other request. - // The setup request should never be retried. - request_access::set_priority(req, true); - req.get_config().cancel_if_unresponded = true; - req.get_config().cancel_on_connection_lost = true; + // Add any subscription commands require to restore the PubSub state + pubsub_st.compose_subscribe_request(req); } class setup_adapter { @@ -83,7 +91,8 @@ class setup_adapter { // When using Sentinel, we add a ROLE command at the end. // We need to ensure that this instance is a master. - if (use_sentinel(st_->cfg) && response_idx_ == st_->cfg.setup.get_expected_responses() - 1u) { + // ROLE may be followed by subscribe requests, but these don't expect any response. + if (use_sentinel(st_->cfg) && response_idx_ == st_->setup_req.get_expected_responses() - 1u) { // ROLE's response should be an array of at least 1 element if (nd.depth == 0u) { if (nd.data_type != resp3::type::array) diff --git a/include/boost/redis/impl/subscription_tracker.ipp b/include/boost/redis/impl/subscription_tracker.ipp new file mode 100644 index 000000000..ec23675fb --- /dev/null +++ b/include/boost/redis/impl/subscription_tracker.ipp @@ -0,0 +1,44 @@ +// +// Copyright (c) 2025 Marcelo Zimbres Silva (mzimbres@gmail.com), +// Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#include +#include + +#include + +#include + +namespace boost::redis::detail { + +void subscription_tracker::clear() +{ + channels_.clear(); + pchannels_.clear(); +} + +void subscription_tracker::commit_changes(const request& req) +{ + for (const auto& ch : request_access::pubsub_changes(req)) { + std::string channel{req.payload().substr(ch.channel_offset, ch.channel_size)}; + switch (ch.type) { + case pubsub_change_type::subscribe: channels_.insert(std::move(channel)); break; + case pubsub_change_type::unsubscribe: channels_.erase(std::move(channel)); break; + case pubsub_change_type::psubscribe: pchannels_.insert(std::move(channel)); break; + case pubsub_change_type::punsubscribe: pchannels_.erase(std::move(channel)); break; + default: BOOST_ASSERT(false); + } + } +} + +void subscription_tracker::compose_subscribe_request(request& to) const +{ + to.push_range("SUBSCRIBE", channels_); + to.push_range("PSUBSCRIBE", pchannels_); +} + +} // namespace boost::redis::detail diff --git a/include/boost/redis/request.hpp b/include/boost/redis/request.hpp index fc3160c06..d3104a785 100644 --- a/include/boost/redis/request.hpp +++ b/include/boost/redis/request.hpp @@ -10,8 +10,12 @@ #include #include +#include #include +#include #include +#include +#include // NOTE: For some commands like hset it would be a good idea to assert // the value type is a pair. @@ -21,6 +25,21 @@ namespace boost::redis { namespace detail { auto has_response(std::string_view cmd) -> bool; struct request_access; + +enum class pubsub_change_type +{ + subscribe, + unsubscribe, + psubscribe, + punsubscribe, +}; + +struct pubsub_change { + pubsub_change_type type; + std::size_t channel_offset; + std::size_t channel_size; +}; + } // namespace detail /** @brief Represents a Redis request. @@ -123,6 +142,7 @@ class request { void clear() { payload_.clear(); + pubsub_changes_.clear(); commands_ = 0; expected_responses_ = 0; has_hello_priority_ = false; @@ -257,17 +277,17 @@ class request { * of arguments and don't have a key. For example: * * @code - * std::set channels - * { "channel1" , "channel2" , "channel3" }; + * std::set keys + * { "key1" , "key2" , "key3" }; * * request req; - * req.push("SUBSCRIBE", channels.cbegin(), channels.cend()); + * req.push("MGET", keys.begin(), keys.end()); * @endcode * * This will generate the following command: * * @code - * SUBSCRIBE channel1 channel2 channel3 + * MGET key1 key2 key3 * @endcode * * *If the passed range is empty, no command is added* and this @@ -412,6 +432,298 @@ class request { */ void append(const request& other); + /** + * @brief Appends a SUBSCRIBE command to the end of the request. + * + * If `channels` contains `{"ch1", "ch2"}`, the resulting command + * is `SUBSCRIBE ch1 ch2`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + void subscribe(std::initializer_list channels) + { + subscribe(channels.begin(), channels.end()); + } + + /** + * @brief Appends a SUBSCRIBE command to the end of the request. + * + * If `channels` contains `["ch1", "ch2"]`, the resulting command + * is `SUBSCRIBE ch1 ch2`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void subscribe(Range&& channels, decltype(std::cbegin(channels))* = nullptr) + { + subscribe(std::cbegin(channels), std::cend(channels)); + } + + /** + * @brief Appends a SUBSCRIBE command to the end of the request. + * + * [`channels_begin`, `channels_end`) should point to a valid + * range of elements convertible to `std::string_view`. + * If the range contains `["ch1", "ch2"]`, the resulting command + * is `SUBSCRIBE ch1 ch2`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void subscribe(ForwardIt channels_begin, ForwardIt channels_end) + { + push_pubsub("SUBSCRIBE", detail::pubsub_change_type::subscribe, channels_begin, channels_end); + } + + /** + * @brief Appends an UNSUBSCRIBE command to the end of the request. + * + * If `channels` contains `{"ch1", "ch2"}`, the resulting command + * is `UNSUBSCRIBE ch1 ch2`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + void unsubscribe(std::initializer_list channels) + { + unsubscribe(channels.begin(), channels.end()); + } + + /** + * @brief Appends an UNSUBSCRIBE command to the end of the request. + * + * If `channels` contains `["ch1", "ch2"]`, the resulting command + * is `UNSUBSCRIBE ch1 ch2`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void unsubscribe(Range&& channels, decltype(std::cbegin(channels))* = nullptr) + { + unsubscribe(std::cbegin(channels), std::cend(channels)); + } + + /** + * @brief Appends an UNSUBSCRIBE command to the end of the request. + * + * [`channels_begin`, `channels_end`) should point to a valid + * range of elements convertible to `std::string_view`. + * If the range contains `["ch1", "ch2"]`, the resulting command + * is `UNSUBSCRIBE ch1 ch2`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void unsubscribe(ForwardIt channels_begin, ForwardIt channels_end) + { + push_pubsub( + "UNSUBSCRIBE", + detail::pubsub_change_type::unsubscribe, + channels_begin, + channels_end); + } + + /** + * @brief Appends a PSUBSCRIBE command to the end of the request. + * + * If `patterns` contains `{"news.*", "events.*"}`, the resulting command + * is `PSUBSCRIBE news.* events.*`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + void psubscribe(std::initializer_list patterns) + { + psubscribe(patterns.begin(), patterns.end()); + } + + /** + * @brief Appends a PSUBSCRIBE command to the end of the request. + * + * If `patterns` contains `["news.*", "events.*"]`, the resulting command + * is `PSUBSCRIBE news.* events.*`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void psubscribe(Range&& patterns, decltype(std::cbegin(patterns))* = nullptr) + { + psubscribe(std::cbegin(patterns), std::cend(patterns)); + } + + /** + * @brief Appends a PSUBSCRIBE command to the end of the request. + * + * [`patterns_begin`, `patterns_end`) should point to a valid + * range of elements convertible to `std::string_view`. + * If the range contains `["news.*", "events.*"]`, the resulting command + * is `PSUBSCRIBE news.* events.*`. + * + * Subscriptions created using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void psubscribe(ForwardIt patterns_begin, ForwardIt patterns_end) + { + push_pubsub( + "PSUBSCRIBE", + detail::pubsub_change_type::psubscribe, + patterns_begin, + patterns_end); + } + + /** + * @brief Appends a PUNSUBSCRIBE command to the end of the request. + * + * If `patterns` contains `{"news.*", "events.*"}`, the resulting command + * is `PUNSUBSCRIBE news.* events.*`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + void punsubscribe(std::initializer_list patterns) + { + punsubscribe(patterns.begin(), patterns.end()); + } + + /** + * @brief Appends a PUNSUBSCRIBE command to the end of the request. + * + * If `patterns` contains `["news.*", "events.*"]`, the resulting command + * is `PUNSUBSCRIBE news.* events.*`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void punsubscribe(Range&& patterns, decltype(std::cbegin(patterns))* = nullptr) + { + punsubscribe(std::cbegin(patterns), std::cend(patterns)); + } + + /** + * @brief Appends a PUNSUBSCRIBE command to the end of the request. + * + * [`patterns_begin`, `patterns_end`) should point to a valid + * range of elements convertible to `std::string_view`. + * If the range contains `["news.*", "events.*"]`, the resulting command + * is `PUNSUBSCRIBE news.* events.*`. + * + * Subscriptions removed using this function are tracked + * to enable PubSub state restoration. After successfully executing + * the request, the connection will store any newly subscribed channels and patterns. + * Every time a reconnection happens, + * a suitable `SUBSCRIBE`/`PSUBSCRIBE` command is issued automatically, + * to restore the subscriptions that were active before the reconnection. + * + * PubSub store restoration only happens when using @ref subscribe, + * @ref unsubscribe, @ref psubscribe or @ref punsubscribe. + * Subscription commands added by @ref push or @ref push_range are not tracked. + */ + template + void punsubscribe(ForwardIt patterns_begin, ForwardIt patterns_end) + { + push_pubsub( + "PUNSUBSCRIBE", + detail::pubsub_change_type::punsubscribe, + patterns_begin, + patterns_end); + } + private: void check_cmd(std::string_view cmd) { @@ -429,6 +741,35 @@ class request { std::size_t commands_ = 0; std::size_t expected_responses_ = 0; bool has_hello_priority_ = false; + std::vector pubsub_changes_{}; + + void add_pubsub_arg(detail::pubsub_change_type type, std::string_view value); + + template + void push_pubsub( + std::string_view cmd, + detail::pubsub_change_type type, + ForwardIt channels_begin, + ForwardIt channels_end) + { + static_assert( + std::is_convertible_v< + typename std::iterator_traits::value_type, + std::string_view>, + "subscribe, psubscribe, unsubscribe and punsubscribe should be passed ranges of elements " + "convertible to std::string_view"); + if (channels_begin == channels_end) + return; + + auto const distance = std::distance(channels_begin, channels_end); + resp3::add_header(payload_, resp3::type::array, 1 + distance); + resp3::add_bulk(payload_, cmd); + + for (; channels_begin != channels_end; ++channels_begin) + add_pubsub_arg(type, *channels_begin); + + ++commands_; // these commands don't have a response + } friend struct detail::request_access; }; @@ -438,6 +779,10 @@ namespace detail { struct request_access { inline static void set_priority(request& r, bool value) { r.has_hello_priority_ = value; } inline static bool has_priority(const request& r) { return r.has_hello_priority_; } + inline static const std::vector& pubsub_changes(const request& r) + { + return r.pubsub_changes_; + } }; // Creates a HELLO 3 request diff --git a/include/boost/redis/src.hpp b/include/boost/redis/src.hpp index eb48981b9..647e643de 100644 --- a/include/boost/redis/src.hpp +++ b/include/boost/redis/src.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -18,8 +19,8 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 40a1ea05d..9c6a97e0e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -55,6 +55,7 @@ make_test(test_update_sentinel_list) make_test(test_flat_tree) make_test(test_generic_flat_response) make_test(test_read_buffer) +make_test(test_subscription_tracker) # Tests that require a real Redis server make_test(test_conn_quit) diff --git a/test/Jamfile b/test/Jamfile index 3aa72460f..ccd47060c 100644 --- a/test/Jamfile +++ b/test/Jamfile @@ -72,6 +72,7 @@ local tests = test_flat_tree test_generic_flat_response test_read_buffer + test_subscription_tracker ; # Build and run the tests diff --git a/test/test_compose_setup_request.cpp b/test/test_compose_setup_request.cpp index 2f91a9470..566799502 100644 --- a/test/test_compose_setup_request.cpp +++ b/test/test_compose_setup_request.cpp @@ -6,220 +6,231 @@ // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // -#include #include +#include #include #include #include #include -#include #include +#include #include -#include +#include +#include + +using namespace boost::redis; namespace asio = boost::asio; -namespace redis = boost::redis; -using redis::detail::compose_setup_request; +using detail::compose_setup_request; +using detail::subscription_tracker; using boost::system::error_code; namespace { +struct fixture { + subscription_tracker tracker; + request out; + config cfg; + + void run(std::string_view expected_payload, boost::source_location loc = BOOST_CURRENT_LOCATION) + { + out.push("PING", "leftover"); // verify that we clear the request + + compose_setup_request(cfg, tracker, out); + + if (!BOOST_TEST_EQ(out.payload(), expected_payload)) + std::cerr << "Called from " << loc << std::endl; + + if (!BOOST_TEST(out.has_hello_priority())) + std::cerr << "Called from " << loc << std::endl; + + if (!BOOST_TEST(out.get_config().cancel_if_unresponded)) + std::cerr << "Called from " << loc << std::endl; + + if (!BOOST_TEST(out.get_config().cancel_on_connection_lost)) + std::cerr << "Called from " << loc << std::endl; + } +}; + void test_hello() { - redis::config cfg; - cfg.clientname = ""; - - compose_setup_request(cfg); + fixture fix; + fix.cfg.clientname = ""; - std::string_view const expected = "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fix.run("*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n"); } void test_select() { - redis::config cfg; - cfg.clientname = ""; - cfg.database_index = 10; + fixture fix; + fix.cfg.clientname = ""; + fix.cfg.database_index = 10; - compose_setup_request(cfg); - - std::string_view const expected = + fix.run( "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n" - "*2\r\n$6\r\nSELECT\r\n$2\r\n10\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + "*2\r\n$6\r\nSELECT\r\n$2\r\n10\r\n"); } void test_clientname() { - redis::config cfg; - - compose_setup_request(cfg); + fixture fix; - std::string_view const - expected = "*4\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$7\r\nSETNAME\r\n$11\r\nBoost.Redis\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fix.run("*4\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$7\r\nSETNAME\r\n$11\r\nBoost.Redis\r\n"); } void test_auth() { - redis::config cfg; - cfg.clientname = ""; - cfg.username = "foo"; - cfg.password = "bar"; - - compose_setup_request(cfg); - - std::string_view const - expected = "*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fixture fix; + fix.cfg.clientname = ""; + fix.cfg.username = "foo"; + fix.cfg.password = "bar"; + + fix.run("*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"); } void test_auth_empty_password() { - redis::config cfg; - cfg.clientname = ""; - cfg.username = "foo"; - - compose_setup_request(cfg); - - std::string_view const - expected = "*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$0\r\n\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fixture fix; + fix.cfg.clientname = ""; + fix.cfg.username = "foo"; + + fix.run("*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$0\r\n\r\n"); } void test_auth_setname() { - redis::config cfg; - cfg.clientname = "mytest"; - cfg.username = "foo"; - cfg.password = "bar"; - - compose_setup_request(cfg); + fixture fix; + fix.cfg.clientname = "mytest"; + fix.cfg.username = "foo"; + fix.cfg.password = "bar"; - std::string_view const expected = + fix.run( "*7\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$7\r\nSETNAME\r\n$" - "6\r\nmytest\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + "6\r\nmytest\r\n"); } void test_use_setup() { - redis::config cfg; - cfg.clientname = "mytest"; - cfg.username = "foo"; - cfg.password = "bar"; - cfg.database_index = 4; - cfg.use_setup = true; - cfg.setup.push("SELECT", 8); - - compose_setup_request(cfg); - - std::string_view const expected = + fixture fix; + fix.cfg.clientname = "mytest"; + fix.cfg.username = "foo"; + fix.cfg.password = "bar"; + fix.cfg.database_index = 4; + fix.cfg.use_setup = true; + fix.cfg.setup.push("SELECT", 8); + + fix.run( "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n" - "*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + "*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"); } // Regression check: we set the priority flag void test_use_setup_no_hello() { - redis::config cfg; - cfg.use_setup = true; - cfg.setup.clear(); - cfg.setup.push("SELECT", 8); - - compose_setup_request(cfg); - - std::string_view const expected = "*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fixture fix; + fix.cfg.use_setup = true; + fix.cfg.setup.clear(); + fix.cfg.setup.push("SELECT", 8); + + fix.run("*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"); } // Regression check: we set the relevant cancellation flags in the request void test_use_setup_flags() { - redis::config cfg; - cfg.use_setup = true; - cfg.setup.clear(); - cfg.setup.push("SELECT", 8); - cfg.setup.get_config().cancel_if_unresponded = false; - cfg.setup.get_config().cancel_on_connection_lost = false; - - compose_setup_request(cfg); - - std::string_view const expected = "*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + fixture fix; + fix.cfg.use_setup = true; + fix.cfg.setup.clear(); + fix.cfg.setup.push("SELECT", 8); + fix.cfg.setup.get_config().cancel_if_unresponded = false; + fix.cfg.setup.get_config().cancel_on_connection_lost = false; + + fix.run("*2\r\n$6\r\nSELECT\r\n$1\r\n8\r\n"); +} + +// If we have tracked subscriptions, these are added at the end +void test_tracked_subscriptions() +{ + fixture fix; + fix.cfg.clientname = ""; + + // Populate the tracker + request sub_req; + sub_req.subscribe({"ch1", "ch2"}); + fix.tracker.commit_changes(sub_req); + + fix.run( + "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n" + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n"); +} + +void test_tracked_subscriptions_use_setup() +{ + fixture fix; + fix.cfg.use_setup = true; + fix.cfg.setup.clear(); + fix.cfg.setup.push("PING", "value"); + + // Populate the tracker + request sub_req; + sub_req.subscribe({"ch1", "ch2"}); + fix.tracker.commit_changes(sub_req); + + fix.run( + "*2\r\n$4\r\nPING\r\n$5\r\nvalue\r\n" + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n"); } // When using Sentinel, a ROLE command is added. This works -// both with the old HELLO and new setup strategies. +// both with the old HELLO and new setup strategies, and with tracked subscriptions void test_sentinel_auth() { - redis::config cfg; - cfg.sentinel.addresses = { + fixture fix; + fix.cfg.sentinel.addresses = { {"localhost", "26379"} }; - cfg.clientname = ""; - cfg.username = "foo"; - cfg.password = "bar"; - - compose_setup_request(cfg); + fix.cfg.clientname = ""; + fix.cfg.username = "foo"; + fix.cfg.password = "bar"; - std::string_view const expected = + fix.run( "*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n" - "*1\r\n$4\r\nROLE\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + "*1\r\n$4\r\nROLE\r\n"); } void test_sentinel_use_setup() { - redis::config cfg; - cfg.sentinel.addresses = { + fixture fix; + fix.cfg.sentinel.addresses = { {"localhost", "26379"} }; - cfg.use_setup = true; - cfg.setup.push("SELECT", 42); + fix.cfg.use_setup = true; + fix.cfg.setup.push("SELECT", 42); - compose_setup_request(cfg); - - std::string_view const expected = + fix.run( "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n" "*2\r\n$6\r\nSELECT\r\n$2\r\n42\r\n" - "*1\r\n$4\r\nROLE\r\n"; - BOOST_TEST_EQ(cfg.setup.payload(), expected); - BOOST_TEST(cfg.setup.has_hello_priority()); - BOOST_TEST(cfg.setup.get_config().cancel_if_unresponded); - BOOST_TEST(cfg.setup.get_config().cancel_on_connection_lost); + "*1\r\n$4\r\nROLE\r\n"); +} + +void test_sentinel_tracked_subscriptions() +{ + fixture fix; + fix.cfg.clientname = ""; + fix.cfg.sentinel.addresses = { + {"localhost", "26379"} + }; + + // Populate the tracker + request sub_req; + sub_req.subscribe({"ch1", "ch2"}); + fix.tracker.commit_changes(sub_req); + + fix.run( + "*2\r\n$5\r\nHELLO\r\n$1\r\n3\r\n" + "*1\r\n$4\r\nROLE\r\n" + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n"); } } // namespace @@ -235,8 +246,11 @@ int main() test_use_setup(); test_use_setup_no_hello(); test_use_setup_flags(); + test_tracked_subscriptions(); + test_tracked_subscriptions_use_setup(); test_sentinel_auth(); test_sentinel_use_setup(); + test_sentinel_tracked_subscriptions(); return boost::report_errors(); } \ No newline at end of file diff --git a/test/test_conn_echo_stress.cpp b/test/test_conn_echo_stress.cpp index 48157b0be..e1455b2aa 100644 --- a/test/test_conn_echo_stress.cpp +++ b/test/test_conn_echo_stress.cpp @@ -55,11 +55,7 @@ std::ostream& operator<<(std::ostream& os, usage const& u) namespace { -auto -receiver( - connection& conn, - flat_tree& resp, - std::size_t expected) -> net::awaitable +auto receiver(connection& conn, flat_tree& resp, std::size_t expected) -> net::awaitable { std::size_t push_counter = 0; while (push_counter != expected) { @@ -135,7 +131,7 @@ BOOST_AUTO_TEST_CASE(echo_stress) // Subscribe, then launch the coroutines request req; - req.push("SUBSCRIBE", "channel"); + req.subscribe({"channel"}); conn.async_exec(req, ignore, [&](error_code ec, std::size_t) { subscribe_finished = true; BOOST_TEST(ec == error_code()); @@ -150,13 +146,11 @@ BOOST_AUTO_TEST_CASE(echo_stress) BOOST_TEST(subscribe_finished); // Print statistics - std::cout - << "-------------------\n" - << "Usage data: \n" - << conn.get_usage() << "\n" - << "-------------------\n" - << "Reallocations: " << resp.get_reallocs() - << std::endl; + std::cout << "-------------------\n" + << "Usage data: \n" + << conn.get_usage() << "\n" + << "-------------------\n" + << "Reallocations: " << resp.get_reallocs() << std::endl; } } // namespace diff --git a/test/test_conn_push2.cpp b/test/test_conn_push2.cpp index 9e3d7f623..9cf638e92 100644 --- a/test/test_conn_push2.cpp +++ b/test/test_conn_push2.cpp @@ -7,12 +7,15 @@ #include #include #include +#include #include #include #include +#include #include +#include #define BOOST_TEST_MODULE conn_push #include @@ -35,6 +38,8 @@ using boost::redis::ignore; using boost::redis::ignore_t; using boost::system::error_code; using boost::redis::logger; +using boost::redis::resp3::node_view; +using boost::redis::resp3::type; using namespace std::chrono_literals; namespace { @@ -389,4 +394,157 @@ BOOST_AUTO_TEST_CASE(test_unsubscribe) BOOST_TEST(run_finished); } +class test_pubsub_state_restoration_ { + net::io_context ioc; + connection conn{ioc}; + request req; + response resp_str; + flat_tree resp_push; + bool exec_finished = false; + + void check_subscriptions() + { + // Checks for the expected subscriptions and patterns after restoration + std::set seen_channels, seen_patterns; + for (auto it = resp_push.get_view().begin(); it != resp_push.get_view().end();) { + // The root element should be a push + BOOST_TEST_REQUIRE(it->data_type == type::push); + BOOST_TEST_REQUIRE(it->aggregate_size >= 2u); + BOOST_TEST_REQUIRE((++it != resp_push.get_view().end())); + + // The next element should be the message type + std::string_view msg_type = it->value; + BOOST_TEST_REQUIRE((++it != resp_push.get_view().end())); + + // The next element is the channel or pattern + if (msg_type == "subscribe") + seen_channels.insert(it->value); + else if (msg_type == "psubscribe") + seen_patterns.insert(it->value); + + // Skip the rest of the nodes + while (it != resp_push.get_view().end() && it->depth != 0u) + ++it; + } + + const std::string_view expected_channels[] = {"ch1", "ch3", "ch5"}; + const std::string_view expected_patterns[] = {"ch1*", "ch3*", "ch4*", "ch8*"}; + + BOOST_TEST(seen_channels == expected_channels, boost::test_tools::per_element()); + BOOST_TEST(seen_patterns == expected_patterns, boost::test_tools::per_element()); + } + + void sub1() + { + // Subscribe to some channels and patterns + req.clear(); + req.subscribe({"ch1", "ch2", "ch3"}); // active: 1, 2, 3 + req.psubscribe({"ch1*", "ch2*", "ch3*", "ch4*"}); // active: 1, 2, 3, 4 + conn.async_exec(req, ignore, [this](error_code ec, std::size_t) { + BOOST_TEST(ec == error_code()); + unsub(); + }); + } + + void unsub() + { + // Unsubscribe from some channels and patterns. + // Unsubscribing from a channel/pattern that we weren't subscribed to is OK. + req.clear(); + req.unsubscribe({"ch2", "ch1", "ch5"}); // active: 3 + req.punsubscribe({"ch2*", "ch4*", "ch9*"}); // active: 1, 3 + conn.async_exec(req, ignore, [this](error_code ec, std::size_t) { + BOOST_TEST(ec == error_code()); + sub2(); + }); + } + + void sub2() + { + // Subscribe to other channels/patterns. + // Re-subscribing to channels/patterns we unsubscribed from is OK. + // Subscribing to the same channel/pattern twice is OK. + req.clear(); + req.subscribe({"ch1", "ch3", "ch5"}); // active: 1, 3, 5 + req.psubscribe({"ch3*", "ch4*", "ch8*"}); // active: 1, 3, 4, 8 + + // Subscriptions created by push() don't survive reconnection + req.push("SUBSCRIBE", "ch10"); // active: 1, 3, 5, 10 + req.push("PSUBSCRIBE", "ch10*"); // active: 1, 3, 4, 8, 10 + + // Validate that we're subscribed to what we expect + req.push("CLIENT", "INFO"); + + conn.async_exec(req, resp_str, [this](error_code ec, std::size_t) { + BOOST_TEST(ec == error_code()); + + // We are subscribed to 4 channels and 5 patterns + BOOST_TEST(std::get<0>(resp_str).has_value()); + BOOST_TEST(find_client_info(std::get<0>(resp_str).value(), "sub") == "4"); + BOOST_TEST(find_client_info(std::get<0>(resp_str).value(), "psub") == "5"); + + resp_push.clear(); + + quit(); + }); + } + + void quit() + { + req.clear(); + req.push("QUIT"); + + conn.async_exec(req, ignore, [this](error_code, std::size_t) { + // we don't know if this request will complete successfully or not + client_info(); + }); + } + + void client_info() + { + req.clear(); + req.push("CLIENT", "INFO"); + req.get_config().cancel_if_unresponded = false; + + conn.async_exec(req, resp_str, [this](error_code ec, std::size_t) { + BOOST_TEST(ec == error_code()); + + // We are subscribed to 3 channels and 4 patterns (1 of each didn't survive reconnection) + BOOST_TEST(std::get<0>(resp_str).has_value()); + BOOST_TEST(find_client_info(std::get<0>(resp_str).value(), "sub") == "3"); + BOOST_TEST(find_client_info(std::get<0>(resp_str).value(), "psub") == "4"); + + // We have received pushes confirming it + check_subscriptions(); + + exec_finished = true; + conn.cancel(); + }); + } + +public: + void run() + { + conn.set_receive_response(resp_push); + + // Start the request chain + sub1(); + + // Start running + bool run_finished = false; + conn.async_run(make_test_config(), [&run_finished](error_code ec) { + BOOST_TEST(ec == net::error::operation_aborted); + run_finished = true; + }); + + ioc.run_for(test_timeout); + + // Done + BOOST_TEST(exec_finished); + BOOST_TEST(run_finished); + } +}; + +BOOST_AUTO_TEST_CASE(test_pubsub_state_restoration) { test_pubsub_state_restoration_().run(); } + } // namespace diff --git a/test/test_conn_sentinel.cpp b/test/test_conn_sentinel.cpp index 7caec74ec..f18c1de08 100644 --- a/test/test_conn_sentinel.cpp +++ b/test/test_conn_sentinel.cpp @@ -96,7 +96,7 @@ void test_receive() // Subscribe to a channel. This produces a push message on itself request req; - req.push("SUBSCRIBE", "sentinel_channel"); + req.subscribe({"sentinel_channel"}); bool exec_finished = false, receive_finished = false, run_finished = false; diff --git a/test/test_exec_fsm.cpp b/test/test_exec_fsm.cpp index f23853556..16301554d 100644 --- a/test/test_exec_fsm.cpp +++ b/test/test_exec_fsm.cpp @@ -6,6 +6,7 @@ // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // +#include #include #include #include @@ -30,6 +31,7 @@ using detail::multiplexer; using detail::exec_action_type; using detail::consume_result; using detail::exec_action; +using detail::connection_state; using boost::system::error_code; using boost::asio::cancellation_type_t; @@ -102,11 +104,23 @@ struct elem_and_request { std::shared_ptr elm; std::weak_ptr weak_elm; // check that we free memory - elem_and_request(request::config cfg = {}) - : req(cfg) + static request make_request(request::config cfg) { + request req{cfg}; + // Empty requests are not valid. The request needs to be populated before creating the element req.push("get", "mykey"); + + return req; + } + + elem_and_request(request::config cfg = {}) + : elem_and_request(make_request(cfg)) + { } + + elem_and_request(request input_req) + : req(std::move(input_req)) + { elm = std::make_shared(req, any_adapter{}); elm->set_done_callback([this] { @@ -121,35 +135,35 @@ struct elem_and_request { void test_success() { // Setup - multiplexer mpx; + connection_state st; elem_and_request input; - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); error_code ec; // Initiate - auto act = fsm.resume(true, cancellation_type_t::none); + auto act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::notify_writer); // We should now wait for a response - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::wait_for_response); // Simulate a successful write - BOOST_TEST_EQ(mpx.prepare_write(), 1u); // one request was placed in the packet to write - BOOST_TEST(mpx.commit_write(mpx.get_write_buffer().size())); + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); // one request was placed in the packet to write + BOOST_TEST(st.mpx.commit_write(st.mpx.get_write_buffer().size())); // Simulate a successful read - read(mpx, "$5\r\nhello\r\n"); - auto req_status = mpx.consume(ec); + read(st.mpx, "$5\r\nhello\r\n"); + auto req_status = st.mpx.consume(ec); BOOST_TEST_EQ(ec, error_code()); BOOST_TEST_EQ(req_status.first, consume_result::got_response); BOOST_TEST_EQ(req_status.second, 11u); // the entire buffer was consumed BOOST_TEST_EQ(input.done_calls, 1u); // This will awaken the exec operation, and should complete the operation - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action(error_code(), 11u)); // All memory should have been freed by now @@ -160,37 +174,37 @@ void test_success() void test_parse_error() { // Setup - multiplexer mpx; + connection_state st; elem_and_request input; - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); error_code ec; // Initiate - auto act = fsm.resume(true, cancellation_type_t::none); + auto act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::notify_writer); // We should now wait for a response - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::wait_for_response); // Simulate a successful write - BOOST_TEST_EQ(mpx.prepare_write(), 1u); // one request was placed in the packet to write - BOOST_TEST(mpx.commit_write(mpx.get_write_buffer().size())); + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); // one request was placed in the packet to write + BOOST_TEST(st.mpx.commit_write(st.mpx.get_write_buffer().size())); // Simulate a read that will trigger an error. // The second field should be a number (rather than the empty string). // Note that although part of the buffer was consumed, the multiplexer // currently throws this information away. - read(mpx, "*2\r\n$5\r\nhello\r\n:\r\n"); - auto req_status = mpx.consume(ec); + read(st.mpx, "*2\r\n$5\r\nhello\r\n:\r\n"); + auto req_status = st.mpx.consume(ec); BOOST_TEST_EQ(ec, error::empty_field); BOOST_TEST_EQ(req_status.second, 15u); BOOST_TEST_EQ(input.done_calls, 1u); // This will awaken the exec operation, and should complete the operation - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action(error::empty_field, 0u)); // All memory should have been freed by now @@ -201,17 +215,17 @@ void test_parse_error() void test_cancel_if_not_connected() { // Setup - multiplexer mpx; + connection_state st; request::config cfg; cfg.cancel_if_not_connected = true; elem_and_request input(cfg); - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); // Initiate. We're not connected, so the request gets cancelled - auto act = fsm.resume(false, cancellation_type_t::none); + auto act = fsm.resume(false, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::immediate); - act = fsm.resume(false, cancellation_type_t::none); + act = fsm.resume(false, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action(error::not_connected)); // We didn't leave memory behind @@ -222,35 +236,35 @@ void test_cancel_if_not_connected() void test_not_connected() { // Setup - multiplexer mpx; + connection_state st; elem_and_request input; - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); error_code ec; // Initiate - auto act = fsm.resume(false, cancellation_type_t::none); + auto act = fsm.resume(false, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::notify_writer); // We should now wait for a response - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::wait_for_response); // Simulate a successful write - BOOST_TEST_EQ(mpx.prepare_write(), 1u); // one request was placed in the packet to write - BOOST_TEST(mpx.commit_write(mpx.get_write_buffer().size())); + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); // one request was placed in the packet to write + BOOST_TEST(st.mpx.commit_write(st.mpx.get_write_buffer().size())); // Simulate a successful read - read(mpx, "$5\r\nhello\r\n"); - auto req_status = mpx.consume(ec); + read(st.mpx, "$5\r\nhello\r\n"); + auto req_status = st.mpx.consume(ec); BOOST_TEST_EQ(ec, error_code()); BOOST_TEST_EQ(req_status.first, consume_result::got_response); BOOST_TEST_EQ(req_status.second, 11u); // the entire buffer was consumed BOOST_TEST_EQ(input.done_calls, 1u); // This will awaken the exec operation, and should complete the operation - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action(error_code(), 11u)); // All memory should have been freed by now @@ -277,24 +291,24 @@ void test_cancel_waiting() for (const auto& tc : test_cases) { // Setup - multiplexer mpx; + connection_state st; elem_and_request input, input2; - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); // Another request enters the multiplexer, so it's busy when we start - mpx.add(input2.elm); - BOOST_TEST_EQ_MSG(mpx.prepare_write(), 1u, tc.name); + st.mpx.add(input2.elm); + BOOST_TEST_EQ_MSG(st.mpx.prepare_write(), 1u, tc.name); // Initiate and wait - auto act = fsm.resume(true, cancellation_type_t::none); + auto act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::setup_cancellation, tc.name); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::notify_writer, tc.name); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::wait_for_response, tc.name); // We get notified because the request got cancelled - act = fsm.resume(true, tc.type); + act = fsm.resume(true, st, tc.type); BOOST_TEST_EQ_MSG(act, exec_action(asio::error::operation_aborted), tc.name); BOOST_TEST_EQ_MSG(input.weak_elm.expired(), true, tc.name); // we didn't leave memory behind } @@ -314,32 +328,32 @@ void test_cancel_notwaiting_terminal_partial() for (const auto& tc : test_cases) { // Setup - multiplexer mpx; + connection_state st; auto input = std::make_unique(); - exec_fsm fsm(mpx, std::move(input->elm)); + exec_fsm fsm(std::move(input->elm)); // Initiate - auto act = fsm.resume(false, cancellation_type_t::none); + auto act = fsm.resume(false, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::setup_cancellation, tc.name); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::notify_writer, tc.name); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ_MSG(act, exec_action_type::wait_for_response, tc.name); // The multiplexer starts writing the request - BOOST_TEST_EQ_MSG(mpx.prepare_write(), 1u, tc.name); - BOOST_TEST_EQ_MSG(mpx.commit_write(mpx.get_write_buffer().size()), true, tc.name); + BOOST_TEST_EQ_MSG(st.mpx.prepare_write(), 1u, tc.name); + BOOST_TEST_EQ_MSG(st.mpx.commit_write(st.mpx.get_write_buffer().size()), true, tc.name); // A cancellation arrives - act = fsm.resume(true, tc.type); + act = fsm.resume(true, st, tc.type); BOOST_TEST_EQ(act, exec_action(asio::error::operation_aborted)); input.reset(); // Verify we don't access the request or response after completion error_code ec; // When the response to this request arrives, it gets ignored - read(mpx, "-ERR wrong command\r\n"); - auto res = mpx.consume(ec); + read(st.mpx, "-ERR wrong command\r\n"); + auto res = st.mpx.consume(ec); BOOST_TEST_EQ_MSG(ec, error_code(), tc.name); BOOST_TEST_EQ_MSG(res.first, consume_result::got_response, tc.name); @@ -352,44 +366,122 @@ void test_cancel_notwaiting_terminal_partial() void test_cancel_notwaiting_total() { // Setup - multiplexer mpx; + connection_state st; elem_and_request input; - exec_fsm fsm(mpx, std::move(input.elm)); + exec_fsm fsm(std::move(input.elm)); error_code ec; // Initiate - auto act = fsm.resume(true, cancellation_type_t::none); + auto act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::notify_writer); - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action_type::wait_for_response); // Simulate a successful write - BOOST_TEST_EQ(mpx.prepare_write(), 1u); - BOOST_TEST(mpx.commit_write(mpx.get_write_buffer().size())); + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); + BOOST_TEST(st.mpx.commit_write(st.mpx.get_write_buffer().size())); // We got requested a cancellation here, but we can't honor it - act = fsm.resume(true, asio::cancellation_type_t::total); + act = fsm.resume(true, st, asio::cancellation_type_t::total); BOOST_TEST_EQ(act, exec_action_type::wait_for_response); // Simulate a successful read - read(mpx, "$5\r\nhello\r\n"); - auto req_status = mpx.consume(ec); + read(st.mpx, "$5\r\nhello\r\n"); + auto req_status = st.mpx.consume(ec); BOOST_TEST_EQ(ec, error_code()); BOOST_TEST_EQ(req_status.first, consume_result::got_response); BOOST_TEST_EQ(req_status.second, 11u); // the entire buffer was consumed BOOST_TEST_EQ(input.done_calls, 1u); // This will awaken the exec operation, and should complete the operation - act = fsm.resume(true, cancellation_type_t::none); + act = fsm.resume(true, st, cancellation_type_t::none); BOOST_TEST_EQ(act, exec_action(error_code(), 11u)); // All memory should have been freed by now BOOST_TEST_EQ(input.weak_elm.expired(), true); } +// If a request completes successfully and contained pubsub changes, these are committed +void test_subscription_tracking_success() +{ + // Setup + request req; + req.subscribe({"ch1", "ch2"}); + connection_state st; + elem_and_request input{std::move(req)}; + exec_fsm fsm(std::move(input.elm)); + + // Initiate + auto act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::notify_writer); + + // We should now wait for a response + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::wait_for_response); + + // Simulate a successful write + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); // one request was placed in the packet to write + BOOST_TEST(st.mpx.commit_write(st.mpx.get_write_buffer().size())); + + // The request doesn't have a response, so this will + // awaken the exec operation, and should complete the operation + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action(error_code(), 0u)); + + // All memory should have been freed by now + BOOST_TEST(input.weak_elm.expired()); + + // The subscription has been added to the tracker + request tracker_req; + st.tracker.compose_subscribe_request(tracker_req); + + request expected_req; + expected_req.push("SUBSCRIBE", "ch1", "ch2"); + BOOST_TEST_EQ(tracker_req.payload(), expected_req.payload()); +} + +// If the request errors, tracked subscriptions are not committed +void test_subscription_tracking_error() +{ + // Setup + request req; + req.subscribe({"ch1", "ch2"}); + connection_state st; + elem_and_request input{std::move(req)}; + exec_fsm fsm(std::move(input.elm)); + + // Initiate + auto act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::setup_cancellation); + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::notify_writer); + + // We should now wait for a response + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action_type::wait_for_response); + + // Simulate a write error, which would trigger a reconnection + BOOST_TEST_EQ(st.mpx.prepare_write(), 1u); // one request was placed in the packet to write + st.mpx.cancel_on_conn_lost(); + + // This awakens the request + act = fsm.resume(true, st, cancellation_type_t::none); + BOOST_TEST_EQ(act, exec_action(asio::error::operation_aborted, 0u)); + + // All memory should have been freed by now + BOOST_TEST(input.weak_elm.expired()); + + // The subscription has not been added to the tracker + request tracker_req; + st.tracker.compose_subscribe_request(tracker_req); + BOOST_TEST_EQ(tracker_req.payload(), ""); +} + } // namespace int main() @@ -401,6 +493,8 @@ int main() test_cancel_waiting(); test_cancel_notwaiting_terminal_partial(); test_cancel_notwaiting_total(); + test_subscription_tracking_success(); + test_subscription_tracking_error(); return boost::report_errors(); } diff --git a/test/test_request.cpp b/test/test_request.cpp index 63b59ca3c..9777e790c 100644 --- a/test/test_request.cpp +++ b/test/test_request.cpp @@ -6,16 +6,72 @@ #include +#include #include +#include +#include +#include +#include #include +#include +#include #include #include +#include -using boost::redis::request; +using namespace boost::redis; +using detail::pubsub_change; +using detail::pubsub_change_type; namespace { +// --- Utilities to check subscription tracking --- +const char* to_string(pubsub_change_type type) +{ + switch (type) { + case pubsub_change_type::subscribe: return "subscribe"; + case pubsub_change_type::unsubscribe: return "unsubscribe"; + case pubsub_change_type::psubscribe: return "psubscribe"; + case pubsub_change_type::punsubscribe: return "punsubscribe"; + default: return ""; + } +} + +// Like pubsub_change, but using a string instead of an offset +struct pubsub_change_str { + pubsub_change_type type; + std::string_view value; + + friend bool operator==(const pubsub_change_str& lhs, const pubsub_change_str& rhs) + { + return lhs.type == rhs.type && lhs.value == rhs.value; + } + + friend std::ostream& operator<<(std::ostream& os, const pubsub_change_str& value) + { + return os << "{ " << to_string(value.type) << ", " << value.value << " }"; + } +}; + +void check_pubsub_changes( + const request& req, + boost::span expected, + boost::source_location loc = BOOST_CURRENT_LOCATION) +{ + // Convert from offsets to strings + std::vector actual; + for (const auto& change : detail::request_access::pubsub_changes(req)) { + actual.push_back( + {change.type, req.payload().substr(change.channel_offset, change.channel_size)}); + } + + // Check + if (!BOOST_TEST_ALL_EQ(actual.begin(), actual.end(), expected.begin(), expected.end())) + std::cerr << "Called from " << loc << std::endl; +} + +// --- Generic functions to add commands --- void test_push_no_args() { request req1; @@ -38,6 +94,26 @@ void test_push_multiple_args() BOOST_TEST_EQ(req.payload(), res); } +// Subscription commands added with push are not tracked +void test_push_pubsub() +{ + request req; + req.push("SUBSCRIBE", "ch1"); + req.push("UNSUBSCRIBE", "ch2"); + req.push("PSUBSCRIBE", "ch3*"); + req.push("PUNSUBSCRIBE", "ch4*"); + + char const* res = + "*2\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n" + "*2\r\n$11\r\nUNSUBSCRIBE\r\n$3\r\nch2\r\n" + "*2\r\n$10\r\nPSUBSCRIBE\r\n$4\r\nch3*\r\n" + "*2\r\n$12\r\nPUNSUBSCRIBE\r\n$4\r\nch4*\r\n"; + BOOST_TEST_EQ(req.payload(), res); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// --- push_range --- void test_push_range() { std::map in{ @@ -58,7 +134,340 @@ void test_push_range() BOOST_TEST_EQ(req2.payload(), expected); } -// Append +// Subscription commands added with push_range are not tracked +void test_push_range_pubsub() +{ + const std::vector channels1{"ch1", "ch2"}, channels2{"ch3"}, patterns1{"ch3*"}, + patterns2{"ch4*"}; + request req; + req.push_range("SUBSCRIBE", channels1); + req.push_range("UNSUBSCRIBE", channels2); + req.push_range("PSUBSCRIBE", patterns1); + req.push_range("PUNSUBSCRIBE", patterns2); + + char const* res = + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n" + "*2\r\n$11\r\nUNSUBSCRIBE\r\n$3\r\nch3\r\n" + "*2\r\n$10\r\nPSUBSCRIBE\r\n$4\r\nch3*\r\n" + "*2\r\n$12\r\nPUNSUBSCRIBE\r\n$4\r\nch4*\r\n"; + BOOST_TEST_EQ(req.payload(), res); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// --- subscribe --- +// Most of the tests build the same request using different overloads. +// This fixture makes checking easier +struct subscribe_fixture { + request req; + + void check_impl( + std::string_view expected_payload, + pubsub_change_type expected_type, + boost::source_location loc = BOOST_CURRENT_LOCATION) + { + if (!BOOST_TEST_EQ(req.payload(), expected_payload)) + std::cerr << "Called from " << loc << std::endl; + + if (!BOOST_TEST_EQ(req.get_commands(), 1u)) + std::cerr << "Called from " << loc << std::endl; + + if (!BOOST_TEST_EQ(req.get_expected_responses(), 0u)) + std::cerr << "Called from " << loc << std::endl; + + const pubsub_change_str expected_changes[] = { + {expected_type, "ch1"}, + {expected_type, "ch2"}, + }; + check_pubsub_changes(req, expected_changes, loc); + } + + void check_subscribe(boost::source_location loc = BOOST_CURRENT_LOCATION) + { + check_impl( + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n", + pubsub_change_type::subscribe, + loc); + } + + void check_unsubscribe(boost::source_location loc = BOOST_CURRENT_LOCATION) + { + check_impl( + "*3\r\n$11\r\nUNSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n", + pubsub_change_type::unsubscribe, + loc); + } + + void check_psubscribe(boost::source_location loc = BOOST_CURRENT_LOCATION) + { + check_impl( + "*3\r\n$10\r\nPSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n", + pubsub_change_type::psubscribe, + loc); + } + + void check_punsubscribe(boost::source_location loc = BOOST_CURRENT_LOCATION) + { + check_impl( + "*3\r\n$12\r\nPUNSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n", + pubsub_change_type::punsubscribe, + loc); + } +}; + +void test_subscribe_iterators() +{ + subscribe_fixture fix; + const std::forward_list channels{"ch1", "ch2"}; + + fix.req.subscribe(channels.begin(), channels.end()); + + fix.check_subscribe(); +} + +// Like push_range, if the range is empty, this is a no-op +void test_subscribe_iterators_empty() +{ + const std::forward_list channels; + request req; + + req.subscribe(channels.begin(), channels.end()); + + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// Iterators whose value_type is convertible to std::string_view work +void test_subscribe_iterators_convertible_string_view() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.subscribe(channels.begin(), channels.end()); + + fix.check_subscribe(); +} + +// The range overload just dispatches to the iterator one +void test_subscribe_range() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.subscribe(channels); + + fix.check_subscribe(); +} + +// The initializer_list overload just dispatches to the iterator one +void test_subscribe_initializer_list() +{ + subscribe_fixture fix; + + fix.req.subscribe({"ch1", "ch2"}); + + fix.check_subscribe(); +} + +// --- unsubscribe --- +void test_unsubscribe_iterators() +{ + subscribe_fixture fix; + const std::forward_list channels{"ch1", "ch2"}; + + fix.req.unsubscribe(channels.begin(), channels.end()); + + fix.check_unsubscribe(); +} + +// Like push_range, if the range is empty, this is a no-op +void test_unsubscribe_iterators_empty() +{ + const std::forward_list channels; + request req; + + req.unsubscribe(channels.begin(), channels.end()); + + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// Iterators whose value_type is convertible to std::string_view work +void test_unsubscribe_iterators_convertible_string_view() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.unsubscribe(channels.begin(), channels.end()); + + fix.check_unsubscribe(); +} + +// The range overload just dispatches to the iterator one +void test_unsubscribe_range() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.unsubscribe(channels); + + fix.check_unsubscribe(); +} + +// The initializer_list overload just dispatches to the iterator one +void test_unsubscribe_initializer_list() +{ + subscribe_fixture fix; + + fix.req.unsubscribe({"ch1", "ch2"}); + + fix.check_unsubscribe(); +} + +// --- psubscribe --- +void test_psubscribe_iterators() +{ + subscribe_fixture fix; + const std::forward_list channels{"ch1", "ch2"}; + + fix.req.psubscribe(channels.begin(), channels.end()); + + fix.check_psubscribe(); +} + +// Like push_range, if the range is empty, this is a no-op +void test_psubscribe_iterators_empty() +{ + const std::forward_list channels; + request req; + + req.psubscribe(channels.begin(), channels.end()); + + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// Iterators whose value_type is convertible to std::string_view work +void test_psubscribe_iterators_convertible_string_view() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.psubscribe(channels.begin(), channels.end()); + + fix.check_psubscribe(); +} + +// The range overload just dispatches to the iterator one +void test_psubscribe_range() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.psubscribe(channels); + + fix.check_psubscribe(); +} + +// The initializer_list overload just dispatches to the iterator one +void test_psubscribe_initializer_list() +{ + subscribe_fixture fix; + + fix.req.psubscribe({"ch1", "ch2"}); + + fix.check_psubscribe(); +} + +// --- punsubscribe --- +void test_punsubscribe_iterators() +{ + subscribe_fixture fix; + const std::forward_list channels{"ch1", "ch2"}; + + fix.req.punsubscribe(channels.begin(), channels.end()); + + fix.check_punsubscribe(); +} + +// Like push_range, if the range is empty, this is a no-op +void test_punsubscribe_iterators_empty() +{ + const std::forward_list channels; + request req; + + req.punsubscribe(channels.begin(), channels.end()); + + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// Iterators whose value_type is convertible to std::string_view work +void test_punsubscribe_iterators_convertible_string_view() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.punsubscribe(channels.begin(), channels.end()); + + fix.check_punsubscribe(); +} + +// The range overload just dispatches to the iterator one +void test_punsubscribe_range() +{ + subscribe_fixture fix; + const std::vector channels{"ch1", "ch2"}; + + fix.req.punsubscribe(channels); + + fix.check_punsubscribe(); +} + +// The initializer_list overload just dispatches to the iterator one +void test_punsubscribe_initializer_list() +{ + subscribe_fixture fix; + + fix.req.punsubscribe({"ch1", "ch2"}); + + fix.check_punsubscribe(); +} + +// Mixing regular commands and pubsub commands is OK +void test_mix_pubsub_regular() +{ + request req; + req.push("PING"); + req.subscribe({"ch1", "ch2"}); + req.push("GET", "key"); + req.punsubscribe({"ch4*"}); + + constexpr std::string_view expected = + "*1\r\n$4\r\nPING\r\n" + "*3\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n$3\r\nch2\r\n" + "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n" + "*2\r\n$12\r\nPUNSUBSCRIBE\r\n$4\r\nch4*\r\n"; + BOOST_TEST_EQ(req.payload(), expected); + BOOST_TEST_EQ(req.get_commands(), 4u); + BOOST_TEST_EQ(req.get_expected_responses(), 2u); + constexpr pubsub_change_str expected_changes[] = { + {pubsub_change_type::subscribe, "ch1" }, + {pubsub_change_type::subscribe, "ch2" }, + {pubsub_change_type::punsubscribe, "ch4*"}, + }; + check_pubsub_changes(req, expected_changes); +} + +// --- append --- void test_append() { request req1; @@ -77,6 +486,7 @@ void test_append() BOOST_TEST_EQ(req1.payload(), expected); BOOST_TEST_EQ(req1.get_commands(), 3u); BOOST_TEST_EQ(req1.get_expected_responses(), 3u); + check_pubsub_changes(req1, {}); } // Commands without responses are handled correctly @@ -98,6 +508,7 @@ void test_append_no_response() BOOST_TEST_EQ(req1.payload(), expected); BOOST_TEST_EQ(req1.get_commands(), 3u); BOOST_TEST_EQ(req1.get_expected_responses(), 2u); + check_pubsub_changes(req1, {}); } // Flags are not modified by append @@ -140,6 +551,7 @@ void test_append_target_empty() BOOST_TEST_EQ(req1.payload(), expected); BOOST_TEST_EQ(req1.get_commands(), 1u); BOOST_TEST_EQ(req1.get_expected_responses(), 1u); + check_pubsub_changes(req1, {}); } void test_append_source_empty() @@ -155,6 +567,7 @@ void test_append_source_empty() BOOST_TEST_EQ(req1.payload(), expected); BOOST_TEST_EQ(req1.get_commands(), 1u); BOOST_TEST_EQ(req1.get_expected_responses(), 1u); + check_pubsub_changes(req1, {}); } void test_append_both_empty() @@ -167,6 +580,89 @@ void test_append_both_empty() BOOST_TEST_EQ(req1.payload(), ""); BOOST_TEST_EQ(req1.get_commands(), 0u); BOOST_TEST_EQ(req1.get_expected_responses(), 0u); + check_pubsub_changes(req1, {}); +} + +// Append correctly handles requests with pubsub changes +void test_append_pubsub() +{ + request req1; + req1.subscribe({"ch1"}); + + auto req2 = std::make_unique(); + req2->unsubscribe({"ch2"}); + req2->psubscribe({"really_very_long_pattern_name*"}); + + req1.append(*req2); + req2.reset(); // make sure we don't leave dangling pointers + + constexpr std::string_view expected = + "*2\r\n$9\r\nSUBSCRIBE\r\n$3\r\nch1\r\n" + "*2\r\n$11\r\nUNSUBSCRIBE\r\n$3\r\nch2\r\n" + "*2\r\n$10\r\nPSUBSCRIBE\r\n$30\r\nreally_very_long_pattern_name*\r\n"; + BOOST_TEST_EQ(req1.payload(), expected); + const pubsub_change_str expected_changes[] = { + {pubsub_change_type::subscribe, "ch1" }, + {pubsub_change_type::unsubscribe, "ch2" }, + {pubsub_change_type::psubscribe, "really_very_long_pattern_name*"}, + }; + check_pubsub_changes(req1, expected_changes); +} + +// If the target is empty and the source has pubsub changes, that's OK +void test_append_pubsub_target_empty() +{ + request req1; + + request req2; + req2.punsubscribe({"ch2"}); + + req1.append(req2); + + constexpr std::string_view expected = "*2\r\n$12\r\nPUNSUBSCRIBE\r\n$3\r\nch2\r\n"; + BOOST_TEST_EQ(req1.payload(), expected); + const pubsub_change_str expected_changes[] = { + {pubsub_change_type::punsubscribe, "ch2"}, + }; + check_pubsub_changes(req1, expected_changes); +} + +// --- clear --- +void test_clear() +{ + // Create request with some commands and some pubsub changes + request req; + req.push("PING", "value"); + req.push("GET", "key"); + req.subscribe({"ch1", "ch2"}); + req.punsubscribe({"ch3*"}); + + // Clear removes the payload, the commands and the pubsub changes + req.clear(); + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); + + // Clearing again does nothing + req.clear(); + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); +} + +// Clearing an empty request doesn't cause trouble +void test_clear_empty() +{ + request req; + + req.clear(); + + BOOST_TEST_EQ(req.payload(), ""); + BOOST_TEST_EQ(req.get_commands(), 0u); + BOOST_TEST_EQ(req.get_expected_responses(), 0u); + check_pubsub_changes(req, {}); } } // namespace @@ -176,7 +672,36 @@ int main() test_push_no_args(); test_push_int(); test_push_multiple_args(); + test_push_pubsub(); + test_push_range(); + test_push_range_pubsub(); + + test_subscribe_iterators(); + test_subscribe_iterators_empty(); + test_subscribe_iterators_convertible_string_view(); + test_subscribe_range(); + test_subscribe_initializer_list(); + + test_unsubscribe_iterators(); + test_unsubscribe_iterators_empty(); + test_unsubscribe_iterators_convertible_string_view(); + test_unsubscribe_range(); + test_unsubscribe_initializer_list(); + + test_psubscribe_iterators(); + test_psubscribe_iterators_empty(); + test_psubscribe_iterators_convertible_string_view(); + test_psubscribe_range(); + test_psubscribe_initializer_list(); + + test_punsubscribe_iterators(); + test_punsubscribe_iterators_empty(); + test_punsubscribe_iterators_convertible_string_view(); + test_punsubscribe_range(); + test_punsubscribe_initializer_list(); + + test_mix_pubsub_regular(); test_append(); test_append_no_response(); @@ -184,6 +709,11 @@ int main() test_append_target_empty(); test_append_source_empty(); test_append_both_empty(); + test_append_pubsub(); + test_append_pubsub_target_empty(); + + test_clear(); + test_clear_empty(); return boost::report_errors(); } diff --git a/test/test_run_fsm.cpp b/test/test_run_fsm.cpp index ef5706610..53f413f6f 100644 --- a/test/test_run_fsm.cpp +++ b/test/test_run_fsm.cpp @@ -565,7 +565,7 @@ void test_setup_ping_requests() const std::string_view expected_setup = "*5\r\n$5\r\nHELLO\r\n$1\r\n3\r\n$4\r\nAUTH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"; BOOST_TEST_EQ(fix.st.ping_req.payload(), expected_ping); - BOOST_TEST_EQ(fix.st.cfg.setup.payload(), expected_setup); + BOOST_TEST_EQ(fix.st.setup_req.payload(), expected_setup); // Reconnect act = fix.fsm.resume(fix.st, error::empty_field, cancellation_type_t::none); @@ -579,7 +579,7 @@ void test_setup_ping_requests() // The requests haven't been modified BOOST_TEST_EQ(fix.st.ping_req.payload(), expected_ping); - BOOST_TEST_EQ(fix.st.cfg.setup.payload(), expected_setup); + BOOST_TEST_EQ(fix.st.setup_req.payload(), expected_setup); } // We correctly send and log the setup request diff --git a/test/test_setup_adapter.cpp b/test/test_setup_adapter.cpp index 694700eb2..fab4ce95e 100644 --- a/test/test_setup_adapter.cpp +++ b/test/test_setup_adapter.cpp @@ -30,7 +30,7 @@ void test_success() connection_state st; st.cfg.use_setup = true; st.cfg.setup.push("SELECT", 2); - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -55,7 +55,7 @@ void test_simple_error() // Setup connection_state st; st.cfg.use_setup = true; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO contains an error @@ -73,7 +73,7 @@ void test_blob_error() connection_state st; st.cfg.use_setup = true; st.cfg.setup.push("SELECT", 1); - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -97,7 +97,7 @@ void test_null() // Setup connection_state st; st.cfg.use_setup = true; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -129,7 +129,7 @@ void test_sentinel_master() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -164,7 +164,7 @@ void test_sentinel_replica() {"localhost", "26379"} }; st.cfg.sentinel.server_role = role::replica; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -193,7 +193,7 @@ void test_sentinel_role_check_failed_master() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -222,7 +222,7 @@ void test_sentinel_role_check_failed_replica() {"localhost", "26379"} }; st.cfg.sentinel.server_role = role::replica; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to HELLO @@ -252,7 +252,7 @@ void test_sentinel_role_error_node() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to ROLE @@ -273,7 +273,7 @@ void test_sentinel_role_not_array() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to ROLE @@ -294,7 +294,7 @@ void test_sentinel_role_empty_array() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to ROLE @@ -315,7 +315,7 @@ void test_sentinel_role_first_element_not_string() st.cfg.sentinel.addresses = { {"localhost", "26379"} }; - compose_setup_request(st.cfg); + compose_setup_request(st.cfg, st.tracker, st.setup_req); setup_adapter adapter{st}; // Response to ROLE diff --git a/test/test_subscription_tracker.cpp b/test/test_subscription_tracker.cpp new file mode 100644 index 000000000..7f5d5aefe --- /dev/null +++ b/test/test_subscription_tracker.cpp @@ -0,0 +1,276 @@ +// +// Copyright (c) 2025-2026 Marcelo Zimbres Silva (mzimbres@gmail.com), +// Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#include +#include + +#include + +using namespace boost::redis; +using detail::subscription_tracker; + +namespace { + +// State originated by SUBSCRIBE commands, only +void test_subscribe() +{ + subscription_tracker tracker; + request req1, req2, req_output, req_expected; + + // Add some changes to the tracker + req1.subscribe({"channel_a", "channel_b"}); + tracker.commit_changes(req1); + + req2.subscribe({"channel_c"}); + tracker.commit_changes(req2); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "channel_a", "channel_b", "channel_c"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// State originated by PSUBSCRIBE commands, only +void test_psubscribe() +{ + subscription_tracker tracker; + request req1, req2, req_output, req_expected; + + // Add some changes to the tracker + req1.psubscribe({"channel_b*", "channel_c*"}); + tracker.commit_changes(req1); + + req2.psubscribe({"channel_a*"}); + tracker.commit_changes(req2); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("PSUBSCRIBE", "channel_a*", "channel_b*", "channel_c*"); // we sort them + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// We can mix SUBSCRIBE and PSUBSCRIBE operations +void test_subscribe_psubscribe() +{ + subscription_tracker tracker; + request req1, req2, req_output, req_expected; + + // Add some changes to the tracker + req1.psubscribe({"channel_a*", "channel_b*"}); + req1.subscribe({"ch1"}); + tracker.commit_changes(req1); + + req2.subscribe({"ch2"}); + req2.psubscribe({"channel_c*"}); + tracker.commit_changes(req2); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch1", "ch2"); + req_expected.push("PSUBSCRIBE", "channel_a*", "channel_b*", "channel_c*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// We can have subscribe and psubscribe commands with the same argument +void test_subscribe_psubscribe_same_arg() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + req.subscribe({"ch1"}); + req.psubscribe({"ch1"}); + tracker.commit_changes(req); + + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch1"); + req_expected.push("PSUBSCRIBE", "ch1"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// An unsubscribe/punsubscribe balances a matching subscribe +void test_unsubscribe() +{ + subscription_tracker tracker; + request req1, req2, req_output, req_expected; + + // Add some changes to the tracker + req1.subscribe({"ch1", "ch2"}); + req1.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req1); + + // Unsubscribe from some channels + req2.punsubscribe({"ch2*"}); + req2.unsubscribe({"ch1"}); + tracker.commit_changes(req2); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch2"); + req_expected.push("PSUBSCRIBE", "ch1*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// After an unsubscribe, we can subscribe again +void test_resubscribe() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + // Subscribe to some channels + req.subscribe({"ch1", "ch2"}); + req.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req); + + // Unsubscribe from some channels + req.clear(); + req.punsubscribe({"ch2*"}); + req.unsubscribe({"ch1"}); + tracker.commit_changes(req); + + // Subscribe again + req.clear(); + req.subscribe({"ch1"}); + req.psubscribe({"ch2*"}); + tracker.commit_changes(req); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch1", "ch2"); + req_expected.push("PSUBSCRIBE", "ch1*", "ch2*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// Subscribing twice is not a problem +void test_subscribe_twice() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + // Subscribe to some channels + req.subscribe({"ch1", "ch2"}); + req.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req); + + // Subscribe to the same channels again + req.clear(); + req.subscribe({"ch2"}); + req.psubscribe({"ch1*"}); + tracker.commit_changes(req); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch1", "ch2"); + req_expected.push("PSUBSCRIBE", "ch1*", "ch2*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// Unsubscribing from channels we haven't subscribed to is not a problem +void test_lone_unsubscribe() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + // Subscribe to some channels + req.subscribe({"ch1", "ch2"}); + req.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req); + + // Unsubscribe from channels we haven't subscribed to + req.clear(); + req.unsubscribe({"other"}); + req.punsubscribe({"other*"}); + tracker.commit_changes(req); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch1", "ch2"); + req_expected.push("PSUBSCRIBE", "ch1*", "ch2*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// A state with no changes is not a problem +void test_empty() +{ + subscription_tracker tracker; + request req_output; + + tracker.compose_subscribe_request(req_output); + BOOST_TEST_EQ(req_output.payload(), ""); +} + +// If the output request is not empty, the commands are added to it, rather than replaced +void test_output_request_not_empty() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + // Subscribe to some channels + req.subscribe({"ch1", "ch2"}); + req.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req); + + // Compose the output request + req_output.push("PING", "hello"); + tracker.compose_subscribe_request(req_output); + + // Check that we generate the correct response + req_expected.push("PING", "hello"); + req_expected.push("SUBSCRIBE", "ch1", "ch2"); + req_expected.push("PSUBSCRIBE", "ch1*", "ch2*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +// Clear removes everything from the state +void test_clear() +{ + subscription_tracker tracker; + request req, req_output, req_expected; + + // Subscribe to some channels + req.subscribe({"ch1", "ch2"}); + req.psubscribe({"ch1*", "ch2*"}); + tracker.commit_changes(req); + + // Clear + tracker.clear(); + + // Nothing should be generated + tracker.compose_subscribe_request(req_output); + BOOST_TEST_EQ(req_output.payload(), ""); + + // We can reuse the tracker by now committing some more changes + req.clear(); + req.subscribe({"ch5"}); + req.psubscribe({"ch6*"}); + tracker.commit_changes(req); + + // Check that we generate the correct response + tracker.compose_subscribe_request(req_output); + req_expected.push("SUBSCRIBE", "ch5"); + req_expected.push("PSUBSCRIBE", "ch6*"); + BOOST_TEST_EQ(req_output.payload(), req_expected.payload()); +} + +} // namespace + +int main() +{ + test_subscribe(); + test_psubscribe(); + test_subscribe_psubscribe(); + test_subscribe_psubscribe_same_arg(); + test_unsubscribe(); + test_resubscribe(); + test_subscribe_twice(); + test_lone_unsubscribe(); + test_empty(); + test_output_request_not_empty(); + test_clear(); + + return boost::report_errors(); +}