Source code for synospec.util.util

"""
Miscellaneous package utilities.

.. include:: ../include/links.rst
"""

from itertools import chain, combinations

from IPython import embed 

import numpy


[docs]def all_subclasses(cls): """ Collect all the subclasses of the provided class. The search follows the inheritance to the highest-level class. Intermediate base classes are included in the returned set, but not the base class itself. Thanks to: https://stackoverflow.com/questions/3862310/how-to-find-all-the-subclasses-of-a-class-given-its-name Args: cls (object): The base class Returns: :obj:`set`: The unique set of derived classes, including any intermediate base classes in the inheritance thread. """ return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in all_subclasses(c)])
[docs]def string_table(tbl, delimeter='print', has_header=True): """ Provided the array of data, format it with equally spaced columns and add a header (first row) and contents delimeter. Args: tbl (`numpy.ndarray`_): Array of string representations of the data to print. delimeter (:obj:`str`, optional): If the first row in the table containts the column headers (see ``has_header``), this sets the delimeter between first table row and the column data. Use ``'print'`` for a simple line of hyphens, anything else results in an ``rst`` style table formatting. has_header (:obj:`bool`, optional): The first row in ``tbl`` contains the column headers. Returns: :obj:`str`: Single long string with the data table. """ nrows, ncols = tbl.shape col_width = [numpy.amax([len(dij) for dij in dj]) for dj in tbl.T] _nrows = nrows start = 1 if delimeter != 'print': _nrows += 2 start += 1 if has_header: _nrows += 1 start += 1 row_string = ['']*_nrows for i in range(start,nrows+start-1): row_string[i] = ' '.join([tbl[1+i-start,j].ljust(col_width[j]) for j in range(ncols)]) if delimeter == 'print': # Heading row row_string[0] = ' '.join([tbl[0,j].ljust(col_width[j]) for j in range(ncols)]) # Delimiter if has_header: row_string[1] = '-'*len(row_string[0]) return '\n'.join(row_string)+'\n' # For an rst table row_string[0] = ' '.join([ '='*col_width[j] for j in range(ncols)]) row_string[1] = ' '.join([tbl[0,j].ljust(col_width[j]) for j in range(ncols)]) if has_header: row_string[2] = row_string[0] row_string[-1] = row_string[0] return '\n'.join(row_string)+'\n'