import logging
import os
import numpy as np
from scipy.optimize import leastsq, least_squares
from hexrd import constants as cnst
from hexrd import matrixutil as mutil
from hexrd.transforms import xfcapi
from .. import grains as grainutil
logger = logging.getLogger()
logger.setLevel('INFO')
# grains
grain_flags_DFLT = np.array(
[1, 1, 1,
1, 0, 1,
0, 0, 0, 0, 0, 0],
dtype=bool
)
ext_eta_tol = np.radians(5.) # for HEDM cal, may make this a user param
[docs]def calibrate_instrument_from_sx(
instr, grain_params, bmat, xyo_det, hkls_idx,
param_flags=None, grain_flags=None,
ome_period=None,
xtol=cnst.sqrt_epsf, ftol=cnst.sqrt_epsf,
factor=10., sim_only=False, use_robust_lsq=False):
"""
arguments xyo_det, hkls_idx are DICTs over panels
"""
grain_params = np.atleast_2d(grain_params)
ngrains = len(grain_params)
pnames = generate_parameter_names(instr, grain_params)
# reset parameter flags for instrument as specified
if param_flags is None:
param_flags = instr.calibration_flags
else:
# will throw an AssertionError if wrong length
instr.calibration_flags = param_flags
# re-map omegas if need be
if ome_period is not None:
for det_key in instr.detectors:
for ig in range(ngrains):
xyo_det[det_key][ig][:, 2] = xfcapi.mapAngle(
xyo_det[det_key][ig][:, 2],
ome_period
)
# first grab the instrument parameters
# 7 global
# 6*num_panels for the detectors
# num_panels*ndp in case of distortion
plist_full = instr.calibration_parameters
# now handle grains
# reset parameter flags for grains as specified
if grain_flags is None:
grain_flags = np.tile(grain_flags_DFLT, ngrains)
plist_full = np.concatenate(
[plist_full, np.hstack(grain_params)]
)
plf_copy = np.copy(plist_full)
# concatenate refinement flags
refine_flags = np.hstack([param_flags, grain_flags])
plist_fit = plist_full[refine_flags]
fit_args = (plist_full,
param_flags, grain_flags,
instr, xyo_det, hkls_idx,
bmat, ome_period)
if sim_only:
return sxcal_obj_func(
plist_fit, plist_full,
param_flags, grain_flags,
instr, xyo_det, hkls_idx,
bmat, ome_period,
sim_only=True)
else:
logger.info("Set up to refine:")
for i in np.where(refine_flags)[0]:
logger.info("\t%s = %1.7e" % (pnames[i], plist_full[i]))
# run optimization
if use_robust_lsq:
result = least_squares(
sxcal_obj_func, plist_fit, args=fit_args,
xtol=xtol, ftol=ftol,
loss='soft_l1', method='trf'
)
x = result.x
resd = result.fun
mesg = result.message
ierr = result.status
else:
# do least squares problem
x, cov_x, infodict, mesg, ierr = leastsq(
sxcal_obj_func, plist_fit, args=fit_args,
factor=factor, xtol=xtol, ftol=ftol,
full_output=1
)
resd = infodict['fvec']
if ierr not in [1, 2, 3, 4]:
raise RuntimeError(f"solution not found: {ierr=}")
else:
logger.info(f"optimization fininshed successfully with {ierr=}")
logger.info(mesg)
# ??? output message handling?
fit_params = plist_full
fit_params[refine_flags] = x
# run simulation with optimized results
sim_final = sxcal_obj_func(
x, plist_full,
param_flags, grain_flags,
instr, xyo_det, hkls_idx,
bmat, ome_period,
sim_only=True)
# ??? reset instrument here?
instr.update_from_parameter_list(fit_params)
# report final
logger.info("Optimization Reults:")
for i in np.where(refine_flags)[0]:
logger.info("\t%s = %1.7e --> %1.7e"
% (pnames[i], plf_copy[i], fit_params[i]))
return fit_params, resd, sim_final
[docs]def generate_parameter_names(instr, grain_params):
pnames = [
'{:>24s}'.format('beam energy'),
'{:>24s}'.format('beam azimuth'),
'{:>24s}'.format('beam polar'),
'{:>24s}'.format('chi'),
'{:>24s}'.format('tvec_s[0]'),
'{:>24s}'.format('tvec_s[1]'),
'{:>24s}'.format('tvec_s[2]'),
]
for det_key, panel in instr.detectors.items():
pnames += [
'{:>24s}'.format('%s tilt[0]' % det_key),
'{:>24s}'.format('%s tilt[1]' % det_key),
'{:>24s}'.format('%s tilt[2]' % det_key),
'{:>24s}'.format('%s tvec[0]' % det_key),
'{:>24s}'.format('%s tvec[1]' % det_key),
'{:>24s}'.format('%s tvec[2]' % det_key),
]
# now add distortion if there
if panel.distortion is not None:
for j in range(len(panel.distortion.params)):
pnames.append(
'{:>24s}'.format('%s dparam[%d]' % (det_key, j))
)
grain_params = np.atleast_2d(grain_params)
for ig, grain in enumerate(grain_params):
pnames += [
'{:>24s}'.format('grain %d xi[0]' % ig),
'{:>24s}'.format('grain %d xi[1]' % ig),
'{:>24s}'.format('grain %d xi[2]' % ig),
'{:>24s}'.format('grain %d tvec_c[0]' % ig),
'{:>24s}'.format('grain %d tvec_c[1]' % ig),
'{:>24s}'.format('grain %d tvec_c[2]' % ig),
'{:>24s}'.format('grain %d vinv_s[0]' % ig),
'{:>24s}'.format('grain %d vinv_s[1]' % ig),
'{:>24s}'.format('grain %d vinv_s[2]' % ig),
'{:>24s}'.format('grain %d vinv_s[3]' % ig),
'{:>24s}'.format('grain %d vinv_s[4]' % ig),
'{:>24s}'.format('grain %d vinv_s[5]' % ig)
]
return pnames
[docs]def sxcal_obj_func(plist_fit, plist_full,
param_flags, grain_flags,
instr, xyo_det, hkls_idx,
bmat, ome_period,
sim_only=False, return_value_flag=None):
"""
"""
npi = len(instr.calibration_parameters)
NP_GRN = 12
# stack flags and force bool repr
refine_flags = np.array(
np.hstack([param_flags, grain_flags]),
dtype=bool)
# fill out full parameter list
# !!! no scaling for now
plist_full[refine_flags] = plist_fit
# instrument update
instr.update_from_parameter_list(plist_full)
# assign some useful params
wavelength = instr.beam_wavelength
bvec = instr.beam_vector
chi = instr.chi
tvec_s = instr.tvec
# right now just stuck on the end and assumed
# to all be the same length... FIX THIS
xy_unwarped = {}
meas_omes = {}
calc_omes = {}
calc_xy = {}
# grain params
grain_params = plist_full[npi:]
if np.mod(len(grain_params), NP_GRN) != 0:
raise RuntimeError("parameter list length is not consistent")
ngrains = len(grain_params) // NP_GRN
grain_params = grain_params.reshape((ngrains, NP_GRN))
# loop over panels
npts_tot = 0
for det_key, panel in instr.detectors.items():
rmat_d = panel.rmat
tvec_d = panel.tvec
xy_unwarped[det_key] = []
meas_omes[det_key] = []
calc_omes[det_key] = []
calc_xy[det_key] = []
for ig, grain in enumerate(grain_params):
ghkls = hkls_idx[det_key][ig]
xyo = xyo_det[det_key][ig]
npts_tot += len(xyo)
xy_unwarped[det_key].append(xyo[:, :2])
meas_omes[det_key].append(xyo[:, 2])
if panel.distortion is not None: # do unwarping
xy_unwarped[det_key][ig] = panel.distortion.apply(
xy_unwarped[det_key][ig]
)
pass
# transform G-vectors:
# 1) convert inv. stretch tensor from MV notation in to 3x3
# 2) take reciprocal lattice vectors from CRYSTAL to SAMPLE frame
# 3) apply stretch tensor
# 4) normalize reciprocal lattice vectors in SAMPLE frame
# 5) transform unit reciprocal lattice vetors back to CRYSAL frame
rmat_c = xfcapi.makeRotMatOfExpMap(grain[:3])
tvec_c = grain[3:6]
vinv_s = grain[6:]
gvec_c = np.dot(bmat, ghkls.T)
vmat_s = mutil.vecMVToSymm(vinv_s)
ghat_s = mutil.unitVector(np.dot(vmat_s, np.dot(rmat_c, gvec_c)))
ghat_c = np.dot(rmat_c.T, ghat_s)
match_omes, calc_omes_tmp = grainutil.matchOmegas(
xyo, ghkls.T,
chi, rmat_c, bmat, wavelength,
vInv=vinv_s,
beamVec=bvec,
omePeriod=ome_period)
rmat_s_arr = xfcapi.make_sample_rmat(
chi, np.ascontiguousarray(calc_omes_tmp)
)
calc_xy_tmp = xfcapi.gvec_to_xy(
ghat_c.T, rmat_d, rmat_s_arr, rmat_c,
tvec_d, tvec_s, tvec_c
)
if np.any(np.isnan(calc_xy_tmp)):
logger.warning("infeasible parameters: may want to scale back "
"finite difference step size")
calc_omes[det_key].append(calc_omes_tmp)
calc_xy[det_key].append(calc_xy_tmp)
pass
pass
# return values
if sim_only:
retval = {}
for det_key in calc_xy.keys():
# ??? calc_xy is always 2-d
retval[det_key] = []
for ig in range(ngrains):
retval[det_key].append(
np.vstack(
[calc_xy[det_key][ig].T, calc_omes[det_key][ig]]
).T
)
else:
meas_xy_all = []
calc_xy_all = []
meas_omes_all = []
calc_omes_all = []
for det_key in xy_unwarped.keys():
meas_xy_all.append(np.vstack(xy_unwarped[det_key]))
calc_xy_all.append(np.vstack(calc_xy[det_key]))
meas_omes_all.append(np.hstack(meas_omes[det_key]))
calc_omes_all.append(np.hstack(calc_omes[det_key]))
pass
meas_xy_all = np.vstack(meas_xy_all)
calc_xy_all = np.vstack(calc_xy_all)
meas_omes_all = np.hstack(meas_omes_all)
calc_omes_all = np.hstack(calc_omes_all)
diff_vecs_xy = calc_xy_all - meas_xy_all
diff_ome = xfcapi.angularDifference(calc_omes_all, meas_omes_all)
retval = np.hstack(
[diff_vecs_xy,
diff_ome.reshape(npts_tot, 1)]
).flatten()
if return_value_flag == 1:
retval = sum(abs(retval))
elif return_value_flag == 2:
denom = npts_tot - len(plist_fit) - 1.
if denom != 0:
nu_fac = 1. / denom
else:
nu_fac = 1.
nu_fac = 1 / (npts_tot - len(plist_fit) - 1.)
retval = nu_fac * sum(retval**2)
return retval
[docs]def parse_reflection_tables(cfg, instr, grain_ids, refit_idx=None):
"""
make spot dictionaries
"""
hkls = {}
xyo_det = {}
idx_0 = {}
for det_key, panel in instr.detectors.items():
hkls[det_key] = []
xyo_det[det_key] = []
idx_0[det_key] = []
for ig, grain_id in enumerate(grain_ids):
spots_filename = os.path.join(
cfg.analysis_dir, os.path.join(
det_key, 'spots_%05d.out' % grain_id
)
)
# load pull_spots output table
gtable = np.loadtxt(spots_filename, ndmin=2)
if len(gtable) == 0:
gtable = np.nan*np.ones((1, 17))
# apply conditions for accepting valid data
valid_reflections = gtable[:, 0] >= 0 # is indexed
not_saturated = gtable[:, 6] < panel.saturation_level
# throw away extremem etas
p90 = xfcapi.angularDifference(gtable[:, 8], cnst.piby2)
m90 = xfcapi.angularDifference(gtable[:, 8], -cnst.piby2)
accept_etas = np.logical_or(p90 > ext_eta_tol,
m90 > ext_eta_tol)
logger.info(f"panel '{det_key}', grain {grain_id}")
logger.info(f"{sum(valid_reflections)} of {len(gtable)} "
"reflections are indexed")
logger.info(f"{sum(not_saturated)} of {sum(valid_reflections)}"
" valid reflections be are below" +
f" saturation threshold of {panel.saturation_level}")
logger.info(f"{sum(accept_etas)} of {len(gtable)}"
" reflections be are greater than " +
f" {np.degrees(ext_eta_tol)} from the rotation axis")
# valid reflections index
if refit_idx is None:
idx = np.logical_and(
valid_reflections,
np.logical_and(not_saturated, accept_etas)
)
idx_0[det_key].append(idx)
else:
idx = refit_idx[det_key][ig]
idx_0[det_key].append(idx)
logger.info(f"input reflection specify {sum(idx)} of "
f"{len(gtable)} total valid reflections")
hkls[det_key].append(gtable[idx, 2:5])
meas_omes = gtable[idx, 12].reshape(sum(idx), 1)
xyo_det[det_key].append(np.hstack([gtable[idx, -2:], meas_omes]))
return hkls, xyo_det, idx_0