Relationship embedding models
A not too uncommon scenario may be that one has trained a knowledge graph embedding (KGE) model outside of the Graph Data Science (GDS) library, and is storing the model training’s output in a Neo4j database. For such cases, GDS has support for using such KGE model output and KGE scoring function to infer new relationships of a GDS graph projection. Currently the scoring functions supported are TransE and DistMult.
Below we will walk through how to use these capabilities. First by having a look at the methods and their signatures, and then by going through an end-to-end example on a small toy graph.
In the examples below we assume that we have an instantiated GraphDataScience
object called gds
.
Read more about this in Getting started.
1. Creating a relationship embedding model
The first part of the workflow of using a pre-trained KGE model to predict new relationships in GDS, is to create a relationship embedding model.
There are two methods one for doing this, one for each supported KGE scoring function:
-
gds.model.transe.create
for creating a model using the TransE scoring function, and -
gds.model.distmult.create
for creating a model using the DistMult scoring function.
Both of these methods return a SimpleRelEmbeddingModel
whose usage we will look into shortly.
They also take the same parameters:
Name | Type | |
---|---|---|
|
|
The object representing the graph the model is trained on |
|
|
The name of the node property under which the KGE model embeddings are stored |
|
|
A mapping of relationship type names to the KGE model’s relationship type embeddings |
2. Making predictions with the relationship embedding model
The SimpleRelEmbeddingModel
represents a relationship embedding model based on a KGE model.
It has three methods for predicting new relationships.
The computation of inferring new embeddings is the same, but how the new relationships are handled after that differs.
This class has three methods:
-
predict_stream
for streaming back the predicted relationships, -
predict_mutate
for adding the relationships to the projected graph, -
predict_write
for writing back the relationships to the Neo4j database.
Since the prediction parts of the computations in these methods are the same, the methods share a set of parameters:
Name | Type | |
---|---|---|
|
|
The specification of source nodes to consider. Either a node label, a node ID, or a list of node IDs |
|
|
The specification of source nodes to consider. Either a node label, a node ID, or a list of node IDs |
|
|
The name of the relationship type whose embedding will be used in the computation |
|
|
How many relationships to produce for each source node. The target nodes with the highest |
|
|
General GDS algorithm configuration as optional keyword parameters |
In particular, the general algorithm configuration parameters supported as keyword parameters for this algorithm are concurrency
, jobId
and logProgress
.
You can read more about them here in the GDS manual.
Let us now outline the differences between these prediction methods.
2.1. Streaming predicted relationships
The predict_stream
method returns a pandas.DataFrame
which contains three columns: sourceNodeId
, targetNodeId
and score
.
These refer to the source node ID, the target node ID, and the score from running the KGE model scoring function on the node pair and relationship type, respectively.
There are no extra parameters to this method other than the ones outlined above.
2.2. Mutating graph projection with predicted relationships
The predict_mutate
method adds the predicted relationships to the graph projection under a new type specified via the mutate_relationship_type
parameter.
Each such relationship will have a property, specified via the mutateProperty
parameter, representing the output from running the KGE model scoring function on the node pair and relationship type.
The method returns a pandas.Series
with metadata about the computation:
In addition to the shared parameters outlined above, this method takes two more positional parameters, after the top_k
parameter, in order:
Name | Type | |
---|---|---|
|
|
The name of the new relationship type for the predicted relationships |
|
|
The name of the property on the new relationships which will store the model prediction score |
Name | Type | |
---|---|---|
|
|
The number of relationships created |
|
|
Milliseconds for adding properties to the projected graph |
|
|
Milliseconds for computing percentiles |
|
|
Milliseconds for preprocessing the data |
|
|
Milliseconds for running the prediction algorithm |
|
|
The configuration used for running the algorithm |
2.3. Writing back predicted relationships to database
The predict_write
method writes back the predicted relationships to the Neo4j database under a new type specified via the write_relationship_type
parameter.
Each such relationship will have a property, specified via the writeProperty
parameter, representing the output from running the KGE model scoring function on the node pair and relationship type.
In addition to the shared parameters outlined above, this method takes two more positional parameters, after the top_k
parameter, in order:
Name | Type | |
---|---|---|
|
|
The name of the new relationship type for the predicted relationships |
|
|
The name of the property on the new relationships which will store the model prediction score |
The method returns a pandas.Series
with metadata about the computation:
Name | Type | |
---|---|---|
|
|
The number of relationships created |
|
|
Milliseconds for writing result data back the Neo4j database |
|
|
Milliseconds for preprocessing the data |
|
|
Milliseconds for running the prediction algorithm |
|
|
The configuration used for running the algorithm |
3. Inspecting relationship embedding models
There are a few methods on the SimpleRelEmbeddingModel
class which lets us inspect it.
None of them take any input, but simply return information about the model.
They are listed below.
Name | Return type | Description |
---|---|---|
|
|
Returns the name of the scoring function the model is using |
|
|
Returns the name of the graph the model is based on |
|
|
Returns the name of the node property storing embeddings in the graph |
|
|
Returns the relationship type embeddings of the model |
4. Example
In this section, we will exemplify creating and using a relationship embedding model based on a KGE model trained using the TransE scoring function.
Part of this will be having a Graph
which represents a projection containing KGE model embeddings.
So we start by introducing a small road-network graph with some inhabitants:
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624, emb: [0.52173235, 0.85803989, 0.31678055]}),
(b:City {name: "Philadelphia", settled: 1682, emb: [0.61455845, 0.79957553, 0.83513986]}),
(c:City:Capital {name: "Washington D.C.", settled: 1790, emb: [0.54354943, 0.64039515, 0.23094848]}),
(d:City {name: "Baltimore", settled: 1729, emb: [0.67689553, 0.28851121, 0.43250516]}),
(e:City {name: "Atlantic City", settled: 1854, emb: [0.79804478, 0.81980933, 0.9322812]}),
(f:City {name: "Boston", settled: 1822, emb: [0.15583946, 0.16060805, 0.52078528]}),
(g:Person {name: "Brian", emb: [0.4142066 , 0.18411476, 0.68245374]}),
(h:Person {name: "Olga", emb: [0.61230904, 0.7735076 , 0.09668418]}),
(i:Person {name: "Jacob", emb: [0.87470625, 0.63589938, 0.33536311]}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f),
(g)-[:LIVES_IN]->(a),
(h)-[:LIVES_IN]->(f),
(i)-[:LIVES_IN]->(e);
"""
)
G, project_result = gds.graph.project(
graph_name="road_graph",
node_spec={"City": {"properties": ["emb"]}, "Person": {"properties": ["emb"]}},
relationship_spec=["ROAD", "LIVES_IN"]
)
# Sanity check
assert G.relationship_count() == 12
The "emb"
node property here contains the TransE node embeddings we will use in our computation to infer new relationships.
4.1. Creating and inspecting our model
Using our graph G
and our precomputed relationship type embeddings we can now construct a TransE relationship embedding model.
transe_model = gds.model.transe.create(
G,
node_embedding_property="emb",
relationship_type_embeddings={
"ROAD": [0.88355126, 0.15116676, 0.24225456],
"LIVES_IN": [0.94185368, 0.60460752, 0.92028837]
}
)
# Sanity check
assert transe_model.scoring_function() == "transe"
With our model created we can start predicting new relationships of our graphs.
4.2. Making predictions
Let’s have our model predict where our three inhabitants of interest might be likely to move in the future, and mutate our GDS projection represented by G
with these new relationships.
result = transe_model.predict_mutate(
source_node_filter="Person",
target_node_filter="City",
relationship_type="LIVES_IN",
top_k=1,
mutate_relationship_type="MIGHT_MOVE",
mutate_property="likeliness_score"
)
# Let us make sure the number of new relationships makes sense
assert result["relationshipsWritten"] == 3
assert G.relationship_count() == 12 + 3
Using TransE embeddings and the relationship embedding model capabilities of GDS we were able to infer where our inhabitants of interest might move in the future.
The new "MIGHT_MOVE"
relationships we created are now part of the GDS graph projection represented by G
.