diff --git a/.flake8 b/.flake8 index ab31112..8f9ea55 100644 --- a/.flake8 +++ b/.flake8 @@ -9,10 +9,10 @@ ignore = W504, # missing whitespace around arithmetic operator E226, - # Import sorting - I201 - I100 exclude= .git, + .github, + .venv, venv, + tests/test_data, diff --git a/README.md b/README.md index 3d61d4f..3d66263 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![codecov](https://codecov.io/gh/CardiacModelling/syncropatch_export/graph/badge.svg?token=HOL0FrpGqs)](https://codecov.io/gh/CardiacModelling/syncropatch_export) This repository contains a python package and scripts for processing data outputted from Nanion SynroPatch 384. + With this package you can export each sweep of each protocol for each well as individual files (.csv). Meta-data describing the protocol, and variables such as membrance capacitance (Cm), Rseries and Rseal can be exported. @@ -41,3 +42,39 @@ Then you can run the tests. ``` python3 -m unittest ``` + +## Usage example + +...TODO + + +## Development + +Commits should be merged in via pull requests. + +Tests are written using the standard [unittest](https://docs.python.org/3.13/library/unittest.html) framework. + +Online testing, style-checking, and coverage testing is set up using GitHub actions. +Coverage testing is handled via [Codecov](https://about.codecov.io/). + +Style testing is done with `flake8`. For example, to test with 4 subprocesses use +``` +flake8 -j4 +``` +Import sorting can be checked with `isort`: +``` +isort --verbose --check-only --diff syncropatch_export tests setup.py +``` + +Documentation is implemented using [Sphinx](https://www.sphinx-doc.org/). +To compile locally, first install the required dependencies +``` +pip install -e .'[docs]' +``` +and then use Make +``` +cd docs +make html +``` + + diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..a007fea --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +build/* diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..e9bd66d --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = syncropatch_export +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..481a7ec --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=syncropatch_export + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/source/_static/placeholder b/docs/source/_static/placeholder new file mode 100644 index 0000000..8566aa9 --- /dev/null +++ b/docs/source/_static/placeholder @@ -0,0 +1 @@ +Images etc. can be placed here diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..d708f16 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import syncropatch_export + + +# -- General configuration ---------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', +] + +# Autodoc defaults +autodoc_default_options = { + 'members': None, + # 'inherited-members': None, +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'Syncropatch export' +# copyright = syncropatch_export.COPYRIGHT + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = syncropatch_export.__version__ +# The full version, including alpha/beta/rc tags. +release = syncropatch_export.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'default' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'alabaster' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Add any paths that contain custom themes here, relative to this directory. +html_theme_path = ['_templates'] + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +html_show_sphinx = False + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +html_show_copyright = False + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'SyncropatchExpertDoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass[howto/manual]). +latex_documents = [( + 'index', 'syncropatch_export.tex', u'Syncropatch Export Documentation', + u'Mixed', 'manual' +)] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [( + 'index', 'syncropatch_export', u'Syncropatch Export Documentation', + [u'Syncropatch Export Team'], 1 +)] + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'syncropatch_export', u'Syncropatch Export Documentation', + 'Syncropatch Export Team', 'Syncropatch Export', + 'Exports Nanion Syncropatch data to CSV.', + 'Miscellaneous'), +] + diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..f673e5c --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,15 @@ +***************** +Table of contents +***************** + +.. module:: syncropatch_export + +This module contains methods to export data from the Nanion SynroPatch 384. + +.. toctree:: + :maxdepth: 2 + + self + trace + voltage_protocols + diff --git a/docs/source/trace.rst b/docs/source/trace.rst new file mode 100644 index 0000000..083b087 --- /dev/null +++ b/docs/source/trace.rst @@ -0,0 +1,10 @@ +.. currentmodule:: syncropatch_export.trace + +************** +Accessing data +************** + +Data is accessed and exported via the :class:`Trace` class. + +.. autoclass:: Trace + diff --git a/docs/source/voltage_protocols.rst b/docs/source/voltage_protocols.rst new file mode 100644 index 0000000..0bab766 --- /dev/null +++ b/docs/source/voltage_protocols.rst @@ -0,0 +1,11 @@ +.. currentmodule:: syncropatch_export.voltage_protocols + +*************************** +Accessing voltage protocols +*************************** + +Voltage protocols can be parsed from the JSON data using the +:class:`VoltageProtocol` class. + +.. autoclass:: VoltageProtocol + diff --git a/setup.py b/setup.py index 1d1727b..af05e42 100644 --- a/setup.py +++ b/setup.py @@ -4,14 +4,17 @@ with open('README.md') as f: readme = f.read() -# Load version number -# with open('version.txt', 'r') as f: -# version = f.read() -# Go! +# Load version number +import os # isort:skip +import sys # isort:skip +sys.path.append(os.path.abspath('syncropatch_export')) +from _version import __version__ as version # noqa isort:skip +sys.path.pop() +del os, sys -version = '0.0.1' +# Go! setup( # Module name (lowercase) name='syncropatch_export', @@ -57,5 +60,8 @@ 'mock>=3.0.5', # For mocking command line args etc. 'codecov>=2.1.3', ], + 'docs': [ + 'sphinx>=1.7.4', + ], }, ) diff --git a/syncropatch_export/__init__.py b/syncropatch_export/__init__.py index e69de29..53ae12e 100644 --- a/syncropatch_export/__init__.py +++ b/syncropatch_export/__init__.py @@ -0,0 +1,5 @@ +# +# Syncropatch, main module +# +from ._version import __version__ # noqa + diff --git a/syncropatch_export/_version.py b/syncropatch_export/_version.py new file mode 100644 index 0000000..106c18b --- /dev/null +++ b/syncropatch_export/_version.py @@ -0,0 +1,4 @@ +# +# Syncropatch_export version number +# +__version__ = '0.0.1' diff --git a/syncropatch_export/trace.py b/syncropatch_export/trace.py index 380c8e9..cb144a6 100644 --- a/syncropatch_export/trace.py +++ b/syncropatch_export/trace.py @@ -9,20 +9,33 @@ class Trace: - """ Defines a Trace object from the output of a Nanion experiment. + """ + Reads a Nanion experiment and provides access to the data and meta data it + contains. + + To create a :class:`Trace`, a directory name should be passed in, along + with the name of a JSON file within that directory containing the meta + data. Data can then be accessed using :meth:`get_all_traces` (to obtain a + ``dict`` mapping well names onto a 2-d numpy array containing the sampled + currents (in pA) for all sweeps. + + Well are + + - @params - filepath: path pointing to folder containing .json and .dat files (str) - json_file: specific filename of json file (str) + Args: + filepath (str): A path pointing to folder containing both ``.json`` and + ``.dat`` files. + json_file (str): The name of a JSON file within ``path``, from which + meta data will be read. """ def __init__(self, filepath, json_file: str): # store file paths self.filepath = filepath - if json_file[-5:] == '.json': - self.json_file = json_file - else: - self.json_file = json_file + ".json" + self.json_file = json_file + if not json_file.endswith('.json'): + self.json_file += '.json' # load json file with open(os.path.join(self.filepath, self.json_file)) as f: @@ -41,6 +54,13 @@ def __init__(self, filepath, json_file: str): self.MeasurementLayout = TraceHeader['MeasurementLayout'] self.FileInformation = TraceHeader['FileInformation'] + # Create (hardcoded) list-of-list of well names: + # [['A01', 'B01', ..., 'P01'], + # ['A02', 'B02', ..., 'P02'], + # ... + # ['A24', 'B24', ..., 'P24']] + # So a list of 24 lists with 16 entries each + # self.WELL_ID = np.array([ [lab + str(i).zfill(2) for lab in string.ascii_uppercase[:16]] for i in range(1, 25)]) @@ -59,64 +79,77 @@ def __init__(self, filepath, json_file: str): self.voltage_protocol = self.get_voltage_protocol() - def get_voltage_protocol(self, holding_potential=-80.0): - """Extract information about the voltage protocol from the json file - - returns: a VoltageProtocol object - + def get_voltage_protocol(self): """ + Extract information about the voltage protocol from the JSON file. - voltage_protocol = VoltageProtocol.from_json( + Returns: + The :class:`VoltageProtocol` used in this experiment. + """ + return VoltageProtocol.from_json( self.meta['ExperimentConditions']['VoltageProtocol'], self.meta['ExperimentConditions']['VMembrane_mV'] ) - return voltage_protocol - def get_voltage_protocol_json(self): """ - Returns the voltage protocol as a JSON object + Returns an unparsed JSON object representing the first segment of the + voltage protocol. """ + # TODO Why only the first row? return self.meta['ExperimentConditions']['VoltageProtocol'][0] - def get_protocol_description(self, holding_potential=-80.0): - """Get the protocol as a numpy array describing the voltages and - durations for each section + def get_protocol_description(self): + """ + Returns the protocol as an ``np.numpy`` with an entry for each segment. - returns: np.array where each row contains the start time, end time, - initial voltage, and final voltage + Returns: + A numpy ``array`` where each row contains the start time, end time, + initial voltage, and final voltage of a ramp or step segment. """ return self.get_voltage_protocol().get_all_sections() def get_voltage(self): - ''' - Returns the voltage stimulus from Nanion .json file - ''' - return np.array(self.TimeScaling['Stimulus']).astype(np.float64)\ - * 1e3 + """ + Returns an array containing voltages (in mV) for each sampled point in + the traces. + """ + return np.array(self.TimeScaling['Stimulus']).astype(np.float64) * 1e3 def get_times(self): - ''' - Returns the time steps from Nanion .json file - ''' + """ + Returns the sampled times (in ms). + """ return np.array(self.TimeScaling['TR_Time']) * 1e3 def get_all_traces(self, leakcorrect=False): - ''' + """ + Returns data for all wells and all sweeps (equivalent to calling + :meth:`get_trace_sweeps()` without any arguments). + + Current is returned in pA. + + By default, the data is returned without leak correction, but the leak + corrected data can be obtained by setting ``leakcorrect=True``. - Params: - leakcorrect: Bool. Set to true if using onboard leak correction + Args: + sweeps (int): The number of sweeps to return. + leakcorrect (bool): Used to choose corrected or uncorrected data. - Returns: all raw current traces from .dat files + Returns: + A dictionary mapping well names (e.g. "A01") onto 2-d numpy arrays + of shape ``n_sweeps, n_times`` where ``n_sweeps`` is the number of + sweeps and ``n_times`` is the number of sampled points. - ''' + """ return self.get_trace_sweeps(leakcorrect=leakcorrect) def get_trace_file(self, sweeps): - ''' - Returns the trace file index of the file for a given set of sweeps - ''' + """ + Returns the trace file index + of the file for a given set of sweeps + """ OUT_file_idx = [] OUT_idx_i = [] for actSweep in sweeps: @@ -133,9 +166,24 @@ def get_trace_file(self, sweeps): return OUT_file_idx, OUT_idx_i def get_trace_sweeps(self, sweeps=None, leakcorrect=False): - ''' - Returns a subset of sweeps defined by the input 'sweeps' - ''' + """ + Returns the first ``sweeps`` sweeps, for all wells. + + Current is returned in pA. + + By default, the data is returned without leak correction, but the leak + corrected data can be obtained by setting ``leakcorrect=True``. + + Args: + sweeps (list): A list of sweep indexes to return, e.g. ``[0, 1, 2]``. + leakcorrect (bool): Used to choose corrected or uncorrected data. + + Returns: + A dictionary mapping well names (e.g. "A01") onto 2-d numpy arrays + of shape ``n_sweeps, n_times`` where ``n_sweeps`` is the number of + sweeps and ``n_times`` is the number of sampled points. + + """ # initialise output out_dict = {} @@ -143,18 +191,22 @@ def get_trace_sweeps(self, sweeps=None, leakcorrect=False): for ijWell in iCol: out_dict[ijWell] = [] + # No sweeps selected? Then return full set if sweeps is None: - #  Sometimes NofSweeps seems to be incorrect sweeps = list(range(self.NofSweeps)) - - # check `getsweep` input is something sensible - if len(sweeps) > self.NofSweeps: - raise ValueError('Required #sweeps > total #sweeps.') - - # convert negative values to positive - for i, sweep in enumerate(sweeps): - if sweep < 0: - sweeps[i] = self.NofSweeps + sweep + else: + # Allow negative values to index later sweeps + sweeps = [self.NofSweeps + x if x < 0 else x for x in sweeps] + # Check all sweeps exist + if max(sweeps) >= self.NofSweeps: + raise ValueError( + f'Invalid sweep selection: sweep {max(sweeps)} requested,' + f' but only {self.NofSweeps} available.') + if min(sweeps) < 0: + raise ValueError( + f'Invalid sweep selection: sweep' + f' {min(sweeps) - self.NofSweeps} requested, but only' + f' {self.NofSweeps} available.') trace_file_idxs, idx_is = self.get_trace_file(sweeps) @@ -188,9 +240,7 @@ def get_trace_sweeps(self, sweeps=None, leakcorrect=False): # convert to double in pA iColTraces = trace[idx_i:idx_f] * self.I2DScale[i] * 1e12 - iColWells = self.WELL_ID[i] - - for j, ijWell in enumerate(iColWells): + for j, ijWell in enumerate(self.WELL_ID[i]): if leakcorrect: leakoffset = 1 else: @@ -221,13 +271,15 @@ def get_trace_sweeps(self, sweeps=None, leakcorrect=False): return out_dict def get_onboard_QC_values(self, sweeps=None): - '''Read quality control values Rseal, Cslow (Cm), and Rseries from a Nanion .json file + """ + Return the quality control values Rseal, Cslow (Cm), and Rseries. - returns: A dictionary where the keys are the well e.g. 'A01' and the - values are the values used for onboard QC i.e., the seal resistance, - cell capacitance and the series resistance. + Returns: + A dict mapping well names ('A01' up to 'P24') to tuples + ``(R_seal, Cm, R_series)`` containing the seal resistance, membrane + capacitance, and series resistance. - ''' + """ # load QC values RSeal = np.array(self.meta['QCData']['RSeal']) @@ -262,11 +314,12 @@ def get_onboard_QC_values(self, sweeps=None): return out_dict def get_onboard_QC_df(self, sweeps=None): - """Create a Pandas DataFrame which lists the Rseries, memebrane - capacitance and Rseries for each well and sweep. + """ + Create a Pandas DataFrame containing the seal resistance, membrane + capacitance, and series resistance for each well and sweep. - @Returns A pandas.DataFrame describing the onboard QC estimates for - each well, sweep + Returns: + A ``pandas.DataFrame`` with the onboard QC estimates. """ @@ -288,3 +341,4 @@ def get_onboard_QC_df(self, sweeps=None): df_rows.append(df_row) return pd.DataFrame.from_records(df_rows) + diff --git a/syncropatch_export/voltage_protocols.py b/syncropatch_export/voltage_protocols.py index bdb0265..fe6c59f 100644 --- a/syncropatch_export/voltage_protocols.py +++ b/syncropatch_export/voltage_protocols.py @@ -1,83 +1,112 @@ import numpy as np -class VoltageProtocol(): - def from_json(json_protocol, holding_potential): - """ Converts a protocol (from the json file) into a np.array +class VoltageProtocol: + """ + Represent a voltage step and ramp protocol. + Each protocol is represented as + + 1. A list of segment starts, ends, initial voltages, and final voltages + 2. A holding potential + + To create a :class:`VoltageProtocol`, use either + :meth:`VoltageProtocol.from_json` or + `meth:`VoltageProtocol.from_voltage_trace`. + """ + + def __init__(self, desc, holding_potential, copy_data=True): + self._desc = np.copy(desc) if copy_data else desc + self.holding_potential = holding_potential + + @classmethod + def from_json(cls, json_protocol, holding_potential): """ + Reads a protocol from a JSON file. + + Args: + json_protocol (list): A list or other sequence containing the + ``VoltageProtocol`` section from the JSON file. + holding_potential (float): The holding potential + """ output_sections = [] for section in json_protocol: tstart = float(section['SegmentStart_ms']) tdur = float(section['Duration ms']) vstart = float(section['VoltageStart']) vend = float(section['VoltageEnd']) + output_sections.append((tstart, tstart + tdur, vstart, vend)) + return cls(np.array(output_sections), holding_potential, False) - output_sections.append(np.array((tstart, tstart + tdur, - vstart, vend))) - - return VoltageProtocol(np.array(output_sections), - holding_potential=holding_potential) - - def from_voltage_trace(voltage_trace, times, holding_potential=-80.0): + @classmethod + def from_voltage_trace(cls, voltage_trace, times, holding_potential=-80.0): + """ + Creates an approximate voltage protocol from a time series ``(times, + voltage_trace)``. + """ threshold = 1e-3 # Find gradient changes diff2 = np.abs(np.diff(voltage_trace, n=2)) - - windows = np.argwhere(diff2 > threshold).flatten() - window_locs = np.unique(windows) - window_locs = np.array([val for val in window_locs if val + 1 - not in window_locs]) + 1 - - windows = zip([0] + list(window_locs), list(window_locs) - + [len(voltage_trace) - 1]) + window_locs = np.unique(np.argwhere(diff2 > threshold).flatten()) + window_locs = 1 + np.array([ + val for val in window_locs if val + 1 not in window_locs]) + windows = zip([0] + list(window_locs), + list(window_locs) + [len(voltage_trace) - 1]) lst = [] for start, end in windows: start_t = times[start] end_t = times[end] - - ramp = voltage_trace[end - 1] != voltage_trace[start] - v_start = voltage_trace[start] - - if ramp: + if voltage_trace[end - 1] != voltage_trace[start]: + # Ramp grad = (voltage_trace[end - 1] - voltage_trace[start]) / \ (times[end - 1] - times[start]) v_end = v_start + grad * (end_t - start_t) else: + # Step v_end = voltage_trace[end - 1] lst.append(np.array([start_t, end_t, v_start, v_end])) - desc = np.vstack(lst) - return VoltageProtocol(desc, holding_potential) - - def __init__(self, desc, holding_potential): - self._desc = desc - self.holding_potential = holding_potential + return cls(np.vstack(lst), holding_potential, False) def get_holding_potential(self): + """ Returns this protocol's holding potential. """ return self.holding_potential def get_step_start_times(self): + """ Returns a list of all segment start times. """ return [line[0] for line in self._desc] def get_ramps(self): + """ + Returns all segments that are ramps. + + Each segment is represented as ``(start time, end time, start voltage, + end voltage)``. + """ return [line for line in self._desc if line[2] != line[3]] def get_all_sections(self): - """ Return a np.array describing the protocol. - - returns: an np.array where the ith row is the start-time, - end-time, start-voltage and end-voltage for the ith section of the protocol - + """ + Return an ``np.array`` describing the protocol, where each row in the + array contains the the start time, end time, initial voltage and final + voltage for a segment of the protocol. """ return np.array(self._desc) def export_txt(self, fname): + """ + Writes a partial textual representation of this protocol to a file. + + The created file will have a header line, followed by one line per + segment. Segments are represented as "Type" (Set or Ramp), "Voltage" + (the final voltage of a segment), and "Duration". + """ + output_lines = ['Type \t Voltage \t Duration'] desc = self.get_all_sections() @@ -93,7 +122,7 @@ def export_txt(self, fname): if round: vend = np.round(vend) - output_lines.append(f"{_type}\t{vend}\t{dur}") + output_lines.append(f'{_type}\t{vend}\t{dur}') with open(fname, 'w') as fout: for line in output_lines: diff --git a/tests/test_trace_class.py b/tests/test_trace_class.py old mode 100644 new mode 100755 index 0b4901b..bee8289 --- a/tests/test_trace_class.py +++ b/tests/test_trace_class.py @@ -1,32 +1,31 @@ +#!/usr/bin/env python import json import os +import tempfile import unittest -import matplotlib.pyplot as plt import numpy as np import pandas as pd -from syncropatch_export.trace import Trace as tr +from syncropatch_export.trace import Trace from syncropatch_export.voltage_protocols import VoltageProtocol class TestTraceClass(unittest.TestCase): - def setUp(self): - filepath = os.path.join('tests', 'test_data', '13112023_MW2_FF', - 'staircaseramp (2)_2kHz_15.01.07') - json_file = "staircaseramp (2)_2kHz_15.01.07" - - self.output_dir = os.path.join('test_output', 'test_trace_class') + """ + Tests both the Trace and VoltageProtocol classes. + """ - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - self.test_trace = tr(filepath, json_file) + def setUp(self): + f = 'staircaseramp (2)_2kHz_15.01.07' + self.trace = Trace( + os.path.join('tests', 'test_data', '13112023_MW2_FF', f), f) def test_protocol_descriptions(self): - voltages = self.test_trace.get_voltage() - times = self.test_trace.get_times() + voltages = self.trace.get_voltage() + times = self.trace.get_times() - protocol_from_json = self.test_trace.get_voltage_protocol() + protocol_from_json = self.trace.get_voltage_protocol() holding_potential = protocol_from_json.get_holding_potential() protocol_desc = VoltageProtocol.from_voltage_trace(voltages, times, holding_potential) @@ -42,48 +41,65 @@ def test_protocol_descriptions(self): self.assertLess(t_error, 1e-2) self.assertLess(v_error, 1e-4) - def test_protocol_export(self): - protocol = self.test_trace.get_voltage_protocol() - protocol.export_txt(os.path.join(self.output_dir, 'protocol.txt')) - json_protocol = self.test_trace.get_voltage_protocol_json() + def test_get_protocol_description(self): + a = np.array(self.trace.get_protocol_description()) + b = np.array(self.trace.get_voltage_protocol().get_all_sections()) + self.assertEqual(a.shape, b.shape) + self.assertTrue(np.all(a == b)) - with open(os.path.join(self.output_dir, 'protocol.json'), 'w') as fin: - json.dump(json_protocol, fin) + def test_protocol_export(self): + with tempfile.TemporaryDirectory() as d: + protocol = self.trace.get_voltage_protocol() + protocol.export_txt(os.path.join(d, 'protocol.txt')) + json_protocol = self.trace.get_voltage_protocol_json() + with open(os.path.join(d, 'protocol.json'), 'w') as fin: + json.dump(json_protocol, fin) def test_protocol_timeseries(self): - voltages = self.test_trace.get_voltage() - times = self.test_trace.get_times() - - voltage_protocol = self.test_trace.get_voltage_protocol() + voltages = self.trace.get_voltage() + times = self.trace.get_times() + voltage_protocol = self.trace.get_voltage_protocol() def voltage_func(t): for tstart, tend, vstart, vend in voltage_protocol.get_all_sections(): if t >= tstart and t < tend: if vstart != vend: - return vstart + (vend - vstart) * (t - tstart)/(tend - tstart) + return vstart + (vend - vstart) * (t - tstart) / (tend - tstart) else: return vstart - - return voltage_protocol.get_holding_potential() + return voltage_protocol.get_holding_potential() # pragma: no cover for t, v in zip(times, voltages): self.assertLess(voltage_func(t) - v, 1e-3) + def test_protocol_get_step_start_times(self): + a = list(self.trace.get_voltage_protocol().get_step_start_times()) + b = [0, 250, 300, 696, 896, 1896, 2396, 3396, 3896, 4396, 4896, 5396, + 5896, 6396, 6896, 7396, 7896, 8396, 8896, 9396, 9896, 10396, + 10896, 11396, 11896, 12396, 12896, 13896, 14396, 14406, 14502, + 14892] + self.assertEqual(a, b) + + def test_protocol_get_ramps(self): + a = np.array(self.trace.get_voltage_protocol().get_ramps()) + b = np.array([[300, 696, -120, -80], [14406, 14502, -70, -110]]) + self.assertEqual(a.shape, b.shape) + self.assertTrue(np.all(a == b)) + def test_get_QC(self): - tr = self.test_trace - QC_values = tr.get_onboard_QC_values() + QC_values = self.trace.get_onboard_QC_values() self.assertGreater(len(QC_values), 0) - df = tr.get_onboard_QC_df() + df = self.trace.get_onboard_QC_df() self.assertGreater(df.shape[0], 0) self.assertGreater(df.shape[1], 0) def test_get_traces(self): - tr = self.test_trace - v = tr.get_voltage() - ts = tr.get_times() - all_traces = tr.get_all_traces(leakcorrect=True) - all_traces = tr.get_all_traces() + v = self.trace.get_voltage() + ts = self.trace.get_times() + all_traces = self.trace.get_all_traces(leakcorrect=True) + all_traces = self.trace.get_all_traces() + # TODO: Check the output, numerically, by comparing a few points self.assertTrue(np.all(np.isfinite(v))) self.assertTrue(np.all(np.isfinite(ts))) @@ -91,27 +107,45 @@ def test_get_traces(self): for well, trace in all_traces.items(): self.assertTrue(np.all(np.isfinite(trace))) - if self.output_dir: - # plot test output + # Test complex sweep selection + a = self.trace.get_trace_sweeps([-1, -2]) + b = self.trace.get_trace_sweeps([1, 0]) + self.assertEqual(len(a), len(b)) + self.assertTrue(np.all(a['A01'] == b['A01'])) + + # Test asking for non-existent sweeps + self.assertRaisesRegex(ValueError, 'Invalid sweep selection', + self.trace.get_trace_sweeps, [2]) + self.assertRaisesRegex(ValueError, 'Invalid sweep selection', + self.trace.get_trace_sweeps, [-3]) + + ''' + # plot test output + if False: + d = 'test_output' + if not os.path.exists(d): + os.makedirs(d) + + import matplotlib.pyplot as plt fig, (ax1, ax2) = plt.subplots(2, 1) - ax1.set_title("Example Sweeps") - some_sweeps = tr.get_trace_sweeps([0])['A01'] + ax1.set_title('Example Sweeps') + some_sweeps = self.trace.get_trace_sweeps([0])['A01'] ax1.plot(ts, np.transpose(some_sweeps), color='grey', alpha=0.5) ax1.set_ylabel('Current') ax1.set_xlabel('Time') - ax2.set_title("Voltage Protocol") + ax2.set_title('Voltage Protocol') ax2.plot(ts, v) ax2.set_ylabel('Voltage') ax2.set_xlabel('Time') plt.tight_layout() - plt.savefig(os.path.join(self.output_dir, - 'example_trace')) + plt.savefig(os.path.join(d, 'example_trace')) plt.close(fig) + ''' def test_qc_df(self): - dfs = [self.test_trace.get_onboard_QC_df(sweeps=[0]), - self.test_trace.get_onboard_QC_df(sweeps=None)] + dfs = [self.trace.get_onboard_QC_df(sweeps=[0]), + self.trace.get_onboard_QC_df(sweeps=None)] for res in dfs: # Check res is a pd.DataFrame self.assertIsInstance(res, pd.DataFrame) @@ -126,3 +160,6 @@ def test_qc_df(self): # Check restricting number of sweeps returns less data self.assertLess(dfs[0].shape[0], dfs[1].shape[0]) + +if __name__ == '__main__': + unittest.main() # pragma: no cover