Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: codecov/codecov-action@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
Expand All @@ -84,3 +82,5 @@ jobs:
- name: Test with pytest
run: |
pytest --cov=pygam

- uses: codecov/codecov-action@v3
26 changes: 16 additions & 10 deletions pygam/pygam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections import OrderedDict, defaultdict
from copy import deepcopy
from itertools import product

import numpy as np
import scipy as sp
Expand Down Expand Up @@ -74,7 +75,6 @@
check_X_y,
check_y,
cholesky,
combine,
flatten,
isiterable,
load_diagonal,
Expand Down Expand Up @@ -1969,6 +1969,8 @@ def gridsearch(
admissible_params = list(self.get_params()) + self._plural
params = []
grids = []

grid_size = 1
for param, grid in list(param_grids.items()):
# check param exists
if param not in (admissible_params):
Expand Down Expand Up @@ -2001,20 +2003,21 @@ def gridsearch(
if cartesian:
if len(grid) != target_len:
raise ValueError(msg)
grid = combine(*grid)

if not all([len(subgrid) == target_len for subgrid in grid]):
raise ValueError(msg)
# we should consider each element in `grid` its own dimension
grid_size *= np.prod([len(g) for g in grid])
grid = product(*grid)
else:
if not all([len(subgrid) == target_len for subgrid in grid]):
raise ValueError(msg)
grid_size *= len(grid)
else:
grid_size *= len(grid)

# save param name and grid
params.append(param)
grids.append(grid)

# build a list of dicts of candidate model params
param_grid_list = []
for candidate in combine(*grids):
param_grid_list.append(dict(zip(params, candidate)))

# set up data collection
best_model = None # keep the best model
best_score = np.inf
Expand All @@ -2039,7 +2042,10 @@ def pbar(x):
return x

# loop through candidate model params
for param_grid in pbar(param_grid_list):
for grid in pbar(product(*grids), max_value=grid_size):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it might be beneficial to parallelize this for loop? I can create an issue/PR if that might be interesting.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ankurankan I think that would be great!
Do you think it will run faster?

I have never been very good art parallelizing numerical python routines. My parallel routines always seem to run just as fast/slow as my serial ones...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can give it a quick try and share some benchmarks. Let's see if that improves anything. Have created an issue: #406

# build dict of candidate model params
param_grid = dict(zip(params, grid))

try:
# try fitting
# define new model
Expand Down
27 changes: 0 additions & 27 deletions pygam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,33 +787,6 @@ def ylogydu(y, u):
return out


def combine(*args):
"""
Tool to perform tree search via recursion
useful for developing the grid in a grid search.

Parameters
----------
args : list of lists

Returns
-------
list of all the combinations of the elements in the input lists
"""
if hasattr(args, "__iter__") and (len(args) > 1):
subtree = combine(*args[:-1])
tree = []
for leaf in subtree:
for node in args[-1]:
if hasattr(leaf, "__iter__"):
tree.append(leaf + [node])
else:
tree.append([leaf] + [node])
return tree
else:
return [[arg] for arg in args[0]]


def isiterable(obj, reject_string=True):
"""Convenience tool to detect if something is iterable.
in python3, strings count as iterables to we have the option to exclude them.
Expand Down