Source code for regraph.backends.neo4j.graphs

"""Neo4j-based persisent graph objects.

This module implements data structures that allow working with persistent
graphs stored in an instance of the Neo4j database.
"""
import os
import json
import warnings

from neo4j.v1 import GraphDatabase

from regraph.graphs import Graph
from regraph.utils import (normalize_attrs,
                           load_nodes_from_json,
                           load_edges_from_json,)
from regraph.exceptions import ReGraphError
from .cypher_utils import generic
from .cypher_utils import rewriting


[docs]class Neo4jGraph(Graph): """Class implementing Neo4j graph instance. This class encapsulates neo4j.v1.GraphDatabase object. It provides an interface for accessing graph sitting in the DB. This interface is similar (in fact is intended to be as similar as possible) to the `networkx.DiGraph` object. Attributes ---------- _driver : neo4j.v1.GraphDatabase Driver providing connection to a Neo4j database _node_label : str Label of nodes inducing the manipulated subgraph. _edge_label : str Type of relations used in the manipulated subgraph. """ def __init__(self, driver=None, uri=None, user=None, password=None, node_label="node", edge_label="edge", unique_node_ids=True): """Initialize Neo4jGraph object. Parameters ---------- driver : neo4j.v1.direct.DirectDriver, optional Driver providing connection to a Neo4j database uri : str, optional Uri for a new Neo4j database connection (bolt) user : str, optional Username for the Neo4j database connection password : str, optional Password for the Neo4j database connection node_label : optional Label of nodes inducing the subgraph to scope. By default `"node"`. edge_label : optional Type of relations inducing the subgraph to scope. By default `"edge"`. unique_node_ids : bool, optional Flag, if True the uniqueness constraint on the property 'id' of nodes is imposed, by default True If database driver is provided, uses it for connecting to database, otherwise creates a new driver object using provided credentials. """ if driver is None: self._driver = GraphDatabase.driver( uri, auth=(user, password)) else: self._driver = driver self._node_label = node_label self._edge_label = edge_label self.unique_node_ids = unique_node_ids if unique_node_ids: try: self._set_constraint('id') except: warnings.warn( "Failed to create id uniqueness constraint") def _execute(self, query): """Execute a Cypher query.""" with self._driver.session() as session: if len(query) > 0: result = session.run(query) return result def _close(self): """Close connection to the database.""" self._driver.close() def _clear(self): """Clear graph database. Returns ------- result : BoltStatementResult """ query = generic.clear_graph(self._node_label) result = self._execute(query) return result def _set_constraint(self, prop): """Set a uniqueness constraint on the property. Parameters ---------- prop : str Name of the property that is required to be unique for the nodes of the database Returns ------- result : BoltStatementResult """ query = "CREATE " + generic.constraint_query( 'n', self._node_label, prop) result = self._execute(query) return result def _drop_constraint(self, prop): """Drop a uniqueness constraint on the property. Parameters ---------- prop : str Name of the property Returns ------- result : BoltStatementResult """ try: query = "DROP " + generic.constraint_query('n', self._node_label, prop) result = self.execute(query) return result except: warnings.warn("Failed to drop constraint")
[docs] def nodes(self, data=False): """Return a list of nodes of the graph.""" query = generic.get_nodes(node_label=self._node_label, data=data) result = self._execute(query) node_list = [] for d in result: node_id = d["node_id"] if data: attrs = d["attrs"] del attrs["id"] normalize_attrs(attrs) node_list.append((node_id, attrs)) else: node_list.append(node_id) return node_list
[docs] def edges(self, data=False): """Return the list of edges of the graph.""" query = generic.get_edges( self._node_label, self._node_label, self._edge_label, data=data) result = self._execute(query) edges = [] for d in result: if d["source_id"] not in self.nodes(): s = int(d["source_id"]) else: s = d["source_id"] if d["target_id"] not in self.nodes(): t = int(d["target_id"]) else: t = d["target_id"] if data: normalize_attrs(d["attrs"]) edges.append((s, t, d["attrs"])) else: edges.append((s, t)) return edges
[docs] def get_node(self, node_id): """Get node attributes. Parameters ---------- graph : networkx.(Di)Graph or regraph.neo4j.Neo4jGraph node_id : hashable, node id. """ query = generic.get_node_attrs( node_id, self._node_label, "attributes") result = self._execute(query) attrs = generic.properties_to_attributes( result, "attributes") return attrs
[docs] def get_edge(self, s, t): """Get edge attributes. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. """ query = generic.get_edge_attrs( s, t, self._node_label, self._edge_label, "attributes") result = self._execute(query) return generic.properties_to_attributes( result, "attributes")
[docs] def add_node(self, node, attrs=None, ignore_naming=False): """Abstract method for adding a node. Parameters ---------- node : hashable Prefix that is prepended to the new unique name. attrs : dict, optional Node attributes. """ if attrs is None: attrs = dict() normalize_attrs(attrs) query =\ rewriting.add_node( "n", node, 'new_id', node_label=self._node_label, attrs=attrs, literal_id=True, ignore_naming=ignore_naming)[0] +\ generic.return_vars(['new_id']) result = self._execute(query) new_id = result.single()['new_id'] return new_id
[docs] def remove_node(self, node): """Remove node. Parameters ---------- graph : networkx.(Di)Graph node_id : hashable, node to remove. """ query =\ generic.match_node( "n", node, node_label=self._node_label) +\ rewriting.remove_node("n") result = self._execute(query) return result
[docs] def add_edge(self, s, t, attrs=None, **attr): """Add an edge to a graph. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. attrs : dict Edge attributes. """ if attrs is None: attrs = dict() normalize_attrs(attrs) query = generic.match_nodes( {"s": s, "t": t}, node_label=self._node_label) query += rewriting.add_edge( edge_var='new_edge', source_var="s", target_var="t", edge_label=self._edge_label, attrs=attrs) result = self._execute(query) return result
[docs] def remove_edge(self, s, t): """Remove edge from the graph. Parameters ---------- graph : networkx.(Di)Graph s : hashable, source node id. t : hashable, target node id. """ query =\ generic.match_edge( "s", "t", s, t, 'edge_var', self._node_label, self._node_label, edge_label='edge') +\ rewriting.remove_edge('edge_var') result = self._execute(query) return result
[docs] def update_node_attrs(self, node_id, attrs, normalize=True): """Update attributes of a node. Parameters ---------- node_id : hashable, node to update. attrs : dict New attributes to assign to the node """ normalize_attrs(attrs) query = ( generic.match_node("n", node_id, self._node_label) + generic.set_attributes("n", attrs, update=True) ) result = self._execute(query) return result
[docs] def update_edge_attrs(self, s, t, attrs, normalize=True): """Update attributes of a node. Parameters ---------- s : hashable, source node of the edge to update. t : hashable, target node of the edge to update. attrs : dict New attributes to assign to the node """ normalize_attrs(attrs) query = ( generic.match_edge( "s", "t", s, t, "rel", self._node_label, self._node_label, self._edge_label) + generic.set_attributes("rel", attrs, update=True) ) result = self._execute(query) return result
[docs] def successors(self, node_id): """Return the set of successors.""" query = generic.successors_query( node_id, node_id, node_label=self._node_label, edge_label=self._edge_label) result = self._execute(query) succ = set() for record in result: if record["suc"] is not None: succ.add(record["suc"]) return succ
[docs] def predecessors(self, node_id): """Return the set of predecessors.""" query = generic.predecessors_query( node_id, node_id, node_label=self._node_label, edge_label=self._edge_label) result = self._execute(query) pred = set() for record in result: if record["pred"] is not None: pred.add(record["pred"]) return pred
[docs] def find_matching(self, pattern, nodes=None, graph_typing=None, pattern_typing=None): """Find matching of a pattern in a graph.""" if len(pattern.nodes()) != 0: # filter nodes by typing matching_nodes = set() for pattern_node in pattern.nodes(): for node in self.nodes(): type_matches = True if pattern_typing: # check types match for graph, pattern_mapping in pattern_typing.items(): if node in graph_typing[graph].keys() and\ pattern_node in pattern_mapping.keys(): if graph_typing[graph][node] != pattern_mapping[ pattern_node]: type_matches = False if type_matches and nodes and node in nodes: matching_nodes.add(node) query = rewriting.find_matching( pattern, node_label=self._node_label, edge_label=self._edge_label, nodes=matching_nodes, pattern_typing=pattern_typing) result = self._execute(query) instances = list() for record in result: instance = dict() for k, v in record.items(): instance[k] = dict(v)["id"] new_instance = dict() for pattern_node, v in instance.items(): if pattern_node not in pattern.nodes(): new_instance[int(pattern_node)] = v else: new_instance[pattern_node] = v instances.append(new_instance) else: instances = [] return instances
[docs] def relabel_node(self, node_id, new_id): """Relabel a node in the graph. Parameters ---------- node_id : hashable Id of the node to relabel. new_id : hashable New label of a node. """ if new_id in self.nodes(): raise ReGraphError( "Cannot relabel '{}' to '{}', '{}' ".format( node_id, new_id, new_id) + "already exists in the graph") query = generic.set_id(self._node_label, node_id, new_id) result = self._execute(query) return result
[docs] @classmethod def from_json(cls, driver=None, uri=None, user=None, password=None, json_data=None, node_label="node", edge_label="edge"): """Create a Neo4jGraph from a json-like dictionary. Parameters ---------- json_data : dict JSON-like dictionary with graph representation """ graph = cls( driver=driver, uri=uri, user=user, password=password, node_label=node_label, edge_label=edge_label) graph.add_nodes_from(load_nodes_from_json(json_data)) graph.add_edges_from(load_edges_from_json(json_data)) return graph
[docs] @classmethod def load(cls, driver=None, uri=None, user=None, password=None, filename=None, node_label="node", edge_label="edge"): """Load a Neo4jGraph from a JSON file. Create a graph object from a JSON representation stored in a file. Parameters ---------- driver : neo4j.v1.direct.DirectDriver, optional Driver providing connection to a Neo4j database uri : str, optional Uri for a new Neo4j database connection (bolt) user : str, optional Username for the Neo4j database connection password : str, optional Password for the Neo4j database connection filename : str, optional Name of the file to load the json serialization of the graph node_label : optional Label of nodes inducing the subgraph to scope. By default `"node"`. edge_label : optional Type of relations inducing the subgraph to scope. By default `"edge"`. Returns ------- Graph object Raises ------ ReGraphError If was not able to load the file """ if os.path.isfile(filename): with open(filename, "r+") as f: j_data = json.loads(f.read()) return cls.from_json( driver=driver, uri=uri, user=user, password=password, json_data=j_data, node_label=node_label, edge_label=edge_label) else: raise ReGraphError( "Error loading graph: file '{}' does not exist!".format( filename) )