diff --git a/pygam/__init__.py b/pygam/__init__.py index 912d3fd8..e6f332d7 100644 --- a/pygam/__init__.py +++ b/pygam/__init__.py @@ -17,8 +17,11 @@ from pygam.terms import f from pygam.terms import te from pygam.terms import intercept +from pygam.terms import from_formula __all__ = ['GAM', 'LinearGAM', 'LogisticGAM', 'GammaGAM', 'PoissonGAM', - 'InvGaussGAM', 'ExpectileGAM', 'l', 's', 'f', 'te', 'intercept'] + 'InvGaussGAM', 'ExpectileGAM', + 'l', 's', 'f', 'te', 'intercept', + 'from_formula'] __version__ = '0.8.0' diff --git a/pygam/terms.py b/pygam/terms.py index 117584b4..9b873ffa 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1820,13 +1820,97 @@ def te(*args, **kwargs): intercept = Intercept() + +def from_formula(formula, df, coerce=True, verbose=False): + """ + Pass a (patsy / R like) formula and data frame and returns a terms object that matches + If only a name is given a spline is assumed + :param formula: + :param df: + :param coerce: Whether to try to convert any invalid characters in the dataframe's column names to underscores `_` + :param verbose: Whether to generate outputs about the processing + :return: + """ + import re + + def regex_contains(pattern, string): + return re.compile(pattern).search(string) is not None + + # Required input validation + if '~' not in formula: + raise AssertionError('Formulas should look like `y ~ x + a + l(b)') + + invalid_chars = '+-()' + are_bad_cols = [bool(set(invalid_chars).intersection(set(col_name))) for col_name in df.columns] + if any(are_bad_cols) and coerce is False: + raise AssertionError( + '`df` columns names cannot have {invalid_chars} in their names. Try setting `coerce=True`'.format( + invalid_chars=invalid_chars + ) + ) + elif any(are_bad_cols) and coerce: + # I know this can be optimised since I know where the bad cols are + new_column_names = [] + for term_name in df.columns.tolist(): + for to_replace in invalid_chars: + term_name = term_name.replace(to_replace, '_') # type: str + new_column_names.append(term_name) + df.columns = new_column_names + + target_name, terms = formula.split('~') + target_name, terms = target_name.strip(), [term.strip() for term in terms.split('+')] + if verbose: + print('target name: {}'.format(target_name)) + print(terms) + + if len(terms) == 0 or (len(terms) == 1 and next(iter(terms), '') == ''): + # Bad formula + raise AssertionError('Check input formula {}'.format(formula)) + + # Check for the simplest of all possible formulas. Early terminate here. + linear_term_pattern = r'l\(.*?\)|L\(.*?\)' + factor_term_pattern = r'c\(.*?\)|C\(.*?\)' + spline_term_pattern = r's\(.*?\)|S\(.*?\)' + + if terms[0] == '*': + term_list = intercept + for i, term_name in enumerate(df.columns): + if target_name in term_name: + continue + term_list += s(i) + return term_list + else: + term_list = intercept + for term in terms: # type: str + if regex_contains(linear_term_pattern, term): + if verbose: + print('{} -> linear term'.format(term)) + term = re.sub(r'(l\()|(L\()|\)', '', term) + term_list += l(df.columns.tolist().index(term)) + elif regex_contains(factor_term_pattern, term): + if verbose: + print('{} -> factor term'.format(term)) + term = re.sub(r'(c\()|(C\()|\)', '', term) + term_list += f(df.columns.tolist().index(term)) + elif regex_contains(spline_term_pattern, term): + if verbose: + print('{} -> spline term'.format(term)) + term = re.sub(r'(s\()|(S\()|\)', '', term) + term_list += s(df.columns.tolist().index(term)) + else: + if verbose: + print('{} -> assumed spline term'.format(term)) + term_list += s(df.columns.tolist().index(term)) + return term_list + + # copy docs for minimal_, class_ in zip([l, s, f, te], [LinearTerm, SplineTerm, FactorTerm, TensorTerm]): minimal_.__doc__ = class_.__init__.__doc__ + minimal_.__doc__ -TERMS = {'term' : Term, - 'intercept_term' : Intercept, +TERMS = {'term': Term, + 'intercept_term': Intercept, 'linear_term': LinearTerm, 'spline_term': SplineTerm, 'factor_term': FactorTerm, diff --git a/pygam/tests/test_terms.py b/pygam/tests/test_terms.py index e4a5c5a4..632d5a5e 100644 --- a/pygam/tests/test_terms.py +++ b/pygam/tests/test_terms.py @@ -3,6 +3,7 @@ from copy import deepcopy import numpy as np +import pandas as pd import pytest from pygam import * @@ -15,6 +16,29 @@ def chicago_gam(chicago_X_y): gam = PoissonGAM(terms=s(0, n_splines=200) + te(3, 1) + s(2)).fit(X, y) return gam +def test_from_formula_bad_formula(): + """Formulas must look like patsy formulas + """ + dummy_df = pd.DataFrame(columns=['SystolicBP', 'Smoke', 'Overwt']) + + for formula_i in ['Smoke + Overwt', 'SystolicBP ~']: + with pytest.raises(AssertionError): + from_formula(formula_i, dummy_df) + + assert from_formula('SystolicBP ~ Smoke + l(Overwt)', dummy_df) is not None + + +def test_from_formula_bad_cols_names(): + """Make sure all bad columns are either detected and properly coerced + """ + bad_df = pd.DataFrame(columns=['Systolic-BP', 'is_smoker', 'Over+wt']) + + with pytest.raises(AssertionError): + from_formula('Systolic-BP ~ *', bad_df, coerce=False) + + assert from_formula('Systolic_BP ~ *', bad_df) + + def test_wrong_length(): """iterable params must all match lengths """