Model objects from the model catalog
Models of the GDS Model Catalog are represented as Model
objects in the Python client, similar to how there are graph objects.
Model
objects are typically constructed from training a pipeline or a GraphSAGE model, in which case a reference to the trained model in the form of a Model
object is returned.
Once created, the Model
objects can be passed as arguments to methods in the Python client, such as the model catalog operations.
Additionally, the Model
objects have convenience methods allowing for inspection of the models represented without explicitly involving the model catalog.
In the examples below we assume that we have an instantiated GraphDataScience
object called gds
.
Read more about this in Getting started.
1. Constructing a model object
The primary way to construct a model object is through training a model. There are two types of models: pipeline models and GraphSAGE models. In order to train a pipeline model, a pipeline must first be created and configured Read more about how to operate pipelines in Machine learning pipelines, including examples of using pipeline models. In this section, we will exemplify creating and using a GraphSAGE model object.
First, we introduce a small road-network graph:
gds.run_cypher(
"""
CREATE
(a:City {name: "New York City", settled: 1624}),
(b:City {name: "Philadelphia", settled: 1682}),
(c:City:Capital {name: "Washington D.C.", settled: 1790}),
(d:City {name: "Baltimore", settled: 1729}),
(e:City {name: "Atlantic City", settled: 1854}),
(f:City {name: "Boston", settled: 1822}),
(a)-[:ROAD {cost: 50}]->(b),
(a)-[:ROAD {cost: 50}]->(c),
(a)-[:ROAD {cost: 100}]->(d),
(b)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 40}]->(d),
(c)-[:ROAD {cost: 80}]->(e),
(d)-[:ROAD {cost: 30}]->(e),
(d)-[:ROAD {cost: 80}]->(f),
(e)-[:ROAD {cost: 40}]->(f);
"""
)
G, project_result = gds.graph.project(
"road_graph",
{"City": {"properties": ["settled"]}},
{"ROAD": {"properties": ["cost"]}}
)
assert G.relationship_count() == 9
Now we can use the graph G
to train a GraphSage model.
model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)
assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1
where model
is the model object, and res
is a pandas Series
containing metadata from the underlying procedure call.
Similarly, we can also get model objects from training machine learning pipelines.
To get a model object that represents a model that has already been trained and is present in the model catalog, one can call the client-side only get
method and passing it a name:
model = gds.model.get("city-representation")
assert model.name() == "city-representation"
The |
2. Inspecting a model object
There are convenience methods on all model objects that let us extract information about the represented model.
Name | Arguments | Return type | Description |
---|---|---|---|
|
|
|
The name of the model as it appears in the model catalog. |
|
|
|
The type of model it is, eg. "graphSage". |
|
|
|
The configuration used for training the model. |
|
|
|
The schema of the graph on which the model was trained. |
|
|
|
|
|
|
|
|
|
|
|
Time when the model was created. |
|
|
|
|
|
|
|
|
|
|
|
Removes the model from the GDS Model Catalog. |
For example, to get the train configuration of our model object model
created above, we would do the following:
train_config = model.train_config()
assert train_config["concurrency"] == 4
3. Using a model object
The primary way to use model objects is for prediction. How to do so for GraphSAGE is described below, and on the Machine learning pipelines page for pipelines.
Additionally, model objects can be used as input to GDS Model Catalog operations.
For instance, supposing we have our model object model
created above, we could:
# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)
gds.beta.model.drop(model) # same as model.drop()
# Load the model again for further use
gds.alpha.model.load(model.name())
3.1. GraphSAGE
As exemplified above in Constructing a model object, training a GraphSAGE model with the Python client is analogous to its Cypher counterpart.
Once trained, in addition to the methods above, the GraphSAGE model object will have the following methods.
Name | Arguments | Return type | Description |
---|---|---|---|
|
|
|
Predict embeddings for nodes of the input graph and mutate graph with predictions. |
|
|
|
Predict embeddings for nodes of the input graph and stream the results. |
|
|
|
Predict embeddings for nodes of the input graph and write the results back to the database. |
|
|
|
Returns values for the metrics computed when training. |
So given the GraphSAGE model model
we trained above, we could do the following:
# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]
# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()