Source code for hexrd.fitting.calibration.multigrain

import logging
import os

import numpy as np
from scipy.optimize import leastsq, least_squares

from hexrd import constants as cnst
from hexrd import matrixutil as mutil
from hexrd import rotations
from hexrd.transforms import xfcapi

from .. import grains as grainutil

logger = logging.getLogger()
logger.setLevel('INFO')

# grains
grain_flags_DFLT = np.array(
    [1, 1, 1,
     1, 0, 1,
     0, 0, 0, 0, 0, 0],
    dtype=bool
)

ext_eta_tol = np.radians(5.)  # for HEDM cal, may make this a user param


[docs]def calibrate_instrument_from_sx( instr, grain_params, bmat, xyo_det, hkls_idx, param_flags=None, grain_flags=None, ome_period=None, xtol=cnst.sqrt_epsf, ftol=cnst.sqrt_epsf, factor=10., sim_only=False, use_robust_lsq=False): """ arguments xyo_det, hkls_idx are DICTs over panels """ grain_params = np.atleast_2d(grain_params) ngrains = len(grain_params) pnames = generate_parameter_names(instr, grain_params) # reset parameter flags for instrument as specified if param_flags is None: param_flags = instr.calibration_flags else: # will throw an AssertionError if wrong length instr.calibration_flags = param_flags # re-map omegas if need be if ome_period is not None: for det_key in instr.detectors: for ig in range(ngrains): xyo_det[det_key][ig][:, 2] = rotations.mapAngle( xyo_det[det_key][ig][:, 2], ome_period ) # first grab the instrument parameters # 7 global # 6*num_panels for the detectors # num_panels*ndp in case of distortion plist_full = instr.calibration_parameters # now handle grains # reset parameter flags for grains as specified if grain_flags is None: grain_flags = np.tile(grain_flags_DFLT, ngrains) plist_full = np.concatenate( [plist_full, np.hstack(grain_params)] ) plf_copy = np.copy(plist_full) # concatenate refinement flags refine_flags = np.hstack([param_flags, grain_flags]) plist_fit = plist_full[refine_flags] fit_args = (plist_full, param_flags, grain_flags, instr, xyo_det, hkls_idx, bmat, ome_period) if sim_only: return sxcal_obj_func( plist_fit, plist_full, param_flags, grain_flags, instr, xyo_det, hkls_idx, bmat, ome_period, sim_only=True) else: logger.info("Set up to refine:") for i in np.where(refine_flags)[0]: logger.info("\t%s = %1.7e" % (pnames[i], plist_full[i])) # run optimization if use_robust_lsq: result = least_squares( sxcal_obj_func, plist_fit, args=fit_args, xtol=xtol, ftol=ftol, loss='soft_l1', method='trf' ) x = result.x resd = result.fun mesg = result.message ierr = result.status else: # do least squares problem x, cov_x, infodict, mesg, ierr = leastsq( sxcal_obj_func, plist_fit, args=fit_args, factor=factor, xtol=xtol, ftol=ftol, full_output=1 ) resd = infodict['fvec'] if ierr not in [1, 2, 3, 4]: raise RuntimeError(f"solution not found: {ierr=}") else: logger.info(f"optimization fininshed successfully with {ierr=}") logger.info(mesg) # ??? output message handling? fit_params = plist_full fit_params[refine_flags] = x # run simulation with optimized results sim_final = sxcal_obj_func( x, plist_full, param_flags, grain_flags, instr, xyo_det, hkls_idx, bmat, ome_period, sim_only=True) # ??? reset instrument here? instr.update_from_parameter_list(fit_params) # report final logger.info("Optimization Reults:") for i in np.where(refine_flags)[0]: logger.info("\t%s = %1.7e --> %1.7e" % (pnames[i], plf_copy[i], fit_params[i])) return fit_params, resd, sim_final
[docs]def generate_parameter_names(instr, grain_params): pnames = [ '{:>24s}'.format('beam energy'), '{:>24s}'.format('beam azimuth'), '{:>24s}'.format('beam polar'), '{:>24s}'.format('chi'), '{:>24s}'.format('tvec_s[0]'), '{:>24s}'.format('tvec_s[1]'), '{:>24s}'.format('tvec_s[2]'), ] for det_key, panel in instr.detectors.items(): pnames += [ '{:>24s}'.format('%s tilt[0]' % det_key), '{:>24s}'.format('%s tilt[1]' % det_key), '{:>24s}'.format('%s tilt[2]' % det_key), '{:>24s}'.format('%s tvec[0]' % det_key), '{:>24s}'.format('%s tvec[1]' % det_key), '{:>24s}'.format('%s tvec[2]' % det_key), ] # now add distortion if there if panel.distortion is not None: for j in range(len(panel.distortion.params)): pnames.append( '{:>24s}'.format('%s dparam[%d]' % (det_key, j)) ) grain_params = np.atleast_2d(grain_params) for ig, grain in enumerate(grain_params): pnames += [ '{:>24s}'.format('grain %d xi[0]' % ig), '{:>24s}'.format('grain %d xi[1]' % ig), '{:>24s}'.format('grain %d xi[2]' % ig), '{:>24s}'.format('grain %d tvec_c[0]' % ig), '{:>24s}'.format('grain %d tvec_c[1]' % ig), '{:>24s}'.format('grain %d tvec_c[2]' % ig), '{:>24s}'.format('grain %d vinv_s[0]' % ig), '{:>24s}'.format('grain %d vinv_s[1]' % ig), '{:>24s}'.format('grain %d vinv_s[2]' % ig), '{:>24s}'.format('grain %d vinv_s[3]' % ig), '{:>24s}'.format('grain %d vinv_s[4]' % ig), '{:>24s}'.format('grain %d vinv_s[5]' % ig) ] return pnames
[docs]def sxcal_obj_func(plist_fit, plist_full, param_flags, grain_flags, instr, xyo_det, hkls_idx, bmat, ome_period, sim_only=False, return_value_flag=None): """ """ npi = len(instr.calibration_parameters) NP_GRN = 12 # stack flags and force bool repr refine_flags = np.array( np.hstack([param_flags, grain_flags]), dtype=bool) # fill out full parameter list # !!! no scaling for now plist_full[refine_flags] = plist_fit # instrument update instr.update_from_parameter_list(plist_full) # assign some useful params wavelength = instr.beam_wavelength bvec = instr.beam_vector chi = instr.chi tvec_s = instr.tvec # right now just stuck on the end and assumed # to all be the same length... FIX THIS xy_unwarped = {} meas_omes = {} calc_omes = {} calc_xy = {} # grain params grain_params = plist_full[npi:] if np.mod(len(grain_params), NP_GRN) != 0: raise RuntimeError("parameter list length is not consistent") ngrains = len(grain_params) // NP_GRN grain_params = grain_params.reshape((ngrains, NP_GRN)) # loop over panels npts_tot = 0 for det_key, panel in instr.detectors.items(): rmat_d = panel.rmat tvec_d = panel.tvec xy_unwarped[det_key] = [] meas_omes[det_key] = [] calc_omes[det_key] = [] calc_xy[det_key] = [] for ig, grain in enumerate(grain_params): ghkls = hkls_idx[det_key][ig] xyo = xyo_det[det_key][ig] npts_tot += len(xyo) xy_unwarped[det_key].append(xyo[:, :2]) meas_omes[det_key].append(xyo[:, 2]) if panel.distortion is not None: # do unwarping xy_unwarped[det_key][ig] = panel.distortion.apply( xy_unwarped[det_key][ig] ) # transform G-vectors: # 1) convert inv. stretch tensor from MV notation in to 3x3 # 2) take reciprocal lattice vectors from CRYSTAL to SAMPLE frame # 3) apply stretch tensor # 4) normalize reciprocal lattice vectors in SAMPLE frame # 5) transform unit reciprocal lattice vetors back to CRYSAL frame rmat_c = xfcapi.make_rmat_of_expmap(grain[:3]) tvec_c = grain[3:6] vinv_s = grain[6:] gvec_c = np.dot(bmat, ghkls.T) vmat_s = mutil.vecMVToSymm(vinv_s) ghat_s = mutil.unitVector(np.dot(vmat_s, np.dot(rmat_c, gvec_c))) ghat_c = np.dot(rmat_c.T, ghat_s) match_omes, calc_omes_tmp = grainutil.matchOmegas( xyo, ghkls.T, chi, rmat_c, bmat, wavelength, vInv=vinv_s, beamVec=bvec, omePeriod=ome_period) rmat_s_arr = xfcapi.make_sample_rmat( chi, np.ascontiguousarray(calc_omes_tmp) ) calc_xy_tmp = xfcapi.gvec_to_xy( ghat_c.T, rmat_d, rmat_s_arr, rmat_c, tvec_d, tvec_s, tvec_c ) if np.any(np.isnan(calc_xy_tmp)): logger.warning("infeasible parameters: may want to scale back " "finite difference step size") calc_omes[det_key].append(calc_omes_tmp) calc_xy[det_key].append(calc_xy_tmp) # return values if sim_only: retval = {} for det_key in calc_xy.keys(): # ??? calc_xy is always 2-d retval[det_key] = [] for ig in range(ngrains): retval[det_key].append( np.vstack( [calc_xy[det_key][ig].T, calc_omes[det_key][ig]] ).T ) else: meas_xy_all = [] calc_xy_all = [] meas_omes_all = [] calc_omes_all = [] for det_key in xy_unwarped.keys(): meas_xy_all.append(np.vstack(xy_unwarped[det_key])) calc_xy_all.append(np.vstack(calc_xy[det_key])) meas_omes_all.append(np.hstack(meas_omes[det_key])) calc_omes_all.append(np.hstack(calc_omes[det_key])) meas_xy_all = np.vstack(meas_xy_all) calc_xy_all = np.vstack(calc_xy_all) meas_omes_all = np.hstack(meas_omes_all) calc_omes_all = np.hstack(calc_omes_all) diff_vecs_xy = calc_xy_all - meas_xy_all diff_ome = rotations.angularDifference(calc_omes_all, meas_omes_all) retval = np.hstack( [diff_vecs_xy, diff_ome.reshape(npts_tot, 1)] ).flatten() if return_value_flag == 1: retval = sum(abs(retval)) elif return_value_flag == 2: denom = npts_tot - len(plist_fit) - 1. if denom != 0: nu_fac = 1. / denom else: nu_fac = 1. nu_fac = 1 / (npts_tot - len(plist_fit) - 1.) retval = nu_fac * sum(retval**2) return retval
[docs]def parse_reflection_tables(cfg, instr, grain_ids, refit_idx=None): """ make spot dictionaries """ hkls = {} xyo_det = {} idx_0 = {} for det_key, panel in instr.detectors.items(): hkls[det_key] = [] xyo_det[det_key] = [] idx_0[det_key] = [] for ig, grain_id in enumerate(grain_ids): spots_filename = os.path.join( cfg.analysis_dir, os.path.join( det_key, 'spots_%05d.out' % grain_id ) ) # load pull_spots output table gtable = np.loadtxt(spots_filename, ndmin=2) if len(gtable) == 0: gtable = np.nan*np.ones((1, 17)) # apply conditions for accepting valid data valid_reflections = gtable[:, 0] >= 0 # is indexed not_saturated = gtable[:, 6] < panel.saturation_level # throw away extremem etas p90 = rotations.angularDifference(gtable[:, 8], cnst.piby2) m90 = rotations.angularDifference(gtable[:, 8], -cnst.piby2) accept_etas = np.logical_or(p90 > ext_eta_tol, m90 > ext_eta_tol) logger.info(f"panel '{det_key}', grain {grain_id}") logger.info(f"{sum(valid_reflections)} of {len(gtable)} " "reflections are indexed") logger.info(f"{sum(not_saturated)} of {sum(valid_reflections)}" " valid reflections be are below" + f" saturation threshold of {panel.saturation_level}") logger.info(f"{sum(accept_etas)} of {len(gtable)}" " reflections be are greater than " + f" {np.degrees(ext_eta_tol)} from the rotation axis") # valid reflections index if refit_idx is None: idx = np.logical_and( valid_reflections, np.logical_and(not_saturated, accept_etas) ) idx_0[det_key].append(idx) else: idx = refit_idx[det_key][ig] idx_0[det_key].append(idx) logger.info(f"input reflection specify {sum(idx)} of " f"{len(gtable)} total valid reflections") hkls[det_key].append(gtable[idx, 2:5]) meas_omes = gtable[idx, 12].reshape(sum(idx), 1) xyo_det[det_key].append(np.hstack([gtable[idx, -2:], meas_omes])) return hkls, xyo_det, idx_0