Source code for matador.magres.referencer

# coding: utf-8
# Distributed under the terms of the MIT license.

from typing import Dict, List, Optional, Tuple
import warnings

import numpy as np
import statsmodels.api as sm

from matador.plotting.plotting import plotting_function
from matador.crystal import Crystal

__all__ = ("MagresReferencer",)

[docs]class MagresReferencer: """Class for referencing computed NMR chemical shielding tensors with experimental data on chemical shifts, and related plotting. Attributes: fit_model: The underlying statsmodel model. fit_results: The results of the statsmodel fit. fit_intercept: The intercept of the fit. fit_gradient: The gradient of the fit. fit_rsquared: The R^2 value of the fit. structures_exp: A dictionary of experimental structures, keyed by label shifts_exp: A dictionary of measured shifts with the same keys as `structures_exp`. species: The species of interest. structures: An optional list of theoretical structures with computed shieldings to be referenced. warn_tolerance: The maximum deviation from the ideal "-1" gradient above which to warn the user during the fit. """ fit_model: sm.regression.linear_model.WLS = None fit_results: sm.regression.linear_model.RegressionResults = None fit_intercept: float fit_gradient: float fit_rsquared: float structures_exp: Dict[str, Crystal] shifts_exp: List[Dict[str, List[float]]] species: str structures: Optional[List[Crystal]] warn_tolerance: float = 0.1 def __init__( self, structures_exp: Dict[str, Crystal], shifts_exp: List[Dict[str, List[float]]], species: str, structures: Optional[List[Crystal]] = None, warn_tolerance: float = 0.1, ): self.structures_exp = structures_exp self.shifts_exp = shifts_exp self.species = species self.structures = structures self.warn_tolerance = warn_tolerance self._calc_shifts = [] self._expt_shifts = [] self._fit_weights = [] self._fit_structures = [] self._fitted = False for formula in self.shifts_exp: if self.species not in self.shifts_exp[formula]: continue if formula not in self.structures_exp: warnings.warn(f"Missing {formula} in reference structures.") self.match_exp_structure_shifts( self.structures_exp[formula], self.shifts_exp[formula][self.species] ) self.print_fit_summary() if self.structures is not None: self.set_shifts_from_fit(self.structures)
[docs] def match_exp_structure_shifts(self, structure, shifts): """For a model structure and a set of experimental shifts, match each site in the structure to the closest shift value, allowing for multiplicity. Sets the weights to be used for each site based on the number of unique shifts/local environments in the structure. Parameters: structure: The model structure for each exp. shift. shifts: An array of measured chemical shifts. """ relevant_sites = [site for site in structure if site.species == self.species] calc_shifts = sorted( [site["chemical_shielding_iso"] for site in relevant_sites] ) _shifts = shifts if ( len(_shifts) <= len(relevant_sites) and len(relevant_sites) % len(_shifts) == 0 ): multiplier = len(relevant_sites) // len(_shifts) _shifts = [shift for cell in [_shifts] * multiplier for shift in cell] _weights = [1 / multiplier for shift in _shifts] else: raise RuntimeError( f"Incompatible shift sizes: {len(relevant_sites)} (theor.) vs {len(_shifts)} (expt.), " "please provide commensurate cells." ) _shifts = sorted(_shifts, reverse=True) self._calc_shifts.extend(calc_shifts) self._expt_shifts.extend(_shifts) self._fit_weights.extend(_weights) self._fit_structures.extend(len(_shifts) * [structure.formula_tex]) return _shifts, _weights, calc_shifts
[docs] def set_shifts_from_fit(self, structures): """Set the chemical shifts of the given structures to the predicted values. Parameters: structures: A list of structures with computed shieldings. """ for _, struc in enumerate(structures): for _, site in enumerate(struc): if site.species == self.species: site["chemical_shift_iso"] = self.predict( site["chemical_shielding_iso"] )[0]
[docs] def fit(self): """Construct a statsmodels weighted least squares model between experimental shifts and computed shieldings. Sets the following attributes: fit_model: The underlying statsmodel model. fit_results: The output of the statsmodel fit. fit_intercept: The intercept of the fit. fit_gradient: The gradient of the fit. fit_rsquared: The R^2 value of the fit. """ _fit_shifts = sm.add_constant(self._calc_shifts) self.fit_model = sm.regression.linear_model.WLS( self._expt_shifts, _fit_shifts, weights=self._fit_weights ) self.fit_results = self.fit_intercept = self.fit_results.params[0] self.fit_gradient = self.fit_results.params[1] self.fit_rsquared = self.fit_results.rsquared self._fitted = True if abs(self.fit_gradient + 1.0) > self.warn_tolerance: warnings.warn( f"{self.__class__.__name__} fit gradient was {self.fit_gradient:.2f}, " f"outside of tolerated range -1.0 ± {self.warn_tolerance:.2f}" )
[docs] def predict(self, shieldings) -> Tuple[np.array, np.array]: """Compute the predicted chemical shifts (and errors) for given chemical shieldings, based on the linear fit. Returns: A tuple of shift values and associated errors. """ _shieldings = np.asarray(shieldings) return ( self.fit_gradient * _shieldings + self.fit_intercept, self.fit_results.bse[1] * _shieldings + self.fit_results.bse[0], )
[docs] def print_fit_summary(self): """Print the fitted parameters and their errors.""" if self._fitted: print("Performed WLS fit for: δ_expt = m * σ_calc + c") print(f"m = {self.fit_gradient:3.3f} ± {self.fit_results.bse[1]:3.3f}") print(f"c = {self.fit_intercept:3.3f} ± {self.fit_results.bse[0]:3.3f} ppm") print(f"R² = {self.fit_rsquared:3.3f}.") else: raise RuntimeError("Fit has not yet been performed.")
[docs] @plotting_function def plot_fit(self, ax=None, padding=100, figsize=None): """Plot the fit of shifts vs shieldings to experimental data. Parameters: ax (matplotlib.axes.Axes): Axes to plot on. padding (float): Padding to add to each axis limit. figsize (tuple): Figure size. Returns: The axis object used for plotting. """ import matplotlib.pyplot as plt import seaborn as sns if ax is None: _, ax = plt.subplots(figsize=figsize) ax.grid(False) ax.set_xlim( np.min(self._calc_shifts) - padding, np.max(self._calc_shifts) + padding ) ax.set_ylim( np.min(self._expt_shifts) - padding, np.max(self._expt_shifts) + padding ) ax = sns.regplot( y=self._expt_shifts, x=self._calc_shifts, scatter=False, ax=ax, line_kws={"linewidth": 0}, color="grey", truncate=False, ) sns.scatterplot( y=self._expt_shifts, x=self._calc_shifts, hue=self._fit_structures, palette="Dark2", ax=ax, zorder=100, ) ax.plot( np.asarray(ax.get_xlim()), self.fit_gradient * np.asarray(ax.get_xlim()) + self.fit_intercept, zorder=10, lw=1.5, c="grey", ) ax.legend(ncol=len(set(self._fit_structures)) // 5) ax.set_ylabel("$\\delta_\\mathrm{expt}$ (ppm)") ax.set_xlabel("$\\sigma_\\mathrm{calc}$ (ppm)") ax.text( 0.1, 0.1, f"$m = {self.fit_gradient:3.3f}; c = {self.fit_intercept:3.0f}\\,ppm; R^2 = {self.fit_rsquared:3.3f}$", bbox=dict(alpha=0.8, facecolor="w", boxstyle="Round"), transform=ax.transAxes, ) return ax
[docs] @plotting_function def plot_fit_and_predictions(self, ax=None, padding=100): """Make a plot of the fit and predictions for the experimental chemical shifts. Parameters: ax (matplotlib.axes.Axes): Optional axis object to plot on. padding (float): Padding to add to each axis limit. """ import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() self.plot_fit(ax=ax, padding=padding) for doc in self.structures: _calc_shifts = [ site["chemical_shielding_iso"] for site in doc if site.species == self.species ] _predicted_shifts, _predicted_errs = self.predict(_calc_shifts) ax.scatter(_predicted_shifts, _calc_shifts, s=5, c="k") ax.errorbar( _predicted_shifts, _calc_shifts, fmt="None", xerr=_predicted_errs, lw=0.5, c="k", ) return ax