""" This file implements turning matador Crystal objects
into CrystalGraph objects.
"""
import networkx as nx
import numpy as np
import itertools
EPS = 1e-12
[docs]class CrystalGraph(nx.MultiDiGraph):
def __init__(self, structure=None, graph=None, coordination_cutoff=1.1, bond_tolerance=1e20, num_images=1,
debug=False, separate_images=False, delete_one_way_bonds=False, max_bond_length=5):
""" Create networkx.MultiDiGraph object with extra functionality for atomic networks.
Keyword Arguments:
structure (matador.Crystal): crystal structure to network-ify
graph (nx.MultiDiGraph): initialise from graph
coordination_cutoff (float) : max multiplier of first
coordination sphere for edge drawing
num_images (int): number of periodic images to include in
each direction
separate_images (bool): whether or not to include image
atoms as new nodes
"""
super().__init__()
if graph is None and structure is None:
raise RuntimeError('No structure or graph to initialise network from.')
if structure is not None:
atoms = structure.sites
num_atoms = len(atoms)
element_bonds = {}
images = list(itertools.product(range(-num_images, num_images+1), repeat=3))
image_number = 0
# now loop over pairs of atoms and decide whether to draw an edge
from matador.utils.cell_utils import calc_pairwise_distances_pbc
distances = calc_pairwise_distances_pbc(structure.positions_abs,
images,
structure.lattice_cart,
max_bond_length,
compress=False,
debug=True)
# first over loop all pairs to find the minimum distance between all species pairs
# and the minimum distance for each atom
for i, atom in enumerate(atoms):
self.add_node(i, species=atom.species)
min_dists = [1e20 for atom in atoms]
for index in np.where(~distances.mask)[0]:
image_index = int(index / num_atoms**2)
i = int((index - image_index * num_atoms**2) / num_atoms)
j = int((index - image_index * num_atoms**2) % num_atoms)
atom = atoms[i]
other_atom = atoms[j]
if (i == j and np.linalg.norm(images[image_index]) <= EPS):
continue
dist = distances[index]
pair_key = tuple(sorted([atom.species, other_atom.species]))
if pair_key not in element_bonds or element_bonds[pair_key] > dist:
element_bonds[pair_key] = dist
# find the closest image of an atom to i
if dist < min_dists[i]:
min_dists[i] = dist
if debug:
print(min_dists)
print(element_bonds)
for index in np.where(~distances.mask)[0]:
image_index = int(index / num_atoms**2)
i = int((index - image_index * num_atoms**2) / num_atoms)
j = int((index - image_index * num_atoms**2) % num_atoms)
atom = atoms[i]
other_atom = atoms[j]
min_dist = min_dists[i]
if (i == j and np.linalg.norm(images[image_index]) <= EPS):
continue
dist = distances[index]
pair_key = tuple(sorted([atom.species, other_atom.species]))
if dist <= min_dist*coordination_cutoff and dist <= element_bonds[pair_key]*bond_tolerance:
if separate_images and all([val <= 0+1e-8 for val in images[image_index]]):
image_number += 1
self.add_node(j+image_number, species=atoms[j].species)
self.add_edge(i, j+image_number, dist=dist)
self.add_edge(j+image_number, i, dist=dist)
else:
is_image = np.linalg.norm(images[image_index]) > EPS
self.add_edge(i, j, dist=dist, image=is_image)
elif graph is not None:
for node, data in graph.nodes.data():
self.add_node(node, species=data['species'], image=data.get('image', False))
for node_in, node_out, data in graph.edges.data():
if not delete_one_way_bonds or (node_out, node_in) in graph.edges():
self.add_edge(node_in, node_out, dist=data.get('dist', 0), image=data.get('image', False))
[docs] def get_strongly_connected_component_subgraphs(self, delete_one_way_bonds=True):
""" Return generator of strongly-connected subgraphs in CrystalGraph format. """
return (CrystalGraph(graph=self.subgraph(c),
delete_one_way_bonds=delete_one_way_bonds)
for c in nx.strongly_connected_components(self))
[docs] def get_communities(self, graph=None, **louvain_kwargs):
""" Return list of community subgraphs in CrystalGraph format. """
import community as louvain
if graph is None:
graph = self
if graph.is_directed():
undirected_graph = self.remove_directionality(graph=graph)
partition = louvain.best_partition(undirected_graph, **louvain_kwargs)
size = len(set(partition.values()))
subgraphs = [nx.MultiGraph() for i in range(size)]
for node in partition:
subgraphs[partition[node]].add_node(
node, species=list(self.nodes(data=True))[list(self.nodes()).index(node)][1]['species']
)
for edge in self.edges():
if partition[edge[0]] == partition[edge[1]]:
subgraphs[partition[edge[0]]].add_edge(edge[0], edge[1])
subgraphs = [CrystalGraph(graph=sg) for sg in subgraphs]
return subgraphs, partition
[docs] def remove_directionality(self, graph=None):
if graph is None:
graph = self
import networkx as nx
undirected_graph = nx.MultiGraph()
for node in graph.nodes(data=True):
undirected_graph.add_node(node[0], species=node[1]['species'])
for edge in graph.edges(data=True):
if (edge[1], edge[0]) not in undirected_graph.edges():
undirected_graph.add_edge(edge[0], edge[1], dist=edge[2].get('dist', 0), image=edge[2].get('image', False))
return undirected_graph
[docs] def set_unique_subgraphs(self, method='community'):
""" Filter strongly connected component subgraphs for isomorphism with others inside
CrystalGraph. Sets self.unique_subgraph to a set of such subgraphs.
"""
if method == 'community':
self.unique_subgraphs = get_unique_subgraphs(self.get_communities())
elif method == 'strongly_connected':
self.unique_subgraphs = get_unique_subgraphs(self.get_strongly_connected_component_subgraphs())
elif method == 'both':
strong_subgraphs = self.get_strongly_connected_component_subgraphs()
community_subgraphs = []
for sg in strong_subgraphs:
community_subgraphs.extend(sg.get_communities())
self.unique_subgraphs = get_unique_subgraphs(community_subgraphs)
[docs] def get_bonds_per_atom(self):
num_bonds = 0
for node_in in self.nodes():
for node_out in self.nodes():
if node_in == node_out:
continue
if self.has_edge(node_in, node_out) and self.has_edge(node_out, node_in):
num_bonds += 1
return num_bonds / self.number_of_nodes()
[docs]def node_match(n1, n2):
return n1['species'] == n2['species']
[docs]def get_unique_subgraphs(subgraphs):
""" Filter strongly connected component subgraphs for isomorphism with others.
Input:
| subgraphs: list(CrystalGraph), list of subgraph objects to filter
Returns:
| unique_subgraphs: set(CrystalGraph), set of unique subgraphs
"""
unique_subgraphs = set()
for subgraph in subgraphs:
if not any([are_graphs_the_same(subgraph, other_subgraph) for other_subgraph in unique_subgraphs]):
unique_subgraphs.add(subgraph)
return unique_subgraphs
[docs]def are_graphs_the_same(g1, g2, edge_match=None):
if edge_match is None:
def edge_match(e1, e2):
atol = 0.1
rtol = 0.05
return abs(e1[0]['dist'] - e2[0]['dist']) <= atol + rtol * e2[0]['dist'] and e1[0]['image'] == e2[0]['image']
return nx.is_isomorphic(g1, g2,
node_match=lambda n1, n2: n1['species'] == n2['species'],
edge_match=edge_match)
[docs]def draw_network(structure,
layout=None, edge_labels=False, node_index=False,
curved_edges=True, node_colour='elem', partition=None, ax=None):
import networkx as nx
from matador.utils.viz_utils import get_element_colours
import matplotlib.pyplot as plt
element_colours = get_element_colours()
try:
network = structure.network
except Exception:
network = structure
if layout is None:
pos = nx.spring_layout(network)
else:
pos = layout
if ax is None:
fig, ax = plt.subplots()
if node_colour == 'degree':
coords = list(set(dict(network.degree).values()))
cmap = plt.cm.get_cmap('Dark2', len(coords)).colors
colours = [cmap[coords.index(network.degree[node])] for node in network.nodes()]
elif node_colour == 'partition' and partition is not None:
num_partitions = len(set(partition.values()))
cmap = plt.cm.get_cmap('Dark2', num_partitions).colors
colours = [cmap[partition[node]] for node in network.nodes()]
else:
elem_map = element_colours
colours = [elem_map.get(data['species']) for node, data in network.nodes.data()]
if node_index:
labels = {node: '{} \\#{}'.format(data['species'], node) for node, data in network.nodes.data()}
else:
labels = {node: str(data['species']) for node, data in network.nodes.data()}
edge_colours = []
for edge in network.edges(data=True):
if edge[2].get('image', True):
edge_colours.append('grey')
else:
edge_colours.append('black')
nx.draw_networkx_nodes(network, pos, node_color=colours, edgecolors='black', linewidths=2, node_size=1000, ax=ax)
nx.draw_networkx_edges(network, pos, edge_color=edge_colours, width=2, node_size=1000, ax=ax)
if edge_labels:
edge_weight = dict()
for edge in network.edges(data=True):
# data = edge[2]
edge = (edge[0], edge[1])
if edge not in edge_weight and (edge[1], edge[0]) not in edge_weight:
edge_weight[edge] = 1
else:
if edge in edge_weight:
edge_weight[edge] += 1
else:
edge_weight[(edge[1], edge[0])] += 1
edge_label_dict = edge_weight
nx.draw_networkx_edge_labels(network, pos, edge_labels=edge_label_dict, ax=ax)
nx.draw_networkx_labels(network, pos, labels=labels, ax=ax)
plt.axis('off')