diff --git a/src/stan/callbacks/file_stream_writer.hpp b/src/stan/callbacks/file_stream_writer.hpp new file mode 100644 index 00000000000..e182321f831 --- /dev/null +++ b/src/stan/callbacks/file_stream_writer.hpp @@ -0,0 +1,117 @@ +#ifndef STAN_CALLBACKS_FILE_STREAM_WRITER_HPP +#define STAN_CALLBACKS_FILE_STREAM_WRITER_HPP + +#include +#include +#include +#include + +namespace stan { +namespace callbacks { + +/** + * file_stream_writer is an implementation + * of writer that writes to a file. + */ +class file_stream_writer final : public writer { + public: + /** + * Constructs a file stream writer with an output stream + * and an optional prefix for comments. + * + * @param[in, out] A unique pointer to a type inheriting from `std::ostream` + * @param[in] comment_prefix string to stream before each comment line. + * Default is "". + */ + explicit file_stream_writer(std::unique_ptr&& output, + const std::string& comment_prefix = "") + : output_(std::move(output)), comment_prefix_(comment_prefix) {} + + file_stream_writer(); + file_stream_writer(file_stream_writer& other) = delete; + file_stream_writer(file_stream_writer&& other) + : output_(std::move(other.output_)), + comment_prefix_(std::move(other.comment_prefix_)) {} + /** + * Virtual destructor + */ + virtual ~file_stream_writer() {} + + /** + * Writes a set of names on a single line in csv format followed + * by a newline. + * + * Note: the names are not escaped. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) { + write_vector(names); + } + /** + * Get the underlying stream + */ + auto& get_stream() { return *output_; } + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] state Values in a std::vector + */ + void operator()(const std::vector& state) { write_vector(state); } + + /** + * Writes the comment_prefix to the stream followed by a newline. + */ + void operator()() { *output_ << comment_prefix_ << std::endl; } + + /** + * Writes the comment_prefix then the message followed by a newline. + * + * @param[in] message A string + */ + void operator()(const std::string& message) { + *output_ << comment_prefix_ << message << std::endl; + } + + private: + /** + * Output stream + */ + std::unique_ptr output_; + + /** + * Comment prefix to use when printing comments: strings and blank lines + */ + std::string comment_prefix_; + + /** + * Writes a set of values in csv format followed by a newline. + * + * Note: the precision of the output is determined by the settings + * of the stream on construction. + * + * @param[in] v Values in a std::vector + */ + template + void write_vector(const std::vector& v) { + if (v.empty()) + return; + + typename std::vector::const_iterator last = v.end(); + --last; + + for (typename std::vector::const_iterator it = v.begin(); it != last; + ++it) + *output_ << *it << ","; + *output_ << v.back() << std::endl; + } +}; + +} // namespace callbacks +} // namespace stan + +#endif diff --git a/src/stan/callbacks/stream_logger.hpp b/src/stan/callbacks/stream_logger.hpp index 6e6cca07e64..677add7c707 100644 --- a/src/stan/callbacks/stream_logger.hpp +++ b/src/stan/callbacks/stream_logger.hpp @@ -14,7 +14,7 @@ namespace callbacks { * logger that writes messages to separate * std::stringstream outputs. */ -class stream_logger : public logger { +class stream_logger final : public logger { private: std::ostream& debug_; std::ostream& info_; diff --git a/src/stan/callbacks/tee_writer.hpp b/src/stan/callbacks/tee_writer.hpp index 8eca4af61a7..c491832e9ec 100644 --- a/src/stan/callbacks/tee_writer.hpp +++ b/src/stan/callbacks/tee_writer.hpp @@ -16,7 +16,7 @@ namespace callbacks { * For any call to this writer, it will tee the call to both writers * provided in the constructor. */ -class tee_writer : public writer { +class tee_writer final : public writer { public: /** * Constructor accepting two writers. diff --git a/src/stan/mcmc/hmc/hamiltonians/softabs_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/softabs_metric.hpp index 10f24dffe48..da2e392c90b 100644 --- a/src/stan/mcmc/hmc/hamiltonians/softabs_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/softabs_metric.hpp @@ -174,26 +174,26 @@ class softabs_metric : public base_hamiltonian { // Threshold below which a power series // approximation of the softabs function is used - static double lower_softabs_thresh; + static constexpr double lower_softabs_thresh = 1e-4; // Threshold above which an asymptotic // approximation of the softabs function is used - static double upper_softabs_thresh; + static constexpr double upper_softabs_thresh = 18; // Threshold below which an exact derivative is // used in the Jacobian calculation instead of // finite differencing - static double jacobian_thresh; + static constexpr double jacobian_thresh = 1e-10; }; template -double softabs_metric::lower_softabs_thresh = 1e-4; +constexpr double softabs_metric::lower_softabs_thresh; template -double softabs_metric::upper_softabs_thresh = 18; +constexpr double softabs_metric::upper_softabs_thresh; template -double softabs_metric::jacobian_thresh = 1e-10; +constexpr double softabs_metric::jacobian_thresh; } // namespace mcmc } // namespace stan #endif diff --git a/src/stan/services/util/generate_transitions.hpp b/src/stan/services/util/generate_transitions.hpp index 1516c6f0f99..8ad93706c6a 100644 --- a/src/stan/services/util/generate_transitions.hpp +++ b/src/stan/services/util/generate_transitions.hpp @@ -44,7 +44,7 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, util::mcmc_writer& mcmc_writer, stan::mcmc::sample& init_s, Model& model, RNG& base_rng, callbacks::interrupt& callback, - callbacks::logger& logger) { + callbacks::logger& logger, size_t n_chain = 0) { for (int m = 0; m < num_iterations; ++m) { callback(); @@ -52,6 +52,9 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, && (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { int it_print_width = std::ceil(std::log10(static_cast(finish))); std::stringstream message; + if (n_chain > 0) { + message << "Chain [" << (n_chain + 1) << "]"; + } message << "Iteration: "; message << std::setw(it_print_width) << m + 1 + start << " / " << finish; message << " [" << std::setw(3) diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index ba0135158dd..dc2105b290a 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -34,7 +35,7 @@ namespace util { * @param[in,out] sample_writer writer for draws * @param[in,out] diagnostic_writer writer for diagnostic information */ -template +template void run_adaptive_sampler(Sampler& sampler, Model& model, std::vector& cont_vector, int num_warmup, int num_samples, int num_thin, int refresh, @@ -87,6 +88,83 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, / 1000.0; writer.write_timing(warm_delta_t, sample_delta_t); } + +template +void run_adaptive_sampler(std::vector& samplers, Model& model, + std::vector>& cont_vectors, + int num_warmup, int num_samples, int num_thin, + int refresh, bool save_warmup, std::vector& rngs, + callbacks::interrupt& interrupt, + callbacks::logger& logger, + std::vector& sample_writers, + std::vector& diagnostic_writers, + size_t n_chain) { + if (n_chain == 0) { + run_adaptive_sampler(samplers[0], model, cont_vectors[0], num_warmup, + num_samples, num_thin, refresh, save_warmup, rngs[0], + interrupt, logger, sample_writers[0], + diagnostic_writers[0]); + } + tbb::parallel_for( + tbb::blocked_range(0, n_chain, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, &samplers, + &model, &rngs, &interrupt, &logger, &sample_writers, &cont_vectors, + &diagnostic_writers](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + auto&& sampler = samplers[i]; + sampler.engage_adaptation(); + Eigen::Map cont_params(cont_vectors[i].data(), + cont_vectors[i].size()); + try { + sampler.z().q = cont_params; + sampler.init_stepsize(logger); + } catch (const std::exception& e) { + logger.info("Exception initializing step size."); + logger.info(e.what()); + return; + } + stan::mcmc::sample samp(cont_params, 0, 0); + auto&& sample_writer = sample_writers[i]; + auto writer = services::util::mcmc_writer( + sample_writer, diagnostic_writers[i], logger); + + // Headers + writer.write_sample_names(samp, sampler, model); + writer.write_diagnostic_names(samp, sampler, model); + + const auto start_warm = std::chrono::steady_clock::now(); + util::generate_transitions(sampler, num_warmup, 0, + num_warmup + num_samples, num_thin, + refresh, save_warmup, true, writer, samp, + model, rngs[i], interrupt, logger, i); + const auto end_warm = std::chrono::steady_clock::now(); + auto warm_delta + = std::chrono::duration_cast( + end_warm - start_warm) + .count() + / 1000.0; + sampler.disengage_adaptation(); + writer.write_adapt_finish(sampler); + sampler.write_sampler_state(sample_writer); + + const auto start_sample = std::chrono::steady_clock::now(); + util::generate_transitions(sampler, num_samples, num_warmup, + num_warmup + num_samples, num_thin, + refresh, true, false, writer, samp, model, + rngs[i], interrupt, logger, i); + const auto end_sample = std::chrono::steady_clock::now(); + auto sample_delta + = std::chrono::duration_cast( + end_sample - start_sample) + .count() + / 1000.0; + writer.write_timing(warm_delta, sample_delta); + } + }, + tbb::simple_partitioner()); +} + } // namespace util } // namespace services } // namespace stan diff --git a/src/test/unit/callbacks/file_stream_writer_test.cpp b/src/test/unit/callbacks/file_stream_writer_test.cpp new file mode 100644 index 00000000000..b3a64e4dcca --- /dev/null +++ b/src/test/unit/callbacks/file_stream_writer_test.cpp @@ -0,0 +1,50 @@ +#include +#include +#include + +class StanInterfaceCallbacksStreamWriter : public ::testing::Test { + public: + StanInterfaceCallbacksStreamWriter() + : writer(std::make_unique(std::stringstream{})) {} + + void SetUp() { + static_cast(writer.get_stream()).str(std::string()); + static_cast(writer.get_stream()).clear(); + } + void TearDown() {} + + stan::callbacks::file_stream_writer writer; +}; + +TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { + const int N = 5; + std::vector x; + for (int n = 0; n < N; ++n) + x.push_back(n); + + EXPECT_NO_THROW(writer(x)); + EXPECT_EQ("0,1,2,3,4\n", + static_cast(writer.get_stream()).str()); +} + +TEST_F(StanInterfaceCallbacksStreamWriter, string_vector) { + const int N = 5; + std::vector x; + for (int n = 0; n < N; ++n) + x.push_back(boost::lexical_cast(n)); + + EXPECT_NO_THROW(writer(x)); + EXPECT_EQ("0,1,2,3,4\n", + static_cast(writer.get_stream()).str()); +} + +TEST_F(StanInterfaceCallbacksStreamWriter, null) { + EXPECT_NO_THROW(writer()); + EXPECT_EQ("\n", static_cast(writer.get_stream()).str()); +} + +TEST_F(StanInterfaceCallbacksStreamWriter, string) { + EXPECT_NO_THROW(writer("message")); + EXPECT_EQ("message\n", + static_cast(writer.get_stream()).str()); +} diff --git a/src/test/unit/services/instrumented_callbacks.hpp b/src/test/unit/services/instrumented_callbacks.hpp index 5df61c17f74..797959ef0ee 100644 --- a/src/test/unit/services/instrumented_callbacks.hpp +++ b/src/test/unit/services/instrumented_callbacks.hpp @@ -9,7 +9,8 @@ #include #include #include - +#include +#include namespace stan { namespace test { namespace unit { @@ -27,7 +28,7 @@ class instrumented_interrupt : public stan::callbacks::interrupt { unsigned int call_count() { return counter_; } private: - unsigned int counter_; + std::atomic counter_; }; /** @@ -96,55 +97,54 @@ class instrumented_writer : public stan::callbacks::writer { unsigned int call_count() { unsigned int n = 0; - for (std::map::iterator it = counter_.begin(); - it != counter_.end(); ++it) - n += it->second; + for (auto& it : counter_) + n += it.second; return n; } unsigned int call_count(std::string s) { return counter_[s]; } - std::vector > string_double_values() { + std::vector> string_double_values() { return string_double; }; - std::vector > string_int_values() { + std::vector> string_int_values() { return string_int; }; - std::vector > string_string_values() { + std::vector> string_string_values() { return string_string; }; - std::vector > > + std::vector>> string_pdouble_int_values() { return string_pdouble_int; }; - std::vector > + std::vector> string_pdouble_int_int_values() { return string_pdouble_int_int; }; - std::vector > vector_string_values() { + std::vector> vector_string_values() { return vector_string; }; - std::vector > vector_double_values() { + std::vector> vector_double_values() { return vector_double; }; std::vector string_values() { return string; }; private: - std::map counter_; - std::vector > string_double; - std::vector > string_int; - std::vector > string_string; - std::vector > > string_pdouble_int; - std::vector > string_pdouble_int_int; - std::vector > vector_string; - std::vector > vector_double; + std::map> counter_; + std::vector> string_double; + std::vector> string_int; + std::vector> string_string; + std::vector>> string_pdouble_int; + std::vector> string_pdouble_int_int; + std::vector> vector_string; + std::vector> vector_double; std::vector string; }; @@ -156,35 +156,53 @@ class instrumented_writer : public stan::callbacks::writer { */ class instrumented_logger : public stan::callbacks::logger { public: + std::mutex logger_guard; instrumented_logger() {} - void debug(const std::string& message) { debug_.push_back(message); } + void debug(const std::string& message) { + std::lock_guard guard(logger_guard); + debug_.push_back(message); + } void debug(const std::stringstream& message) { + std::lock_guard guard(logger_guard); debug_.push_back(message.str()); } - void info(const std::string& message) { info_.push_back(message); } + void info(const std::string& message) { + std::lock_guard guard(logger_guard); + info_.push_back(message); + } void info(const std::stringstream& message) { + std::lock_guard guard(logger_guard); info_.push_back(message.str()); } - void warn(const std::string& message) { warn_.push_back(message); } + void warn(const std::string& message) { + std::lock_guard guard(logger_guard); + warn_.push_back(message); + } void warn(const std::stringstream& message) { + std::lock_guard guard(logger_guard); warn_.push_back(message.str()); } - void error(const std::string& message) { error_.push_back(message); } + void error(const std::string& message) { + std::lock_guard guard(logger_guard); + error_.push_back(message); + } void error(const std::stringstream& message) { + std::lock_guard guard(logger_guard); error_.push_back(message.str()); } void fatal(const std::string& message) { fatal_.push_back(message); } void fatal(const std::stringstream& message) { + std::lock_guard guard(logger_guard); fatal_.push_back(message.str()); } diff --git a/src/test/unit/services/util/run_adaptive_sampler_parallel_test.cpp b/src/test/unit/services/util/run_adaptive_sampler_parallel_test.cpp new file mode 100644 index 00000000000..5d483864b50 --- /dev/null +++ b/src/test/unit/services/util/run_adaptive_sampler_parallel_test.cpp @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +auto&& blah = stan::math::init_threadpool_tbb(); + +static constexpr size_t num_chains = 5; + +class ServicesUtil : public testing::Test { + using model_t = stan::mcmc::adapt_unit_e_nuts; + + public: + ServicesUtil() + : model(context, 0, &model_log), + rng(num_chains), + cont_vector(num_chains, std::vector{0, 0}), + sampler(), + num_warmup(0), + num_samples(0), + num_thin(1), + refresh(0), + n_chain(num_chains), + save_warmup(false) { + rng.clear(); + for (int i = 0; i < num_chains; ++i) { + rng[i] = std::move(stan::services::util::create_rng(0, 1)); + sampler.push_back(model_t(model, rng[i])); + sample_writer.push_back(stan::test::unit::instrumented_writer{}); + diagnostic_writer.push_back(stan::test::unit::instrumented_writer{}); + } + } + + std::stringstream model_log; + stan::io::empty_var_context context; + stan_model model; + std::vector> cont_vector; + std::vector rng; + stan::test::unit::instrumented_interrupt interrupt; + std::vector sample_writer, + diagnostic_writer; + stan::test::unit::instrumented_logger logger; + std::vector sampler; + int num_warmup, num_samples, num_thin, refresh, n_chain; + bool save_warmup; +}; + +TEST_F(ServicesUtil, all_zero) { + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ(0, interrupt.call_count()); + + EXPECT_EQ((3 + 2) * num_chains, logger.call_count()) + << "Writes the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ(8, sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) + << "adaptation info + elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + + EXPECT_EQ(6, diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; +} + +TEST_F(ServicesUtil, num_warmup_no_save) { + num_warmup = 1000; + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ(num_warmup * num_chains, interrupt.call_count()); + + EXPECT_EQ((3 + 2) * num_chains, logger.call_count()) + << "Writes the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ(8, sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) + << "adaptation info + elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + + EXPECT_EQ(6, diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; +} + +TEST_F(ServicesUtil, num_warmup_save) { + num_warmup = 1000; + save_warmup = true; + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ((num_warmup)*num_chains, interrupt.call_count()); + + EXPECT_EQ((3 + 2) * num_chains, logger.call_count()) + << "Writes the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ(num_warmup + 8, sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) + << "adaptation info + elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_warmup, sample_writer[0].call_count("vector_double")) + << "warmup draws"; + + EXPECT_EQ(num_warmup + 6, diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_warmup, diagnostic_writer[0].call_count("vector_double")) + << "warmup draws"; +} + +TEST_F(ServicesUtil, num_samples) { + num_samples = 1000; + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ(num_samples * num_chains, interrupt.call_count()); + + EXPECT_EQ((3 + 2) * num_chains, logger.call_count()) + << "Writes the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ(num_samples + 8, sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) + << "adaptation info + elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_samples, sample_writer[0].call_count("vector_double")) + << "num_samples draws"; + + EXPECT_EQ(num_samples + 6, diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_samples, sample_writer[0].call_count("vector_double")) + << "num_samples draws"; +} + +TEST_F(ServicesUtil, num_warmup_save_num_samples_num_thin) { + num_warmup = 500; + save_warmup = true; + num_samples = 500; + num_thin = 10; + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count()); + + EXPECT_EQ((3 + 2) * num_chains, logger.call_count()) + << "Writes the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ((num_warmup + num_samples) / num_thin + 8, + sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ((num_warmup + num_samples) / num_thin, + sample_writer[0].call_count("vector_double")) + << "thinned warmup and draws"; + + EXPECT_EQ((num_warmup + num_samples) / num_thin + 6, + diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ((num_warmup + num_samples) / num_thin, + diagnostic_writer[0].call_count("vector_double")) + << "thinned warmup and draws"; +} + +TEST_F(ServicesUtil, num_warmup_num_samples_refresh) { + num_warmup = 500; + num_samples = 500; + refresh = 10; + stan::services::util::run_adaptive_sampler( + sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh, + save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer, + num_chains); + EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count()); + + EXPECT_EQ((num_warmup * num_chains + num_samples * num_chains) / refresh + + (2 + 3 + 2) * num_chains, + logger.call_count()) + << "Writes 1 to start warmup, 1 to start post-warmup, and " + << "(num_warmup + num_samples) / refresh, then the elapsed time"; + EXPECT_EQ(logger.call_count(), logger.call_count_info()) + << "No other calls to logger"; + + EXPECT_EQ(num_samples + 8, sample_writer[0].call_count()); + EXPECT_EQ(1, sample_writer[0].call_count("vector_string")) << "header line"; + EXPECT_EQ(2 + 3, sample_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, sample_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_samples, sample_writer[0].call_count("vector_double")) + << "draws"; + + EXPECT_EQ(num_samples + 6, diagnostic_writer[0].call_count()); + EXPECT_EQ(1, diagnostic_writer[0].call_count("vector_string")) + << "header line"; + EXPECT_EQ(3, diagnostic_writer[0].call_count("string")) << "elapsed time"; + EXPECT_EQ(2, diagnostic_writer[0].call_count("empty")) << "blank lines"; + EXPECT_EQ(num_samples, diagnostic_writer[0].call_count("vector_double")) + << "draws"; +}