Source code for matador.plotting.magres_plotting

# coding: utf-8
# Distributed under the terms of the MIT License.

""" This submodule contains functions to plot chemical shifts
and chemical shieldings from NMR calculations.

"""

from typing import List, Optional, Dict, Union, Tuple

import numpy as np
from matador.crystal import Crystal
from matador.plotting.plotting import plotting_function
from matador.fingerprints import Fingerprint

__all__ = ['plot_magres']


[docs]@plotting_function def plot_magres( magres: Union[List[Crystal], Crystal], species: str, magres_key: str = "chemical_shielding_iso", xlabel: str = None, broadening_width: float = 1, text_offset: float = 0.1, ax=None, figsize: Tuple[float] = None, show: bool = False, savefig: Optional[str] = None, signal_labels: Optional[Union[str, List[str]]] = None, signal_limits: Tuple[float] = None, line_kwargs: Optional[Union[Dict, List[Dict]]] = None, invert_xaxis: bool = True, ): """ Plot magnetic resonance. Parameters: magres (Union[Crystal, List[Crystal]]): list of :class:`Crystal` containing magres data. species (str): the species to plot the shifts of. Keyword arguments: ax (matplotlib.axes.Axes): an existing axis on which to plot. magres_key (str): the data key for which the magres site data is stored under. show (bool): whether to show plot in an X window. figsize (Tuple[float]): overrides the default size for the matplotlib figure. broadening_width (float): the Lorentzian width to apply to the shifts. xlabel (str): a custom label for the x-axis. savefig (str): filename to use to save the plot. signal_labels (list): optional list of labels for the curves in the magres list. signal_limits (Tuple[float]): values at which to clip the magres signals. Defaults to the maximum and minimum shifts across all passed structures. line_kwargs (list or dict): parameters to pass to the curve plotter, if a list then the line kwargs will be passed to each line individually. """ import matplotlib.pyplot as plt if not isinstance(magres, list): magres = [magres] if signal_labels is not None and not isinstance(signal_labels, list): signal_labels = [signal_labels] if line_kwargs is not None and not isinstance(line_kwargs, list): line_kwargs = len(magres) * [line_kwargs] if figsize is None: _user_default_figsize = plt.rcParams.get('figure.figsize', (8, 6)) height = len(magres) * max(0.5, _user_default_figsize[1] / 1.5 / len(magres)) figsize = (_user_default_figsize[0], height) if ax is None: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) if species is None: raise RuntimeError("You must provide a species label for plotting.") if signal_labels is not None and len(signal_labels) != len(magres): raise RuntimeError( "Wrong number of labels passed for number of magres: {} vs {}" .format(len(signal_labels), len(magres)) ) _magres = [] if signal_limits is not None: min_shielding, max_shielding = signal_limits else: min_shielding, max_shielding = (1e20, -1e20) for ind, doc in enumerate(magres): if isinstance(doc, dict): _doc = Crystal(doc) else: _doc = doc _magres.append(_doc) relevant_sites = [atom for atom in _doc if atom.species == species] if relevant_sites: shielding = [atom[magres_key] for atom in relevant_sites] if signal_limits is None: min_shielding = min(np.min(shielding), min_shielding) max_shielding = max(np.max(shielding), max_shielding) if min_shielding > 1e19 and max_shielding < -1e19: raise RuntimeError(f"No sites of {species} found in any of the passed crystals.") _buffer = 0.2 * np.abs(min_shielding - max_shielding) s_space = np.linspace( min_shielding - _buffer, max_shielding + _buffer, num=1000 ) _padded_colours = list(plt.rcParams["axes.prop_cycle"].by_key()["color"]) _padded_colours = (1 + (len(magres) // len(_padded_colours))) * _padded_colours if line_kwargs is not None and len(line_kwargs) != len(magres): raise RuntimeError( "Wrong number of line kwargs passed for number of magres: {} vs {}" .format(len(line_kwargs), len(magres)) ) for ind, doc in enumerate(_magres): if signal_labels is None: stoich_label = doc.formula_tex else: stoich_label = None _label = stoich_label if signal_labels is not None and len(signal_labels) > ind: _label = signal_labels[ind] _line_kwargs = {'c': _padded_colours[ind]} if line_kwargs is not None: _line_kwargs.update(line_kwargs[ind]) relevant_sites = [site for site in doc if site.species == species] if not relevant_sites: print(f"No sites of {species} found in {doc.root_source}, signal will be empty.") signal = np.zeros_like(s_space) else: shifts = [site[magres_key] for site in relevant_sites] hist, bins = np.histogram(shifts, bins=s_space) if broadening_width > 0: signal = Fingerprint._broadening_unrolled(hist, s_space, broadening_width, broadening_type="lorentzian") else: signal = np.array(hist, dtype=np.float64) bin_centres = s_space[:-1] + (s_space[1] - s_space[0]) / 2 s_space = bin_centres if np.max(signal) > 1e-10: signal /= np.max(signal) else: signal *= 0 ax.plot(s_space, signal + (ind * 1.1), **_line_kwargs) if _label is not None: ax.text(0.95, (ind * 1.1) + text_offset, _label, transform=ax.get_yaxis_transform(), horizontalalignment='right') if xlabel is None: unit = set(doc.get("magres_units", {}).get("ms", "ppm") for doc in magres) if len(unit) > 1: raise RuntimeError(f"Multiple incompatible units found for chemical shift: {unit}") unit = list(unit)[0] if magres_key == "chemical_shielding_iso": xlabel = f"{species}: Isotropic chemical shielding $\\sigma_\\mathrm{{iso}}$ ({unit})" elif magres_key == "chemical_shift_iso": xlabel = f"{species}: Isotropic chemical shift $\\sigma_\\mathrm{{iso}}$ ({unit})" elif magres_key == "chemical_shift_aniso": xlabel = f"{species}: Anisotropic chemical shift $\\sigma_\\mathrm{{iso}}$ ({unit})" elif magres_key == "chemical_shift_asymmetry": xlabel = f"{species}: Chemial shift asymmetry, $\\eta$" ax.set_xlabel(xlabel) ax.set_ylabel("Intensity (arb. units)") if len(magres) > 1: ax.set_yticks([]) else: ax.set_yticks(np.linspace(0, 1, 5, endpoint=True)) ax.set_ylim(-0.1, 1.1 * len(magres)) if invert_xaxis: ax.invert_xaxis() if savefig: plt.savefig(savefig) print('Wrote {}'.format(savefig)) elif show: plt.show() return ax