# -*- coding: utf-8 -*-
"""
:class:`~pycroscopy.analysis.fitter.Fitter` - Abstract class that provides the
framework for building application-specific children classes
Created on Thu Aug 15 11:48:53 2019
@author: Suhas Somnath
"""
from __future__ import division, print_function, absolute_import, \
unicode_literals
import numpy as np
from warnings import warn
import joblib
from scipy.optimize import least_squares
from pyUSID.processing.comp_utils import recommend_cpu_cores
from pyUSID.processing.process import Process
from pyUSID.io.usi_data import USIDataset
# TODO: All reading, holding operations should use Dask arrays
[docs]
class Fitter(Process):
[docs]
def __init__(self, h5_main, proc_name, variables=None, **kwargs):
"""
Creates a new instance of the abstract Fitter class
Parameters
----------
h5_main : h5py.Dataset or pyUSID.io.USIDataset object
Main datasets whose one or dimensions will be reduced
proc_name : str or unicode
Name of the child process
variables : str or list, optional
List of spectroscopic dimension names that will be reduced
h5_target_group : h5py.Group, optional. Default = None
Location where to look for existing results and to place newly
computed results. Use this kwarg if the results need to be written
to a different HDF5 file. By default, this value is set to the
parent group containing `h5_main`
kwargs : dict
Keyword arguments that will be passed on to
pyUSID.processing.process.Process
"""
super(Fitter, self).__init__(h5_main, proc_name, **kwargs)
# Validate other arguments / kwargs here:
if variables is not None:
if isinstance(variables, str):
variables = [variables]
if not isinstance(variables, (list, tuple)):
raise TypeError('variables should be a string / list or tuple'
'of strings. Provided object was of type: {}'
''.format(type(variables)))
if not all([dim in self.h5_main.spec_dim_labels for dim in variables]):
raise ValueError('Provided dataset does not appear to have the'
' spectroscopic dimension(s): {} that need '
'to be fitted: {}'
''.format(self.h5_main.spec_dim_labels,
variables))
# Variables specific to Fitter
self._guess = None
self._fit = None
self._is_guess = True
self._h5_guess = None
self._h5_fit = None
self.__set_up_called = False
# Variables from Process:
self.compute = self.set_up_guess
self._unit_computation = super(Fitter, self)._unit_computation
self._create_results_datasets = self._create_guess_datasets
self._map_function = None
def _read_guess_chunk(self):
"""
Returns a chunk of guess dataset corresponding to the same pixels of
the main dataset.
"""
curr_pixels = self._get_pixels_in_current_batch()
self._guess = self._h5_guess[curr_pixels, :]
if self.verbose and self.mpi_rank == 0:
print('Guess of shape: {}'.format(self._guess.shape))
def _write_results_chunk(self):
"""
Writes the guess or fit results into appropriate HDF5 datasets.
"""
if self._is_guess:
targ_dset = self._h5_guess
source_dset = self._guess
else:
targ_dset = self._h5_fit
source_dset = self._fit
curr_pixels = self._get_pixels_in_current_batch()
if self.verbose and self.mpi_rank == 0:
print('Writing data of shape: {} and dtype: {} to position range: '
'{} in HDF5 dataset:{}'.format(source_dset.shape,
source_dset.dtype,
[curr_pixels[0],curr_pixels[-1]],
targ_dset))
targ_dset[curr_pixels, :] = source_dset
def _create_guess_datasets(self):
"""
Model specific call that will create the h5 group, empty guess dataset,
corresponding spectroscopic datasets and also link the guess dataset
to the spectroscopic datasets.
"""
raise NotImplementedError('Please override the _create_guess_datasets '
'specific to your model')
def _create_fit_datasets(self):
"""
Model specific call that will create the (empty) fit dataset, and
link the fit dataset to the spectroscopic datasets.
"""
raise NotImplementedError('Please override the _create_fit_datasets '
'specific to your model')
def _get_existing_datasets(self):
"""
Gets existing Guess, Fit, status datasets, from the HDF5 group.
All other domain-specific datasets should be loaded in the classes that
extend this class
"""
self._h5_guess = USIDataset(self.h5_results_grp['Guess'])
try:
self._h5_status_dset = self.h5_results_grp[self._status_dset_name]
except KeyError:
warn('status dataset not created yet')
self._h5_status_dset = None
try:
self._h5_fit = self.h5_results_grp['Fit']
self._h5_fit = USIDataset(self._h5_fit)
except KeyError:
self._h5_fit = None
if not self._is_guess:
self._create_fit_datasets()
[docs]
def do_guess(self, *args, override=False, **kwargs):
"""
Computes the Guess
Parameters
----------
args : list, optional
List of arguments
override : bool, optional
If True, computes a fresh guess even if existing Guess was found
Else, returns existing Guess dataset. Default = False
kwargs : dict, optional
Keyword arguments
Returns
-------
USIDataset
HDF5 dataset with the Guesses computed
"""
if not self.__set_up_called:
raise ValueError('Please call set_up_guess() before calling '
'do_guess()')
self.h5_results_grp = super(Fitter, self).compute(override=override)
# to be on the safe side, expect setup again
self.__set_up_called = False
return USIDataset(self.h5_results_grp['Guess'])
[docs]
def do_fit(self, *args, override=False, **kwargs):
"""
Computes the Fit
Parameters
----------
args : list, optional
List of arguments
override : bool, optional
If True, computes a fresh guess even if existing Fit was found
Else, returns existing Fit dataset. Default = False
kwargs : dict, optional
Keyword arguments
Returns
-------
USIDataset
HDF5 dataset with the Fit computed
"""
if not self.__set_up_called:
raise ValueError('Please call set_up_guess() before calling '
'do_guess()')
"""
Either delete or reset 'last_pixel' attribute to 0
This value will be used for filling in the status dataset.
"""
self.h5_results_grp.attrs['last_pixel'] = 0
self.h5_results_grp = super(Fitter, self).compute(override=override)
# to be on the safe side, expect setup again
self.__set_up_called = False
return USIDataset(self.h5_results_grp['Fit'])
def _reformat_results(self, results, strategy='wavelet_peaks'):
"""
Model specific restructuring / reformatting of the parallel compute
results
Parameters
----------
results : list or array-like
Results to be formatted for writing
strategy : str
The strategy used in the fit. Determines how the results will be
reformatted, if multiple strategies for guess / fit are available
Returns
-------
results : numpy.ndarray
Formatted array that is ready to be writen to the HDF5 file
"""
return np.array(results)
[docs]
def set_up_guess(self, h5_partial_guess=None):
"""
Performs necessary book-keeping before do_guess can be called
Parameters
----------
h5_partial_guess: h5py.Dataset or pyUSID.io.USIDataset, optional
HDF5 dataset containing partial Guess. Not implemented
"""
# TODO: h5_partial_guess needs to be utilized
if h5_partial_guess is not None:
raise NotImplementedError('Provided h5_partial_guess cannot be '
'used yet. Ask developer to implement')
# Set up the parms dict so everything necessary for checking previous
# guess / fit is ready
self._is_guess = True
self._status_dset_name = 'completed_guess_positions'
ret_vals = self._check_for_duplicates()
self.duplicate_h5_groups, self.partial_h5_groups = ret_vals
if self.verbose and self.mpi_rank == 0:
print('Groups with Guess in:\nCompleted: {}\nPartial:{}'.format(
self.duplicate_h5_groups, self.partial_h5_groups))
self._unit_computation = super(Fitter, self)._unit_computation
self._create_results_datasets = self._create_guess_datasets
self.compute = self.do_guess
self.__set_up_called = True
[docs]
def set_up_fit(self, h5_partial_fit=None, h5_guess=None):
"""
Performs necessary book-keeping before do_fit can be called
Parameters
----------
h5_partial_fit: h5py.Dataset or pyUSID.io.USIDataset, optional
HDF5 dataset containing partial Fit. Not implemented
h5_guess: h5py.Dataset or pyUSID.io.USIDataset, optional
HDF5 dataset containing completed Guess. Not implemented
"""
# TODO: h5_partial_guess needs to be utilized
if h5_partial_fit is not None or h5_guess is not None:
raise NotImplementedError('Provided h5_partial_fit cannot be '
'used yet. Ask developer to implement')
self._is_guess = False
self._map_function = None
self._unit_computation = None
self._create_results_datasets = self._create_fit_datasets
# Case 1: Fit already complete or partially complete.
# This is similar to a partial process. Leave as is
self._status_dset_name = 'completed_fit_positions'
ret_vals = self._check_for_duplicates()
self.duplicate_h5_groups, self.partial_h5_groups = ret_vals
if self.verbose and self.mpi_rank == 0:
print('Checking on partial / completed fit datasets')
print(
'Completed results groups:\n{}\nPartial results groups:\n'
'{}'.format(self.duplicate_h5_groups, self.partial_h5_groups))
# Case 2: Fit neither partial / completed. Search for guess.
# Most popular scenario:
if len(self.duplicate_h5_groups) == 0 and len(
self.partial_h5_groups) == 0:
if self.verbose and self.mpi_rank == 0:
print('No fit datasets found. Looking for Guess datasets')
# Change status dataset name back to guess to check for status
# on guesses:
self._status_dset_name = 'completed_guess_positions'
# Note that check_for_duplicates() will be against fit's parm_dict.
# So make a backup of that
fit_parms = self.parms_dict.copy()
# Set parms_dict to an empty dict so that we can accept any Guess
# dataset:
self.parms_dict = dict()
ret_vals = self._check_for_duplicates()
guess_complete_h5_grps, guess_partial_h5_grps = ret_vals
if self.verbose and self.mpi_rank == 0:
print(
'Guess datasets search resulted in:\nCompleted: {}\n'
'Partial:{}'.format(guess_complete_h5_grps,
guess_partial_h5_grps))
# Now put back the original parms_dict:
self.parms_dict.update(fit_parms)
# Case 2.1: At least guess is completed:
if len(guess_complete_h5_grps) > 0:
# Just set the last group as the current results group
self.h5_results_grp = guess_complete_h5_grps[-1]
if self.verbose and self.mpi_rank == 0:
print('Guess found! Using Guess in:\n{}'.format(
self.h5_results_grp))
# It will grab the older status default unless we set the
# status dataset back to fit
self._status_dset_name = 'completed_fit_positions'
# Get handles to the guess dataset. Nothing else will be found
self._get_existing_datasets()
elif len(guess_complete_h5_grps) == 0 and len(
guess_partial_h5_grps) > 0:
FileNotFoundError(
'Guess not yet completed. Please complete guess first')
return
else:
FileNotFoundError(
'No Guess found. Please complete guess first')
return
# We want compute to call our own manual unit computation function:
self._unit_computation = self._unit_compute_fit
self.compute = self.do_fit
self.__set_up_called = True
def _unit_compute_fit(self, obj_func, obj_func_args=[],
solver_options={'jac': 'cs'}):
"""
Performs least-squares fitting on self.data using self.guess for
initial conditions.
Results of the computation are captured in self._results
Parameters
----------
obj_func : callable
Objective function to minimize on
obj_func_args : list
Arguments required by obj_func following the guess parameters
(which should be the first argument)
solver_options : dict, optional
Keyword arguments passed onto scipy.optimize.least_squares
"""
# At this point data has been read in. Read in the guess as well:
self._read_guess_chunk()
if self.verbose and self.mpi_rank == 0:
print('_unit_compute_fit got:\nobj_func: {}\nobj_func_args: {}\n'
'solver_options: {}'.format(obj_func, obj_func_args,
solver_options))
# TODO: Generalize this bit. Use Parallel compute instead!
if self.mpi_size > 1:
if self.verbose:
print('Rank {}: About to start serial computation'
'.'.format(self.mpi_rank))
self._results = list()
for pulse_resp, pulse_guess in zip(self.data, self._guess):
curr_results = least_squares(obj_func, pulse_guess,
args=[pulse_resp] + obj_func_args,
**solver_options)
self._results.append(curr_results)
else:
cores = recommend_cpu_cores(self.data.shape[0],
verbose=self.verbose)
if self.verbose:
print('Starting parallel fitting with {} cores'.format(cores))
values = [joblib.delayed(least_squares)(obj_func, pulse_guess,
args=[pulse_resp] + obj_func_args,
**solver_options) for
pulse_resp, pulse_guess in zip(self.data, self._guess)]
self._results = joblib.Parallel(n_jobs=cores)(values)
if self.verbose and self.mpi_rank == 0:
print(
'Finished computing fits on {} objects. Results of length: {}'
'.'.format(self.data.shape[0], len(self._results)))
# What least_squares returns is an object that needs to be extracted
# to get the coefficients. This is handled by the write function