Source code for regraph.backends.networkx.graphs

"""NetworkX-based in-memory graph objects.

This module implements data structures wrapping the `networkx.DiGraph` class.
"""
import itertools
import networkx as nx
from networkx.algorithms import isomorphism

import warnings

from regraph.exceptions import (ReGraphError,
                                GraphError,
                                GraphAttrsWarning,
                                )
from regraph.graphs import Graph

from regraph.utils import (normalize_attrs,
                           safe_deepcopy_dict,
                           valid_attributes,
                           )


[docs]class NXGraph(Graph): """Wrapper for NetworkX directed graphs.""" node_dict_factory = dict adj_dict_factory = dict def __init__(self, incoming_graph_data=None, **attr): """Initialize NetworkX graph.""" self.node_dict_factory = ndf = self.node_dict_factory self.adj_dict_factory = adf = self.adj_dict_factory super().__init__() self._graph = nx.DiGraph() self.node = ndf() self.adj = adf()
[docs] def nodes(self, data=False): """Return the list of nodes.""" if data: return [(n, self.get_node(n)) for n in self._graph.nodes()] else: return self._graph.nodes()
[docs] def edges(self, data=False): """Return the list of edges.""" if data: return [(s, t, self.get_edge(s, t)) for s, t in self.edges()] return self._graph.edges()
[docs] def get_node(self, n): """Get node attributes. Parameters ---------- n : hashable Node id. """ return self._graph.node[n]
[docs] def get_edge(self, s, t): """Get edge attributes. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. """ return self._graph.adj[s][t]
[docs] def add_node(self, node_id, attrs=None): """Abstract method for adding a node. Parameters ---------- node_id : hashable Prefix that is prepended to the new unique name. attrs : dict, optional Node attributes. """ if attrs is None: new_attrs = dict() else: new_attrs = safe_deepcopy_dict(attrs) normalize_attrs(new_attrs) if node_id not in self.nodes(): self._graph.add_node(node_id) self.node[node_id] = dict() for k, v in new_attrs.items(): self._graph.node[node_id][k] = v self.node[node_id][k] = v return node_id else: raise GraphError("Node '{}' already exists!".format(node_id))
[docs] def remove_node(self, node_id): """Remove node. Parameters ---------- graph : networkx.(Di)Graph node_id : hashable, node to remove. """ if node_id in self.nodes(): self._graph.remove_node(node_id) del self.node[node_id] if node_id in self.adj.items(): del self.adj[node_id] for k, v in self.adj.items(): if v == node_id: del self.adj[k][v] else: raise GraphError("Node '{}' does not exist!".format(node_id)) return
[docs] def add_edge(self, s, t, attrs=None, **attr): """Add an edge to a graph. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. attrs : dict Edge attributes. """ if attrs is None: attrs = attr else: try: attrs.update(attr) except AttributeError: raise ReGraphError( "The attr_dict argument must be a dictionary." ) new_attrs = safe_deepcopy_dict(attrs) if s not in self.nodes(): raise GraphError("Node '{}' does not exist!".format(s)) if t not in self.nodes(): raise GraphError("Node '{}' does not exist!".format(t)) normalize_attrs(new_attrs) if (s, t) in self.edges(): raise GraphError( "Edge '{}'->'{}' already exists!".format(s, t)) self._graph.add_edge(s, t, **new_attrs) if s in self.adj.keys(): self.adj[s][t] = new_attrs else: self.adj[s] = {t: new_attrs}
[docs] def remove_edge(self, s, t): """Remove edge from the graph. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. """ if (s, t) not in self.edges(): raise GraphError( "Edge '{}->{}' does not exist!".format(s, t)) self._graph.remove_edge(s, t) del self.adj[s][t]
[docs] def update_node_attrs(self, node_id, attrs, normalize=True): """Update attributes of a node. Parameters ---------- node_id : hashable, node to update. attrs : dict New attributes to assign to the node """ new_attrs = safe_deepcopy_dict(attrs) if node_id not in self.nodes(): raise GraphError( "Node '{}' does not exist!".format(node_id)) elif new_attrs is None: warnings.warn( "You want to update '{}' attrs with an empty attrs_dict!".format( node_id), GraphAttrsWarning ) else: if normalize is True: normalize_attrs(new_attrs) attrs_to_remove = set() for k in self._graph.node[node_id].keys(): if k not in new_attrs.keys(): attrs_to_remove.add(k) self._graph.add_node(node_id, **new_attrs) self.node[node_id] = new_attrs for k in attrs_to_remove: del self._graph.node[node_id][k]
[docs] def update_edge_attrs(self, s, t, attrs, normalize=True): """Update attributes of a node. Parameters ---------- s : hashable, source node of the edge to update. t : hashable, target node of the edge to update. attrs : dict New attributes to assign to the node """ if not self._graph.has_edge(s, t): raise GraphError("Edge '{}->{}' does not exist!".format( s, t)) if attrs is None: warnings.warn( "You want to update '{}->{}' attrs with an empty attrs_dict".format( s, t), GraphAttrsWarning ) if normalize is True: normalize_attrs(attrs) attrs_to_remove = set() for k in self._graph.adj[s][t].keys(): if k not in attrs.keys(): attrs_to_remove.add(k) self._graph.add_edge(s, t, **attrs) self.adj[s][t] = attrs for k in attrs_to_remove: del self._graph.adj[s][t][k]
[docs] def successors(self, node_id): """Return the set of successors.""" return self._graph.successors(node_id)
[docs] def predecessors(self, node_id): """Return the set of predecessors.""" if node_id not in self.nodes(): raise GraphError( "Node '{}' does not exist in the graph".format( node_id)) return self._graph.predecessors(node_id)
[docs] def get_relabeled_graph(self, mapping): """Return a graph with node labeling specified in the mapping. Parameters ---------- graph : networkx.(Di)Graph mapping: dict A dictionary with keys being old node ids and their values being new id's of the respective nodes. Returns ------- g : networkx.(Di)Graph New graph object isomorphic to the `graph` with the relabled nodes. Raises ------ ReGraphError If new id's do not define a set of distinct node id's. See also -------- regraph.primitives.relabel_nodes """ g = nx.DiGraph() old_nodes = set(mapping.keys()) for old_node in old_nodes: try: new_node = mapping[old_node] except KeyError: continue try: g.add_node( new_node, **self.get_node(old_node)) except KeyError: raise GraphError("Node '%s' does not exist!" % old_node) new_edges = list() attributes = dict() for s, t in self.edges(): new_edges.append(( mapping[s], mapping[t])) attributes[(mapping[s], mapping[t])] =\ self.get_edge(s, t) g.add_edges_from(new_edges) for s, t in g.edges(): for k, v in attributes[(s, t)].items(): g.adj[s][t][k] = v return g
[docs] def subgraph(self, nodes): """Get a subgraph induced by the collection of nodes.""" g = NXGraph() g.add_nodes_from([ (n, attrs) for n, attrs in self.nodes(data=True) if n in nodes]) for s, t, attrs in self.edges(data=True): if s in g.nodes() and t in g.nodes(): g.add_edge(s, t, attrs) return g
[docs] def find_matching(self, pattern, nodes=None, graph_typing=None, pattern_typing=None): """Find matching of a pattern in a graph. This function takes as an input a graph and a pattern, optionally, it also takes a collection of nodes specifying the subgraph of the original graph, where the matching should be searched in, then it searches for a matching of the pattern inside of the graph (or induced subragh), which corresponds to solving subgraph matching problem. The matching is defined by a map from the nodes of the pattern to the nodes of the graph such that: * edges are preserved, i.e. if there is an edge between nodes `n1` and `n2` in the pattern, there is an edge between the nodes of the graph that correspond to the image of `n1` and `n2`, moreover, the attribute dictionary of the edge between `n1` and `n2` is the subdictiotary of the edge it corresponds to in the graph; * the attribute dictionary of a pattern node is a subdictionary of its image in the graph; Uses `networkx.isomorphism.(Di)GraphMatcher` class, which implements subgraph matching algorithm. In addition, two parameters `graph_typing` and `pattern_typing` can be specified. They restrict the space of admisible solutions by checking if an isomorphic subgraph found in the input graph respects the provided pattern typings according to the specified graph typings. Parameters ---------- graph : nx.(Di)Graph pattern : nx.(Di)Graph Pattern graph to search for nodes : iterable, optional Subset of nodes to search for matching graph_typing : dict of dict, optional Dictionary defining typing of graph nodes pattern_typing : dict of dict, optional Dictionary definiting typing of pattern nodes Returns ------- instances : list of dict's List of instances of matching found in the graph, every instance is represented with a dictionary where keys are nodes of the pattern, and values are corresponding nodes of the graph. """ if pattern_typing is None: pattern_typing = {} if graph_typing is None: graph_typing = {} # check graph/pattern typing is consistent for g, mapping in pattern_typing.items(): if g not in graph_typing: raise ReGraphError( "Graph is not typed by '{}' from the specified ".format( g) + "pattern typing") if nodes is not None: g = self._graph.subgraph(nodes) else: g = self._graph labels_mapping = dict([(n, i + 1) for i, n in enumerate(g.nodes())]) g = self.get_relabeled_graph(labels_mapping) inverse_mapping = dict( [(value, key) for key, value in labels_mapping.items()] ) matching_nodes = set() # find all the nodes matching the nodes in pattern for pattern_node in pattern.nodes(): for node in g.nodes(): if pattern_typing: # check types match match = False for graph, pattern_mapping in pattern_typing.items(): if node in graph_typing[graph].keys() and\ pattern_node in pattern_mapping.keys(): if graph_typing[graph][node] == pattern_mapping[ pattern_node]: if valid_attributes( pattern.node[pattern_node], g.node[node]): match = True else: if valid_attributes( pattern.node[pattern_node], g.node[node]): match = True if match: matching_nodes.add(node) else: if valid_attributes( pattern.node[pattern_node], g.node[node]): matching_nodes.add(node) # find all the isomorphic subgraphs reduced_graph = g.subgraph(matching_nodes) instances = [] isomorphic_subgraphs = [] for sub_nodes in itertools.combinations(reduced_graph.nodes(), len(pattern.nodes())): subg = reduced_graph.subgraph(sub_nodes) for edgeset in itertools.combinations(subg.edges(), len(pattern.edges())): edge_induced_graph = nx.DiGraph(list(edgeset)) edge_induced_graph.add_nodes_from( [n for n in subg.nodes() if n not in edge_induced_graph.nodes()]) if isinstance(pattern, Graph): matching_obj = isomorphism.DiGraphMatcher( pattern._graph, edge_induced_graph) else: matching_obj = isomorphism.DiGraphMatcher( pattern, edge_induced_graph) for isom in matching_obj.isomorphisms_iter(): isomorphic_subgraphs.append((subg, isom)) for subgraph, mapping in isomorphic_subgraphs: # print(subgraph.nodes(), mapping) # check node matches # exclude subgraphs which nodes information does not # correspond to pattern for (pattern_node, node) in mapping.items(): if pattern_typing: for g, pattern_mapping in pattern_typing.items(): if inverse_mapping[node] in graph_typing[g].keys() and\ pattern_node in pattern_mapping.keys(): if graph_typing[g][ inverse_mapping[node]] != pattern_mapping[ pattern_node]: break if not valid_attributes( pattern.node[pattern_node], subgraph.node[node]): break else: continue break else: if not valid_attributes( pattern.node[pattern_node], subgraph.node[node]): break else: # check edge attribute matched for edge in pattern.edges(): pattern_attrs = pattern.adj[edge[0]][edge[1]] target_attrs = subgraph.adj[ mapping[edge[0]]][mapping[edge[1]]] if not valid_attributes(pattern_attrs, target_attrs): break else: instances.append(mapping) # bring back original labeling for instance in instances: for key, value in instance.items(): instance[key] = inverse_mapping[value] return instances
[docs] @classmethod def copy(cls, graph): """Copy the input graph object.""" new_graph = cls() new_graph.add_nodes_from(graph.nodes(data=True)) new_graph.add_edges_from(graph.edges(data=True)) return new_graph