diff --git a/discretesampling/base/random.py b/discretesampling/base/random.py index 7e95aec..fc0862b 100644 --- a/discretesampling/base/random.py +++ b/discretesampling/base/random.py @@ -1,5 +1,6 @@ from random import randint import random +import numpy as np class Random(object): @@ -35,6 +36,31 @@ def eval(self): return random.choices(population=self.population, weights=self.weights, cum_weights=self.cum_weights, k=self.k) +class Dice(object): + def __init__(self, probabilities, outcomes): + assert len(outcomes) == len(probabilities), "Invalid PMF specified, x and p" +\ + " of different lengths" + probabilities = np.array(probabilities) + tolerance = np.sqrt(np.finfo(np.float64).eps) + assert abs(1 - sum(probabilities)) < tolerance, "Invalid PMF specified," +\ + " sum of probabilities !~= 1.0" + assert all(probabilities >= 0.), "Invalid PMF specified, all probabilities" +\ + " must be > 0" + self.probabilities = probabilities + self.outcomes = outcomes + self.pmf = probabilities + self.cmf = np.cumsum(probabilities) + self.randomiser = Random() + + def eval(self): + q = self.randomiser.eval() + x = self.outcomes[np.argmax(self.cmf >= q)] + + while callable(x): + x = x() + return x + + def set_seed(seed): """ :param seed: random seed diff --git a/discretesampling/base/types.py b/discretesampling/base/types.py index 546f288..b32cba8 100644 --- a/discretesampling/base/types.py +++ b/discretesampling/base/types.py @@ -28,19 +28,8 @@ def getOptimalLKernelType(self): class DiscreteVariableProposal: - def __init__(self, values, probs): - # Check dims and probs are valid - assert len(values) == len(probs), "Invalid PMF specified, x and p" +\ - " of different lengths" - probs = np.array(probs) - tolerance = np.sqrt(np.finfo(np.float64).eps) - assert abs(1 - sum(probs)) < tolerance, "Invalid PMF specified," +\ - " sum of probabilities !~= 1.0" - assert all(probs > 0), "Invalid PMF specified, all probabilities" +\ - " must be > 0" - self.x = values - self.pmf = probs - self.cmf = np.cumsum(probs) + def __init__(self, moves_dice): + self.moves_dice = moves_dice @classmethod def norm(self, x): @@ -54,8 +43,7 @@ def heuristic(self, x, y): return True def sample(self): - q = random.random() # random unif(0,1) - return self.x[np.argmax(self.cmf >= q)] + return self.moves_dice.eval() def eval(self, y): try: diff --git a/discretesampling/domain/decision_tree/tree_distribution.py b/discretesampling/domain/decision_tree/tree_distribution.py index 4be97ac..2fd2bbf 100644 --- a/discretesampling/domain/decision_tree/tree_distribution.py +++ b/discretesampling/domain/decision_tree/tree_distribution.py @@ -1,11 +1,14 @@ import numpy as np from math import log, inf import copy -from ...base.random import Random +from ...base.random import Dice from ...base import types class TreeProposal(types.DiscreteVariableProposal): + moves_prob = [0.4, 0.1, 0.1, 0.4] + moves = ["prune", "swap", "change", "grow"] # noqa + def __init__(self, tree): self.X_train = tree.X_train self.y_train = tree.y_train @@ -21,40 +24,25 @@ def norm(self, tree): def heuristic(self, x, y): return y < x or abs(x-y) < 2 - def sample(self): - # initialise the probabilities of each move - moves = ["prune", "swap", "change", "grow"] # noqa - moves_prob = [0.4, 0.1, 0.1, 0.4] + def get_moves_prob(self): if len(self.tree.tree) == 1: moves_prob = [0.0, 0.0, 0.5, 0.5] - moves_probabilities = np.cumsum(moves_prob) - random_number = Random().eval() - newTree = copy.deepcopy(self.tree) - if random_number < moves_probabilities[0]: - # prune - newTree = newTree.prune() - - elif random_number < moves_probabilities[1]: - # swap - newTree = newTree.swap() - - elif random_number < moves_probabilities[2]: - # change - newTree = newTree.change() - else: - # grow - newTree = newTree.grow() + moves_prob = self.moves_prob + return moves_prob - return newTree + def sample(self): + # initialise the probabilities of each move + moves_prob = self.get_moves_prob() + newTree = copy.deepcopy(self.tree) + moves_dice = Dice(moves_prob, [newTree.prune, newTree.swap, newTree.change, newTree.grow]) + + return moves_dice.eval() def eval(self, sampledTree): initialTree = self.tree - moves_prob = [0.4, 0.1, 0.1, 0.4] logprobability = -inf - if len(initialTree.tree) == 1: - moves_prob = [0.0, 0.0, 0.5, 0.5] - + moves_prob = self.get_moves_prob() nodes_differences = [i for i in sampledTree.tree + initialTree.tree if i not in sampledTree.tree or i not in initialTree.tree]