From cd2829b596787430bcab03754d740a792db6fa86 Mon Sep 17 00:00:00 2001 From: arose13 Date: Mon, 8 Jul 2019 17:05:40 -0400 Subject: [PATCH 1/5] patsy like formula creation --- pygam/__init__.py | 5 ++- pygam/terms.py | 79 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/pygam/__init__.py b/pygam/__init__.py index 912d3fd8..e6f332d7 100644 --- a/pygam/__init__.py +++ b/pygam/__init__.py @@ -17,8 +17,11 @@ from pygam.terms import f from pygam.terms import te from pygam.terms import intercept +from pygam.terms import from_formula __all__ = ['GAM', 'LinearGAM', 'LogisticGAM', 'GammaGAM', 'PoissonGAM', - 'InvGaussGAM', 'ExpectileGAM', 'l', 's', 'f', 'te', 'intercept'] + 'InvGaussGAM', 'ExpectileGAM', + 'l', 's', 'f', 'te', 'intercept', + 'from_formula'] __version__ = '0.8.0' diff --git a/pygam/terms.py b/pygam/terms.py index 117584b4..dc279086 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1820,13 +1820,88 @@ def te(*args, **kwargs): intercept = Intercept() + +def from_formula(formula, df, coerce=True) -> TermList: + """ + Pass a (patsy / R like) formula and data frame and returns a terms object that matches + If only a name is given a spline is assumed + :param formula: + :param df: + :param coerce: Whether to try to convert any invalid characters in the dataframe's column names to underscores `_` + :return: + """ + import re + + def regex_contains(pattern, string): + return re.compile(pattern).search(string) is not None + + # Required input validation + if '~' not in formula: + raise AssertionError('Formulas should look like `y ~ x + a + l(b)') + + invalid_chars = '+-()' + are_bad_cols = [bool(set(invalid_chars).intersection(set(col_name))) for col_name in df.columns] + if any(are_bad_cols) and coerce is False: + raise AssertionError( + f'`df` columns names cannot have {invalid_chars} in their names. Try setting `coerce=True`' + ) + elif any(are_bad_cols) and coerce: + # I know this can be optimised since I know where the bad cols are + new_column_names = [] + for term_name in df.columns.tolist(): + for to_replace in invalid_chars: + term_name = term_name.replace(to_replace, '_') # type: str + new_column_names.append(term_name) + df.columns = new_column_names + + target_name, terms = formula.split('~') + target_name, terms = target_name.strip(), [term.strip() for term in terms.split('+')] + print(f'target name: {target_name}') + print(terms) + + if len(terms) == 0: + AssertionError(f'Check input formula {formula}') + + # Check for the simplest of all possible formulas. Early terminate here. + linear_term_pattern = r'l\(.*?\)|L\(.*?\)' + factor_term_pattern = r'c\(.*?\)|C\(.*?\)' + spline_term_pattern = r's\(.*?\)|S\(.*?\)' + + if terms[0] == '*': + term_list = intercept + for i, term_name in enumerate(df.columns): + if target_name in term_name: + continue + term_list += s(i) + return term_list + else: + term_list = intercept + for term in terms: # type: str + if regex_contains(linear_term_pattern, term): + print(f'{term} -> linear term') + term = re.sub(r'(l\()|(L\()|\)', '', term) + term_list += l(df.columns.tolist().index(term)) + elif regex_contains(factor_term_pattern, term): + print(f'{term} -> factor term') + term = re.sub(r'(c\()|(C\()|\)', '', term) + term_list += f(df.columns.tolist().index(term)) + elif regex_contains(spline_term_pattern, term): + print(f'{term} -> spline term') + term = re.sub(r'(s\()|(S\()|\)', '', term) + term_list += s(df.columns.tolist().index(term)) + else: + print(f'{term} -> assumed spline term') + term_list += s(df.columns.tolist().index(term)) + return term_list + + # copy docs for minimal_, class_ in zip([l, s, f, te], [LinearTerm, SplineTerm, FactorTerm, TensorTerm]): minimal_.__doc__ = class_.__init__.__doc__ + minimal_.__doc__ -TERMS = {'term' : Term, - 'intercept_term' : Intercept, +TERMS = {'term': Term, + 'intercept_term': Intercept, 'linear_term': LinearTerm, 'spline_term': SplineTerm, 'factor_term': FactorTerm, From 833a3b5ccc3180442677a1d236f6fad532a1fbad Mon Sep 17 00:00:00 2001 From: arose13 Date: Wed, 10 Jul 2019 11:06:59 -0400 Subject: [PATCH 2/5] trying to pass CI tests --- pygam/terms.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pygam/terms.py b/pygam/terms.py index dc279086..6288631b 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1821,13 +1821,14 @@ def te(*args, **kwargs): intercept = Intercept() -def from_formula(formula, df, coerce=True) -> TermList: +def from_formula(formula, df, coerce=True, verbose=False) -> TermList: """ Pass a (patsy / R like) formula and data frame and returns a terms object that matches If only a name is given a spline is assumed :param formula: :param df: :param coerce: Whether to try to convert any invalid characters in the dataframe's column names to underscores `_` + :param verbose: Whether to generate outputs about the processing :return: """ import re @@ -1843,7 +1844,9 @@ def regex_contains(pattern, string): are_bad_cols = [bool(set(invalid_chars).intersection(set(col_name))) for col_name in df.columns] if any(are_bad_cols) and coerce is False: raise AssertionError( - f'`df` columns names cannot have {invalid_chars} in their names. Try setting `coerce=True`' + '`df` columns names cannot have {invalid_chars} in their names. Try setting `coerce=True`'.format( + invalid_chars=invalid_chars + ) ) elif any(are_bad_cols) and coerce: # I know this can be optimised since I know where the bad cols are @@ -1856,8 +1859,9 @@ def regex_contains(pattern, string): target_name, terms = formula.split('~') target_name, terms = target_name.strip(), [term.strip() for term in terms.split('+')] - print(f'target name: {target_name}') - print(terms) + if verbose: + print('target name: {target_name}'.format(target_name=target_name)) + print(terms) if len(terms) == 0: AssertionError(f'Check input formula {formula}') @@ -1878,19 +1882,23 @@ def regex_contains(pattern, string): term_list = intercept for term in terms: # type: str if regex_contains(linear_term_pattern, term): - print(f'{term} -> linear term') + if verbose: + print('{} -> linear term'.format(term)) term = re.sub(r'(l\()|(L\()|\)', '', term) term_list += l(df.columns.tolist().index(term)) elif regex_contains(factor_term_pattern, term): - print(f'{term} -> factor term') + if verbose: + print('{} -> factor term'.format(term)) term = re.sub(r'(c\()|(C\()|\)', '', term) term_list += f(df.columns.tolist().index(term)) elif regex_contains(spline_term_pattern, term): - print(f'{term} -> spline term') + if verbose: + print('{} -> spline term'.format(term)) term = re.sub(r'(s\()|(S\()|\)', '', term) term_list += s(df.columns.tolist().index(term)) else: - print(f'{term} -> assumed spline term') + if verbose: + print('{} -> assumed spline term'.format(term)) term_list += s(df.columns.tolist().index(term)) return term_list From 0f555c28836d3d0e2cfcf5f810de810717c9ecb2 Mon Sep 17 00:00:00 2001 From: arose13 Date: Wed, 10 Jul 2019 11:22:53 -0400 Subject: [PATCH 3/5] removed Python3 only return type hint --- pygam/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygam/terms.py b/pygam/terms.py index 6288631b..966f06ae 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1821,7 +1821,7 @@ def te(*args, **kwargs): intercept = Intercept() -def from_formula(formula, df, coerce=True, verbose=False) -> TermList: +def from_formula(formula, df, coerce=True, verbose=False): """ Pass a (patsy / R like) formula and data frame and returns a terms object that matches If only a name is given a spline is assumed From 99bc3a28f1720eb02190c5f38e2cdbe9487d75ec Mon Sep 17 00:00:00 2001 From: arose13 Date: Wed, 10 Jul 2019 11:33:40 -0400 Subject: [PATCH 4/5] removed Python3 f strings --- pygam/terms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pygam/terms.py b/pygam/terms.py index 966f06ae..c1c40e40 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1860,11 +1860,11 @@ def regex_contains(pattern, string): target_name, terms = formula.split('~') target_name, terms = target_name.strip(), [term.strip() for term in terms.split('+')] if verbose: - print('target name: {target_name}'.format(target_name=target_name)) + print('target name: {}'.format(target_name)) print(terms) if len(terms) == 0: - AssertionError(f'Check input formula {formula}') + AssertionError('Check input formula {}'.format(formula)) # Check for the simplest of all possible formulas. Early terminate here. linear_term_pattern = r'l\(.*?\)|L\(.*?\)' From 680294b07998c56bb94a398f3054957b2d8b6411 Mon Sep 17 00:00:00 2001 From: arose13 Date: Wed, 10 Jul 2019 12:36:03 -0400 Subject: [PATCH 5/5] New tests for the from_formula function --- pygam/terms.py | 5 +++-- pygam/tests/test_terms.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pygam/terms.py b/pygam/terms.py index c1c40e40..9b873ffa 100644 --- a/pygam/terms.py +++ b/pygam/terms.py @@ -1863,8 +1863,9 @@ def regex_contains(pattern, string): print('target name: {}'.format(target_name)) print(terms) - if len(terms) == 0: - AssertionError('Check input formula {}'.format(formula)) + if len(terms) == 0 or (len(terms) == 1 and next(iter(terms), '') == ''): + # Bad formula + raise AssertionError('Check input formula {}'.format(formula)) # Check for the simplest of all possible formulas. Early terminate here. linear_term_pattern = r'l\(.*?\)|L\(.*?\)' diff --git a/pygam/tests/test_terms.py b/pygam/tests/test_terms.py index e4a5c5a4..632d5a5e 100644 --- a/pygam/tests/test_terms.py +++ b/pygam/tests/test_terms.py @@ -3,6 +3,7 @@ from copy import deepcopy import numpy as np +import pandas as pd import pytest from pygam import * @@ -15,6 +16,29 @@ def chicago_gam(chicago_X_y): gam = PoissonGAM(terms=s(0, n_splines=200) + te(3, 1) + s(2)).fit(X, y) return gam +def test_from_formula_bad_formula(): + """Formulas must look like patsy formulas + """ + dummy_df = pd.DataFrame(columns=['SystolicBP', 'Smoke', 'Overwt']) + + for formula_i in ['Smoke + Overwt', 'SystolicBP ~']: + with pytest.raises(AssertionError): + from_formula(formula_i, dummy_df) + + assert from_formula('SystolicBP ~ Smoke + l(Overwt)', dummy_df) is not None + + +def test_from_formula_bad_cols_names(): + """Make sure all bad columns are either detected and properly coerced + """ + bad_df = pd.DataFrame(columns=['Systolic-BP', 'is_smoker', 'Over+wt']) + + with pytest.raises(AssertionError): + from_formula('Systolic-BP ~ *', bad_df, coerce=False) + + assert from_formula('Systolic_BP ~ *', bad_df) + + def test_wrong_length(): """iterable params must all match lengths """