From c651341b4fb48a7b8c89d43be129ce0fb93a5adc Mon Sep 17 00:00:00 2001 From: Bryan Rumsey Date: Tue, 11 Apr 2023 16:13:37 -0400 Subject: [PATCH] Fixed the epsilon selectors so they correctly handle the rounds. --- sciope/inference/smc_abc.py | 6 +++--- .../epsilonselectors/absolute_epsilon_selector.py | 10 +++++----- .../epsilonselectors/relative_epsilon_selector.py | 14 ++++++++------ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sciope/inference/smc_abc.py b/sciope/inference/smc_abc.py index 87a86dc..db6964e 100644 --- a/sciope/inference/smc_abc.py +++ b/sciope/inference/smc_abc.py @@ -181,10 +181,10 @@ def infer(self, num_samples, batch_size, abc_history.append(abc_results) # SMC iterations - round = 1 + abc_round = 2 # Indicates the round getting ready to start while not terminate: - tol, relative, terminate = eps_selector.get_epsilon(round, abc_history) + tol, relative, terminate = eps_selector.get_epsilon(abc_round, abc_history) print("Starting epsilon = {}".format(tol)) if self.use_logger: @@ -228,7 +228,7 @@ def infer(self, num_samples, batch_size, if self.parameters is not None: abc_results = InferenceRound.build_from_inference_round(abc_results, list(self.parameters.keys())) abc_history.append(abc_results) - round += 1 + abc_round += 1 except KeyboardInterrupt: if self.parameters is None: diff --git a/sciope/utilities/epsilonselectors/absolute_epsilon_selector.py b/sciope/utilities/epsilonselectors/absolute_epsilon_selector.py index c907ffc..009e570 100644 --- a/sciope/utilities/epsilonselectors/absolute_epsilon_selector.py +++ b/sciope/utilities/epsilonselectors/absolute_epsilon_selector.py @@ -34,7 +34,7 @@ def __init__(self, epsilon_sequence): assert (len(epsilon_sequence) > 0) self.epsilon_sequence = epsilon_sequence - self.last_round = len(self.epsilon_sequence) - 1 + self.last_round = len(self.epsilon_sequence) def get_initial_epsilon(self): """Gets the first epsilon in the sequence. @@ -48,14 +48,14 @@ def get_initial_epsilon(self): has_more : bool Whether there are more epsilons after this one """ - return self.epsilon_sequence[0], False, len(self.epsilon_sequence) == 1 + return self.epsilon_sequence[0], False, self.last_round == 1 - def get_epsilon(self, round, abc_history): + def get_epsilon(self, abc_round, abc_history): """Returns the n-th epsilon in the seqeunce. Parameters ---------- - round : int + abc_round : int the round to get the epsilon for abc_history : type A list of dictionaries with keys `accepted_samples` and `distances` @@ -70,4 +70,4 @@ def get_epsilon(self, round, abc_history): terminate : bool Whether to stop after this epsilon """ - return self.epsilon_sequence[round], False, round == self.last_round + return self.epsilon_sequence[abc_round - 1], False, abc_round >= self.last_round diff --git a/sciope/utilities/epsilonselectors/relative_epsilon_selector.py b/sciope/utilities/epsilonselectors/relative_epsilon_selector.py index 0ff4207..5284539 100644 --- a/sciope/utilities/epsilonselectors/relative_epsilon_selector.py +++ b/sciope/utilities/epsilonselectors/relative_epsilon_selector.py @@ -35,6 +35,8 @@ def __init__(self, epsilon_percentile, max_rounds=None): max_rounds : int The maximum number of rounds before stopping. If None, doesn't end. """ + if max_rounds == 0: + raise ValueError("max_rounds must be greater than 0.") self.epsilon_percentile = epsilon_percentile self.max_rounds = max_rounds @@ -52,14 +54,14 @@ def get_initial_epsilon(self): has_more : bool Whether there are more epsilons after this one """ - return self.epsilon_percentile, True, self.max_rounds == 0 + return self.epsilon_percentile, True, self.max_rounds == 1 - def get_epsilon(self, round, abc_history): + def get_epsilon(self, abc_round, abc_history): """Returns the new epsilon based on the n-th round. Parameters ---------- - round : int + abc_round : int the n-th round of the sequence abc_history : type A list of dictionaries with keys `accepted_samples` and `distances` @@ -74,8 +76,8 @@ def get_epsilon(self, round, abc_history): terminate : bool Whether to stop after this epsilon """ - if round > len(abc_history): + if abc_round > len(abc_history): epsilon = np.percentile(abc_history[-1]['distances'], self.epsilon_percentile) else: - epsilon = np.percentile(abc_history[round - 1]['distances'], self.epsilon_percentile) - return epsilon, False, self.max_rounds and round + 1 == self.max_rounds + epsilon = np.percentile(abc_history[abc_round - 1]['distances'], self.epsilon_percentile) + return epsilon, False, self.max_rounds and abc_round >= self.max_rounds