-
-
Notifications
You must be signed in to change notification settings - Fork 379
Parallel Run Adaptive Sampler #3028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5edf4e2
faa5eb0
751403f
a08ba35
a4c9679
04836ef
0608601
3ebbc5c
a3bf12b
9056101
5c7b0bd
35f88a7
ebeb1f9
028bb51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| #ifndef STAN_CALLBACKS_FILE_STREAM_WRITER_HPP | ||
| #define STAN_CALLBACKS_FILE_STREAM_WRITER_HPP | ||
|
|
||
| #include <stan/callbacks/writer.hpp> | ||
| #include <ostream> | ||
| #include <vector> | ||
| #include <string> | ||
|
|
||
| namespace stan { | ||
| namespace callbacks { | ||
|
|
||
| /** | ||
| * <code>file_stream_writer</code> is an implementation | ||
| * of <code>writer</code> that writes to a file. | ||
| */ | ||
| class file_stream_writer final : public writer { | ||
| public: | ||
| /** | ||
| * Constructs a file stream writer with an output stream | ||
| * and an optional prefix for comments. | ||
| * | ||
| * @param[in, out] A unique pointer to a type inheriting from `std::ostream` | ||
| * @param[in] comment_prefix string to stream before each comment line. | ||
| * Default is "". | ||
| */ | ||
| explicit file_stream_writer(std::unique_ptr<std::ostream>&& output, | ||
| const std::string& comment_prefix = "") | ||
| : output_(std::move(output)), comment_prefix_(comment_prefix) {} | ||
|
|
||
| file_stream_writer(); | ||
| file_stream_writer(file_stream_writer& other) = delete; | ||
| file_stream_writer(file_stream_writer&& other) | ||
| : output_(std::move(other.output_)), | ||
| comment_prefix_(std::move(other.comment_prefix_)) {} | ||
| /** | ||
| * Virtual destructor | ||
| */ | ||
| virtual ~file_stream_writer() {} | ||
|
|
||
| /** | ||
| * Writes a set of names on a single line in csv format followed | ||
| * by a newline. | ||
| * | ||
| * Note: the names are not escaped. | ||
| * | ||
| * @param[in] names Names in a std::vector | ||
| */ | ||
| void operator()(const std::vector<std::string>& names) { | ||
| write_vector(names); | ||
| } | ||
| /** | ||
| * Get the underlying stream | ||
| */ | ||
| auto& get_stream() { return *output_; } | ||
|
|
||
| /** | ||
| * Writes a set of values in csv format followed by a newline. | ||
| * | ||
| * Note: the precision of the output is determined by the settings | ||
| * of the stream on construction. | ||
| * | ||
| * @param[in] state Values in a std::vector | ||
| */ | ||
| void operator()(const std::vector<double>& state) { write_vector(state); } | ||
|
|
||
| /** | ||
| * Writes the comment_prefix to the stream followed by a newline. | ||
| */ | ||
| void operator()() { *output_ << comment_prefix_ << std::endl; } | ||
|
|
||
| /** | ||
| * Writes the comment_prefix then the message followed by a newline. | ||
| * | ||
| * @param[in] message A string | ||
| */ | ||
| void operator()(const std::string& message) { | ||
| *output_ << comment_prefix_ << message << std::endl; | ||
| } | ||
|
|
||
| private: | ||
| /** | ||
| * Output stream | ||
| */ | ||
| std::unique_ptr<std::ostream> output_; | ||
|
|
||
| /** | ||
| * Comment prefix to use when printing comments: strings and blank lines | ||
| */ | ||
| std::string comment_prefix_; | ||
|
|
||
| /** | ||
| * Writes a set of values in csv format followed by a newline. | ||
| * | ||
| * Note: the precision of the output is determined by the settings | ||
| * of the stream on construction. | ||
| * | ||
| * @param[in] v Values in a std::vector | ||
| */ | ||
| template <class T> | ||
| void write_vector(const std::vector<T>& v) { | ||
| if (v.empty()) | ||
| return; | ||
|
|
||
| typename std::vector<T>::const_iterator last = v.end(); | ||
| --last; | ||
|
|
||
| for (typename std::vector<T>::const_iterator it = v.begin(); it != last; | ||
| ++it) | ||
| *output_ << *it << ","; | ||
| *output_ << v.back() << std::endl; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace callbacks | ||
| } // namespace stan | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,14 +44,17 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations, | |
| util::mcmc_writer& mcmc_writer, | ||
| stan::mcmc::sample& init_s, Model& model, | ||
| RNG& base_rng, callbacks::interrupt& callback, | ||
| callbacks::logger& logger) { | ||
| callbacks::logger& logger, size_t n_chain = 0) { | ||
| for (int m = 0; m < num_iterations; ++m) { | ||
| callback(); | ||
|
|
||
| if (refresh > 0 | ||
| && (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) { | ||
| int it_print_width = std::ceil(std::log10(static_cast<double>(finish))); | ||
| std::stringstream message; | ||
| if (n_chain > 0) { | ||
| message << "Chain [" << (n_chain + 1) << "]"; | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense but we can run it by interfaces.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah anyone that is depending on the current behavior shouldn't be effected by this so I think it's fine |
||
| message << "Iteration: "; | ||
| message << std::setw(it_print_width) << m + 1 + start << " / " << finish; | ||
| message << " [" << std::setw(3) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| #include <stan/callbacks/writer.hpp> | ||
| #include <stan/services/util/generate_transitions.hpp> | ||
| #include <stan/services/util/mcmc_writer.hpp> | ||
| #include <tbb/parallel_for.h> | ||
| #include <chrono> | ||
| #include <vector> | ||
|
|
||
|
|
@@ -34,7 +35,7 @@ namespace util { | |
| * @param[in,out] sample_writer writer for draws | ||
| * @param[in,out] diagnostic_writer writer for diagnostic information | ||
| */ | ||
| template <class Sampler, class Model, class RNG> | ||
| template <typename Sampler, typename Model, typename RNG> | ||
| void run_adaptive_sampler(Sampler& sampler, Model& model, | ||
| std::vector<double>& cont_vector, int num_warmup, | ||
| int num_samples, int num_thin, int refresh, | ||
|
|
@@ -87,6 +88,83 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, | |
| / 1000.0; | ||
| writer.write_timing(warm_delta_t, sample_delta_t); | ||
| } | ||
|
|
||
| template <class Sampler, class Model, class RNG, typename SampT, | ||
| typename DiagnoseT> | ||
| void run_adaptive_sampler(std::vector<Sampler>& samplers, Model& model, | ||
| std::vector<std::vector<double>>& cont_vectors, | ||
| int num_warmup, int num_samples, int num_thin, | ||
| int refresh, bool save_warmup, std::vector<RNG>& rngs, | ||
| callbacks::interrupt& interrupt, | ||
| callbacks::logger& logger, | ||
| std::vector<SampT>& sample_writers, | ||
| std::vector<DiagnoseT>& diagnostic_writers, | ||
| size_t n_chain) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the model, interrupt, and logger bits need to be threadsafe. Is every function in the model (other than constructors/destructors) thread safe? I think this threading stuff will mess up the profiles, so we gotta figure out what to do there. For the interrupts and loggers, I see that the implementations in the test had to change but shouldn't there be changes in the main source too? Or were these already threadsafe or something? I see that interrupt is for the interfaces to capture keystrokes. Are the interrupts signals generated at the interfaces and transmitted in here -- or are they signals generated in the samplers and transmitted back to the interfaces?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the logger goes to
The interrupts, at least in cmdstan, are actually just the base impl
Yep! I've looked the model class over. The model is essentially immutable after construction so it's safe to pass around.
Oooh hm idk how this will effect profiling. @rok-cesnovar looking at the profiling gen code I think we might need a mutex in the profile constructor here profile(std::string name, profile_map& profiles)
: key_({name, std::this_thread::get_id()}) {
{
std::lock_guard<std::mutex> map_lock(profile_mutex);
// idt search is thread safe
profile_map::iterator p = profiles.find(key_);
if (p == profiles.end()) {
// idt insertion is thread safe either
profiles[key_] = profile_info();
}
// rest business as usual
}
They are called in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it an option to run the profiler in only one thread or avoid the need for a mutex by using multiple model copies? We do not want a mutex unless we have to.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want separate profiles for each thread if we want profiling to act like it currently does (a separate profile for each chain). Is Stan data actually copied into the model? Or is it references? I guess if it's references then we could actually have lots of models? But maybe that is undesirable.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a mutex? The thread id is a key in the storage map and each thread is thus guaranteed to always access only its "own" profiles. We could switch to using tbb::concurent_hash_map though I think that is unecessary.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The profiler itself runs in one thread, the problem is that in the above constructor the imo idt a mutex is that bad here. I think we can actually put it in the One solution here is to prefill the profile_map keys. Since we know each key is decided by a unique name and thread combination no two threads can cause a race condition because they can write to different parts of the map. So we only need to worry about the map reallocing for new keys. for an example model like model {
profile("ayyy") {
theta ~ beta(1,1); // uniform prior on interval 0,1
}
profile("yooo") {
y ~ bernoulli(theta);
}
}We can have a function in the model like Or something like that to fill up the keys before we do any other insertions. Then we don't need to worry about data reallocation since all the keys are there. But still I think a mutex here is not that bad. It's short lived and only causes locks whenever we are doing profiling (which imo is a usage pattern that is not for performance anyway). I think what I would like is to just put a mutex in the if for now and then later we can do the lock free pattern
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh shoot I typed that out before Rok posted his reply. Yeah from there it looks like it is safe to insert since each key is distinct for each thread. I'll look into this more to make sure it's fine. If I can't sort out anything conclusive then I think moving to a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global variables are problematic for obvious reasons. The TBB data structures should be a good alternative... How often is this code being run over? I mean, is it performance critical? EDIT: If std map operations needed are supposed to be safe in this usage, then no need to switch to the TBB thing, of course.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Insert to a map happens the first time a profile is registered. So once in a lifetime of a model. After that its just read access to get the reference to profile info. And those are separated for separate threads, no race conditions there. If locking the insert in a mutex is something that would make us sleep more calmly, thats not a big price to pay. It does seem its not needed.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah idt it's needed. Let's not worry about this for now and we should notice it with a pretty simple test once the cmdstan stuff is up |
||
| if (n_chain == 0) { | ||
| run_adaptive_sampler(samplers[0], model, cont_vectors[0], num_warmup, | ||
| num_samples, num_thin, refresh, save_warmup, rngs[0], | ||
| interrupt, logger, sample_writers[0], | ||
| diagnostic_writers[0]); | ||
| } | ||
| tbb::parallel_for( | ||
| tbb::blocked_range<size_t>(0, n_chain, 1), | ||
| [num_warmup, num_samples, num_thin, refresh, save_warmup, &samplers, | ||
| &model, &rngs, &interrupt, &logger, &sample_writers, &cont_vectors, | ||
| &diagnostic_writers](const tbb::blocked_range<size_t>& r) { | ||
| for (size_t i = r.begin(); i != r.end(); ++i) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now can we just call Long term these could behave differently (if we did custom adaptation or something) but they seem the same for now? Or is there some difference I'm missing?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I didn't realize this until I started working on #3033 lol. Yeah I think that would work? Let me try to do that in #3033. Though I will say I kind of would like to keep this code only as an example of how to do
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright so just ran this locally and it gives back good answers. If we are going to cut this code out then what I can actually do it remove the changes to run_adaptive_sampler in #3033 and just do the simpler parallel loop and then close this PR |
||
| auto&& sampler = samplers[i]; | ||
| sampler.engage_adaptation(); | ||
| Eigen::Map<Eigen::VectorXd> cont_params(cont_vectors[i].data(), | ||
| cont_vectors[i].size()); | ||
| try { | ||
| sampler.z().q = cont_params; | ||
| sampler.init_stepsize(logger); | ||
| } catch (const std::exception& e) { | ||
| logger.info("Exception initializing step size."); | ||
| logger.info(e.what()); | ||
| return; | ||
| } | ||
| stan::mcmc::sample samp(cont_params, 0, 0); | ||
| auto&& sample_writer = sample_writers[i]; | ||
| auto writer = services::util::mcmc_writer( | ||
| sample_writer, diagnostic_writers[i], logger); | ||
|
|
||
| // Headers | ||
| writer.write_sample_names(samp, sampler, model); | ||
| writer.write_diagnostic_names(samp, sampler, model); | ||
|
|
||
| const auto start_warm = std::chrono::steady_clock::now(); | ||
| util::generate_transitions(sampler, num_warmup, 0, | ||
| num_warmup + num_samples, num_thin, | ||
| refresh, save_warmup, true, writer, samp, | ||
| model, rngs[i], interrupt, logger, i); | ||
| const auto end_warm = std::chrono::steady_clock::now(); | ||
| auto warm_delta | ||
| = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
| end_warm - start_warm) | ||
| .count() | ||
| / 1000.0; | ||
| sampler.disengage_adaptation(); | ||
| writer.write_adapt_finish(sampler); | ||
| sampler.write_sampler_state(sample_writer); | ||
|
|
||
| const auto start_sample = std::chrono::steady_clock::now(); | ||
| util::generate_transitions(sampler, num_samples, num_warmup, | ||
| num_warmup + num_samples, num_thin, | ||
| refresh, true, false, writer, samp, model, | ||
| rngs[i], interrupt, logger, i); | ||
| const auto end_sample = std::chrono::steady_clock::now(); | ||
| auto sample_delta | ||
| = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
| end_sample - start_sample) | ||
| .count() | ||
| / 1000.0; | ||
| writer.write_timing(warm_delta, sample_delta); | ||
| } | ||
| }, | ||
| tbb::simple_partitioner()); | ||
| } | ||
|
|
||
| } // namespace util | ||
| } // namespace services | ||
| } // namespace stan | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| #include <gtest/gtest.h> | ||
| #include <boost/lexical_cast.hpp> | ||
| #include <stan/callbacks/file_stream_writer.hpp> | ||
|
|
||
| class StanInterfaceCallbacksStreamWriter : public ::testing::Test { | ||
| public: | ||
| StanInterfaceCallbacksStreamWriter() | ||
| : writer(std::make_unique<std::stringstream>(std::stringstream{})) {} | ||
|
|
||
| void SetUp() { | ||
| static_cast<std::stringstream&>(writer.get_stream()).str(std::string()); | ||
| static_cast<std::stringstream&>(writer.get_stream()).clear(); | ||
| } | ||
| void TearDown() {} | ||
|
|
||
| stan::callbacks::file_stream_writer writer; | ||
| }; | ||
|
|
||
| TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { | ||
| const int N = 5; | ||
| std::vector<double> x; | ||
| for (int n = 0; n < N; ++n) | ||
| x.push_back(n); | ||
|
|
||
| EXPECT_NO_THROW(writer(x)); | ||
| EXPECT_EQ("0,1,2,3,4\n", | ||
| static_cast<std::stringstream&>(writer.get_stream()).str()); | ||
| } | ||
|
|
||
| TEST_F(StanInterfaceCallbacksStreamWriter, string_vector) { | ||
| const int N = 5; | ||
| std::vector<std::string> x; | ||
| for (int n = 0; n < N; ++n) | ||
| x.push_back(boost::lexical_cast<std::string>(n)); | ||
|
|
||
| EXPECT_NO_THROW(writer(x)); | ||
| EXPECT_EQ("0,1,2,3,4\n", | ||
| static_cast<std::stringstream&>(writer.get_stream()).str()); | ||
| } | ||
|
|
||
| TEST_F(StanInterfaceCallbacksStreamWriter, null) { | ||
| EXPECT_NO_THROW(writer()); | ||
| EXPECT_EQ("\n", static_cast<std::stringstream&>(writer.get_stream()).str()); | ||
| } | ||
|
|
||
| TEST_F(StanInterfaceCallbacksStreamWriter, string) { | ||
| EXPECT_NO_THROW(writer("message")); | ||
| EXPECT_EQ("message\n", | ||
| static_cast<std::stringstream&>(writer.get_stream()).str()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do these two lines mean?:
Because the
output_is a unique ptr we don't allow copy constructors only move constructors or something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay I looked back at what you previously wrote:
It seems cleaner to replace stream_writer with something that behaves better. Presumably the other use cases of stream_writer can still work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to replace stream writer, but I was having trouble with backwards compatibility. I think once cmdstan uses the
file_stream_writerwe can removestream_writerThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it difficult because it's one of these cross-repo things where we'd need to change cmdstan and stan at the same time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's the main issue