Source code for hexrd.fitgrains

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 22 19:04:10 2017

@author: bernier2
"""
import os
import logging
import multiprocessing
import numpy as np
import timeit
import warnings

from hexrd import instrument
from hexrd.transforms import xfcapi
from hexrd.fitting import fitGrain, objFuncFitGrain, gFlag_ref

logger = logging.getLogger(__name__)


# multiprocessing fit funcs
[docs]def fit_grain_FF_init(params): """ Broadcast the fitting parameters as globals for multiprocessing Parameters ---------- params : dict The dictionary of fitting parameters. Returns ------- None. Notes ----- See fit_grain_FF_reduced for specification. """ global paramMP paramMP = params
[docs]def fit_grain_FF_cleanup(): """ Tears down the global fitting parameters. """ global paramMP del paramMP
[docs]def fit_grain_FF_reduced(grain_id): """ Perform non-linear least-square fit for the specified grain. Parameters ---------- grain_id : int The grain id. Returns ------- grain_id : int The grain id. completeness : float The ratio of predicted to measured (observed) Bragg reflections. chisq: float Figure of merit describing the sum of squared residuals for each Bragg reflection in the form (x, y, omega) normalized by the total number of degrees of freedom. grain_params : array_like The optimized grain parameters [<orientation [3]>, <centroid [3]> <inverse stretch [6]>]. Notes ----- input parameters are [plane_data, instrument, imgser_dict, tth_tol, eta_tol, ome_tol, npdiv, threshold] """ grains_table = paramMP['grains_table'] plane_data = paramMP['plane_data'] instrument = paramMP['instrument'] imgser_dict = paramMP['imgser_dict'] tth_tol = paramMP['tth_tol'] eta_tol = paramMP['eta_tol'] ome_tol = paramMP['ome_tol'] npdiv = paramMP['npdiv'] refit = paramMP['refit'] threshold = paramMP['threshold'] eta_ranges = paramMP['eta_ranges'] ome_period = paramMP['ome_period'] analysis_dirname = paramMP['analysis_dirname'] prefix = paramMP['spots_filename'] spots_filename = None if prefix is None else prefix % grain_id grain = grains_table[grain_id] grain_params = grain[3:15] for tols in zip(tth_tol, eta_tol, ome_tol): complvec, results = instrument.pull_spots( plane_data, grain_params, imgser_dict, tth_tol=tols[0], eta_tol=tols[1], ome_tol=tols[2], npdiv=npdiv, threshold=threshold, eta_ranges=eta_ranges, ome_period=ome_period, dirname=analysis_dirname, filename=spots_filename, return_spot_list=False, quiet=True, check_only=False, interp='nearest') # ======= DETERMINE VALID REFLECTIONS ======= culled_results = dict.fromkeys(results) num_refl_tot = 0 num_refl_valid = 0 for det_key in culled_results: panel = instrument.detectors[det_key] ''' grab panel results: peak_id hkl_id hkl sum_int max_int, pred_angs, meas_angs, meas_xy ''' presults = results[det_key] nrefl = len(presults) # make data arrays refl_ids = np.empty(nrefl) max_int = np.empty(nrefl) for i, spot_data in enumerate(presults): refl_ids[i] = spot_data[0] max_int[i] = spot_data[4] valid_refl_ids = refl_ids >= 0 # find unsaturated spots on this panel unsat_spots = np.ones(len(valid_refl_ids), dtype=bool) if panel.saturation_level is not None: unsat_spots[valid_refl_ids] = \ max_int[valid_refl_ids] < panel.saturation_level idx = np.logical_and(valid_refl_ids, unsat_spots) # if an overlap table has been written, load it and use it overlaps = np.zeros_like(idx, dtype=bool) try: ot = np.load( os.path.join( analysis_dirname, os.path.join( det_key, 'overlap_table.npz' ) ) ) for key in ot.keys(): for this_table in ot[key]: these_overlaps = np.where( this_table[:, 0] == grain_id)[0] if len(these_overlaps) > 0: mark_these = np.array( this_table[these_overlaps, 1], dtype=int ) otidx = [ np.where(refl_ids == mt)[0] for mt in mark_these ] overlaps[otidx] = True idx = np.logical_and(idx, ~overlaps) # logger.info("found overlap table for '%s'", det_key) except(IOError, IndexError): # logger.info("no overlap table found for '%s'", det_key) pass # attach to proper dict entry # FIXME: want to avoid looping again here culled_results[det_key] = [presults[i] for i in np.where(idx)[0]] num_refl_tot += len(valid_refl_ids) num_refl_valid += sum(valid_refl_ids) pass # now we have culled data # CAVEAT: completeness from pullspots only; incl saturated and overlaps # <JVB 2015-12-15> try: completeness = num_refl_valid / float(num_refl_tot) except(ZeroDivisionError): raise RuntimeError( "simulated number of relfections is 0; " + "check instrument config or grain parameters" ) # ======= DO LEASTSQ FIT ======= if num_refl_valid <= 12: # not enough reflections to fit... exit warnings.warn( f'Not enough valid reflections ({num_refl_valid}) to fit, ' f'exiting', RuntimeWarning ) return grain_id, completeness, np.inf, grain_params else: grain_params = fitGrain( grain_params, instrument, culled_results, plane_data.latVecOps['B'], plane_data.wavelength ) # get chisq # TODO: do this while evaluating fit??? chisq = objFuncFitGrain( grain_params[gFlag_ref], grain_params, gFlag_ref, instrument, culled_results, plane_data.latVecOps['B'], plane_data.wavelength, ome_period, simOnly=False, return_value_flag=2) pass # end conditional on fit pass # end tolerance looping if refit is not None: # first get calculated x, y, ome from previous solution # NOTE: this result is a dict xyo_det_fit_dict = objFuncFitGrain( grain_params[gFlag_ref], grain_params, gFlag_ref, instrument, culled_results, plane_data.latVecOps['B'], plane_data.wavelength, ome_period, simOnly=True, return_value_flag=2) # make dict to contain new culled results culled_results_r = dict.fromkeys(culled_results) num_refl_valid = 0 for det_key in culled_results_r: presults = culled_results[det_key] if not presults: culled_results_r[det_key] = [] continue ims = next(iter(imgser_dict.values())) # grab first for the omes ome_step = sum(np.r_[-1, 1]*ims.metadata['omega'][0, :]) xyo_det = np.atleast_2d( np.vstack([np.r_[x[7], x[6][-1]] for x in presults]) ) xyo_det_fit = xyo_det_fit_dict[det_key] xpix_tol = refit[0]*panel.pixel_size_col ypix_tol = refit[0]*panel.pixel_size_row fome_tol = refit[1]*ome_step # define difference vectors for spot fits x_diff = abs(xyo_det[:, 0] - xyo_det_fit['calc_xy'][:, 0]) y_diff = abs(xyo_det[:, 1] - xyo_det_fit['calc_xy'][:, 1]) ome_diff = np.degrees( xfcapi.angularDifference(xyo_det[:, 2], xyo_det_fit['calc_omes']) ) # filter out reflections with centroids more than # a pixel and delta omega away from predicted value idx_new = np.logical_and( x_diff <= xpix_tol, np.logical_and(y_diff <= ypix_tol, ome_diff <= fome_tol) ) # attach to proper dict entry culled_results_r[det_key] = [ presults[i] for i in np.where(idx_new)[0] ] num_refl_valid += sum(idx_new) pass # only execute fit if left with enough reflections if num_refl_valid > 12: grain_params = fitGrain( grain_params, instrument, culled_results_r, plane_data.latVecOps['B'], plane_data.wavelength ) # get chisq # TODO: do this while evaluating fit??? chisq = objFuncFitGrain( grain_params[gFlag_ref], grain_params, gFlag_ref, instrument, culled_results_r, plane_data.latVecOps['B'], plane_data.wavelength, ome_period, simOnly=False, return_value_flag=2) pass pass # close refit conditional return grain_id, completeness, chisq, grain_params
[docs]def fit_grains(cfg, grains_table, show_progress=False, ids_to_refine=None, write_spots_files=True, check_if_canceled_func=None): """ Performs optimization of grain parameters. operates on a single HEDM config block The `check_if_canceled_func` has the following signature: check_if_canceled_func() -> bool If it returns `True`, it indicates that fit_grains should be canceled. This is done by terminating the multiprocessing processes. If `check_if_canceled_func` is set, multiprocessing will be performed, even if there is only one grain or one process, so that it will be cancelable. """ # grab imageseries dict imsd = cfg.image_series # grab instrument instr = cfg.instrument.hedm # grab eta ranges and ome_period eta_ranges = np.radians(cfg.find_orientations.eta.range) # handle omega period # !!! we assume all detector ims have the same ome ranges, so any will do! oims = next(iter(imsd.values())) ome_period = np.radians(oims.omega[0, 0] + np.r_[0., 360.]) # number of processes ncpus = cfg.multiprocessing # threshold for fitting threshold = cfg.fit_grains.threshold if ids_to_refine is not None: grains_table = np.atleast_2d(grains_table[ids_to_refine, :]) spots_filename = "spots_%05d.out" if write_spots_files else None params = dict( grains_table=grains_table, plane_data=cfg.material.plane_data, instrument=instr, imgser_dict=imsd, tth_tol=cfg.fit_grains.tolerance.tth, eta_tol=cfg.fit_grains.tolerance.eta, ome_tol=cfg.fit_grains.tolerance.omega, npdiv=cfg.fit_grains.npdiv, refit=cfg.fit_grains.refit, threshold=threshold, eta_ranges=eta_ranges, ome_period=ome_period, analysis_dirname=cfg.analysis_dir, spots_filename=spots_filename) # ===================================================================== # EXECUTE MP FIT # ===================================================================== # DO FIT! if (len(grains_table) == 1 or ncpus == 1) and not check_if_canceled_func: logger.info("\tstarting serial fit") start = timeit.default_timer() fit_grain_FF_init(params) fit_results = list( map(fit_grain_FF_reduced, np.array(grains_table[:, 0], dtype=int)) ) fit_grain_FF_cleanup() elapsed = timeit.default_timer() - start else: nproc = min(ncpus, len(grains_table)) chunksize = max(1, len(grains_table)//ncpus) logger.info("\tstarting fit on %d processes with chunksize %d", nproc, chunksize) start = timeit.default_timer() pool = multiprocessing.Pool( nproc, fit_grain_FF_init, (params, ) ) async_result = pool.map_async( fit_grain_FF_reduced, np.array(grains_table[:, 0], dtype=int), chunksize=chunksize ) while not async_result.ready(): if check_if_canceled_func and check_if_canceled_func(): pool.terminate() logger.info('Fit grains canceled.') # Perform an early return if we need to cancel. return None async_result.wait(0.25) fit_results = async_result.get() pool.close() pool.join() elapsed = timeit.default_timer() - start logger.info("fitting took %f seconds", elapsed) return fit_results