Relationship embedding models

A not too uncommon scenario may be that one has trained a knowledge graph embedding (KGE) model outside of the Graph Data Science (GDS) library, and is storing the model training’s output in a Neo4j database. For such cases, GDS has support for using such KGE model output and KGE scoring function to infer new relationships of a GDS graph projection. Currently the scoring functions supported are TransE and DistMult.

Below we will walk through how to use these capabilities. First by having a look at the methods and their signatures, and then by going through an end-to-end example on a small toy graph.

In the examples below we assume that we have an instantiated GraphDataScience object called gds. Read more about this in Getting started.

1. Creating a relationship embedding model

The first part of the workflow of using a pre-trained KGE model to predict new relationships in GDS, is to create a relationship embedding model.

There are two methods one for doing this, one for each supported KGE scoring function:

  • gds.model.transe.create for creating a model using the TransE scoring function, and

  • gds.model.distmult.create for creating a model using the DistMult scoring function.

Both of these methods return a SimpleRelEmbeddingModel whose usage we will look into shortly. They also take the same parameters:

Table 1. KGE based relationship model creation parameters
Name Type

G

Graph

The object representing the graph the model is trained on

node_embedding_property

str

The name of the node property under which the KGE model embeddings are stored

relationship_type_embeddings

dict[str, list[float]]

A mapping of relationship type names to the KGE model’s relationship type embeddings

2. Making predictions with the relationship embedding model

The SimpleRelEmbeddingModel represents a relationship embedding model based on a KGE model. It has three methods for predicting new relationships. The computation of inferring new embeddings is the same, but how the new relationships are handled after that differs.

This class has three methods:

  • predict_stream for streaming back the predicted relationships,

  • predict_mutate for adding the relationships to the projected graph,

  • predict_write for writing back the relationships to the Neo4j database.

Since the prediction parts of the computations in these methods are the same, the methods share a set of parameters:

Table 2. Shared relationship embedding model prediction parameters
Name Type

source_node_filter

Union[str, int, list[int]]

The specification of source nodes to consider. Either a node label, a node ID, or a list of node IDs

target_node_filter

Union[str, int, list[int]]

The specification of source nodes to consider. Either a node label, a node ID, or a list of node IDs

relationship_type

str

The name of the relationship type whose embedding will be used in the computation

top_k

int

How many relationships to produce for each source node. The target nodes with the highest top_k scores will be kept for each source node

general_config

**dict[str, Any]

General GDS algorithm configuration as optional keyword parameters

In particular, the general algorithm configuration parameters supported as keyword parameters for this algorithm are concurrency, jobId and logProgress. You can read more about them here in the GDS manual.

Let us now outline the differences between these prediction methods.

2.1. Streaming predicted relationships

The predict_stream method returns a pandas.DataFrame which contains three columns: sourceNodeId, targetNodeId and score. These refer to the source node ID, the target node ID, and the score from running the KGE model scoring function on the node pair and relationship type, respectively.

There are no extra parameters to this method other than the ones outlined above.

2.2. Mutating graph projection with predicted relationships

The predict_mutate method adds the predicted relationships to the graph projection under a new type specified via the mutate_relationship_type parameter. Each such relationship will have a property, specified via the mutateProperty parameter, representing the output from running the KGE model scoring function on the node pair and relationship type. The method returns a pandas.Series with metadata about the computation:

In addition to the shared parameters outlined above, this method takes two more positional parameters, after the top_k parameter, in order:

Table 3. Input parameters specific to .predict_mutate
Name Type

mutate_relationship_type

str

The name of the new relationship type for the predicted relationships

mutate_property

str

The name of the property on the new relationships which will store the model prediction score

Table 4. Fields of the pandas.Series object returned by .predict_mutate
Name Type

relationshipsWritten

int

The number of relationships created

mutateMillis

int

Milliseconds for adding properties to the projected graph

postProcessingMillis

int

Milliseconds for computing percentiles

preProcessingMillis

int

Milliseconds for preprocessing the data

computeMillis

int

Milliseconds for running the prediction algorithm

configuration

dict[str, Any]

The configuration used for running the algorithm

2.3. Writing back predicted relationships to database

The predict_write method writes back the predicted relationships to the Neo4j database under a new type specified via the write_relationship_type parameter. Each such relationship will have a property, specified via the writeProperty parameter, representing the output from running the KGE model scoring function on the node pair and relationship type.

