diff --git a/fred/api.py b/fred/api.py index 7a0af5a..ecd9ffe 100644 --- a/fred/api.py +++ b/fred/api.py @@ -15,68 +15,68 @@ def key(api_key): # Category ##################### -def category(**kwargs): +def category(session=None, **kwargs): """Get a category.""" if 'series' in kwargs: kwargs.pop('series') path = 'series' else: path = None - return Fred().category(path, **kwargs) + return Fred(session=session).category(path, **kwargs) -def categories(identifier, **kwargs): +def categories(identifier, session=None, **kwargs): """Just in case someone misspells the method.""" kwargs['category_id'] = identifier - return category(**kwargs) + return category(session=session, **kwargs) -def children(category_id=None, **kwargs): +def children(category_id=None, session=None, **kwargs): """Get child categories for a specified parent category.""" kwargs['category_id'] = category_id - return Fred().category('children', **kwargs) + return Fred(session=session).category('children', **kwargs) -def related(identifier, **kwargs): +def related(identifier, session=None, **kwargs): """Get related categories for a specified category.""" kwargs['category_id'] = identifier - return Fred().category('related', **kwargs) + return Fred(session=session).category('related', **kwargs) -def category_series(identifier, **kwargs): +def category_series(identifier, session=None, **kwargs): """Get the series in a category.""" kwargs['category_id'] = identifier - return Fred().category('series', **kwargs) + return Fred(session=session).category('series', **kwargs) ##################### # Releases ##################### -def release(release_id, **kwargs): +def release(release_id, session=None, **kwargs): """Get the release of economic data.""" kwargs['release_id'] = release_id - return Fred().release(**kwargs) + return Fred(session=session).release(**kwargs) -def releases(release_id=None, **kwargs): +def releases(release_id=None, session=None, **kwargs): """Get all releases of economic data.""" if not 'id' in kwargs and release_id is not None: kwargs['release_id'] = release_id return Fred().release(**kwargs) - return Fred().releases(**kwargs) + return Fred(session=session).releases(**kwargs) -def dates(**kwargs): +def dates(session=None, **kwargs): """Get release dates for economic data.""" - return Fred().releases('dates', **kwargs) + return Fred(session=session).releases('dates', **kwargs) ##################### # Series ##################### -def series(identifier=None, **kwargs): +def series(identifier=None, session=None, **kwargs): """Get an economic data series.""" if identifier: kwargs['series_id'] = identifier @@ -88,40 +88,40 @@ def series(identifier=None, **kwargs): path = 'release' else: path = None - return Fred().series(path, **kwargs) + return Fred(session=session).series(path, **kwargs) -def observations(identifier, **kwargs): +def observations(identifier, session=None, **kwargs): """Get an economic data series.""" kwargs['series_id'] = identifier - return Fred().series('observations', **kwargs) + return Fred(session=session).series('observations', **kwargs) -def search(text, **kwargs): +def search(text, session=None, **kwargs): """Get economic data series that match keywords.""" kwargs['search_text'] = text - return Fred().series('search', **kwargs) + return Fred(session=session).series('search', **kwargs) -def updates(**kwargs): +def updates(session=None, **kwargs): """Get economic data series sorted in descending order.""" - return Fred().series('updates', **kwargs) + return Fred(session=session).series('updates', **kwargs) -def vintage(identifier, **kwargs): +def vintage(identifier, session=None, **kwargs): """ Get the dates in history when a series' data values were revised or new data values were released. """ kwargs['series_id'] = identifier - return Fred().series('vintagedates', **kwargs) + return Fred(session=session).series('vintagedates', **kwargs) ##################### # Sources ##################### -def source(source_id=None, **kwargs): +def source(source_id=None, session=None, **kwargs): """Get a source of economic data.""" if source_id is not None: kwargs['source_id'] = source_id @@ -133,11 +133,11 @@ def source(source_id=None, **kwargs): path = 'releases' else: path = None - return Fred().source(path, **kwargs) + return Fred(session=session).source(path, **kwargs) -def sources(source_id=None, **kwargs): +def sources(source_id=None, session=None, **kwargs): """Get the sources of economic data.""" if source_id or 'id' in kwargs: return source(source_id, **kwargs) - return Fred().sources(**kwargs) + return Fred(session=session).sources(**kwargs) diff --git a/fred/core.py b/fred/core.py index b6a23e8..bf043b3 100644 --- a/fred/core.py +++ b/fred/core.py @@ -21,14 +21,20 @@ class Fred(object): """An easy-to-use Python wrapper over the St. Louis FRED API.""" - def __init__(self, api_key='', xml_output=False): + def __init__(self, api_key='', xml_output=False, session=None): if 'FRED_API_KEY' in os.environ: self.api_key = os.environ['FRED_API_KEY'] else: self.api_key = api_key + self.session = self._init_session(session) self.xml = xml_output self.endpoint = 'https://api.stlouisfed.org/fred/' + def _init_session(self, session): + if session is None: + session = requests.Session() + return session + def _create_path(self, *args): """Create the URL path with the Fred endpoint and given arguments.""" args = filter(None, args) @@ -40,7 +46,7 @@ def get(self, *args, **kwargs): location = args[0] params = self._get_keywords(location, kwargs) url = self._create_path(*args) - request = requests.get(url, params=params) + request = self.session.get(url, params=params) content = request.content self._request = request return self._output(content)