In addition to the shared parameters outlined above, this method takes two more positional parameters, after the top_k parameter, in order:

Table 5. Input parameters specific to .predict_write
Name Type

write_relationship_type

str

The name of the new relationship type for the predicted relationships

write_property

str

The name of the property on the new relationships which will store the model prediction score

The method returns a pandas.Series with metadata about the computation:

Table 6. Fields of the pandas.Series object returned by .predict_write
Name Type

relationshipsWritten

int

The number of relationships created

writeMillis

int

Milliseconds for writing result data back the Neo4j database

preProcessingMillis

int

Milliseconds for preprocessing the data

computeMillis

int

Milliseconds for running the prediction algorithm

configuration

dict[str, Any]

The configuration used for running the algorithm

3. Inspecting relationship embedding models

There are a few methods on the SimpleRelEmbeddingModel class which lets us inspect it. None of them take any input, but simply return information about the model. They are listed below.

Table 7. SimpleRelEmbeddingModel getters methods for inspection
Name Return type Description

scoring_function

str

Returns the name of the scoring function the model is using

graph_name

str

Returns the name of the graph the model is based on

node_embedding_property

str

Returns the name of the node property storing embeddings in the graph

relationship_type_embeddings

dict[str, list[float]]

Returns the relationship type embeddings of the model

4. Example

In this section, we will exemplify creating and using a relationship embedding model based on a KGE model trained using the TransE scoring function. Part of this will be having a Graph which represents a projection containing KGE model embeddings.

So we start by introducing a small road-network graph with some inhabitants:

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624, emb: [0.52173235, 0.85803989, 0.31678055]}),
    (b:City {name: "Philadelphia", settled: 1682, emb: [0.61455845, 0.79957553, 0.83513986]}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790, emb: [0.54354943, 0.64039515, 0.23094848]}),
    (d:City {name: "Baltimore", settled: 1729, emb: [0.67689553, 0.28851121, 0.43250516]}),
    (e:City {name: "Atlantic City", settled: 1854, emb: [0.79804478, 0.81980933, 0.9322812]}),
    (f:City {name: "Boston", settled: 1822, emb: [0.15583946, 0.16060805, 0.52078528]}),

    (g:Person {name: "Brian", emb: [0.4142066 , 0.18411476, 0.68245374]}),
    (h:Person {name: "Olga", emb: [0.61230904, 0.7735076 , 0.09668418]}),
    (i:Person {name: "Jacob", emb: [0.87470625, 0.63589938, 0.33536311]}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f),

    (g)-[:LIVES_IN]->(a),
    (h)-[:LIVES_IN]->(f),
    (i)-[:LIVES_IN]->(e);
  """
)
G, project_result = gds.graph.project(
    graph_name="road_graph",
    node_spec={"City": {"properties": ["emb"]}, "Person": {"properties": ["emb"]}},
    relationship_spec=["ROAD", "LIVES_IN"]
)

# Sanity check
assert G.relationship_count() == 12

The "emb" node property here contains the TransE node embeddings we will use in our computation to infer new relationships.

4.1. Creating and inspecting our model

Using our graph G and our precomputed relationship type embeddings we can now construct a TransE relationship embedding model.

transe_model = gds.model.transe.create(
    G,
    node_embedding_property="emb",
    relationship_type_embeddings={
        "ROAD": [0.88355126, 0.15116676, 0.24225456],
        "LIVES_IN": [0.94185368, 0.60460752, 0.92028837]
    }
)

# Sanity check
assert transe_model.scoring_function() == "transe"

With our model created we can start predicting new relationships of our graphs.

4.2. Making predictions

Let’s have our model predict where our three inhabitants of interest might be likely to move in the future, and mutate our GDS projection represented by G with these new relationships.

result = transe_model.predict_mutate(
    source_node_filter="Person",
    target_node_filter="City",
    relationship_type="LIVES_IN",
    top_k=1,
    mutate_relationship_type="MIGHT_MOVE",
    mutate_property="likeliness_score"
)

# Let us make sure the number of new relationships makes sense
assert result["relationshipsWritten"] == 3
assert G.relationship_count() == 12 + 3

Using TransE embeddings and the relationship embedding model capabilities of GDS we were able to infer where our inhabitants of interest might move in the future. The new "MIGHT_MOVE" relationships we created are now part of the GDS graph projection represented by G.