Implementing a Graph Neural Network in Keras

Nathan Bailey
8 min readJul 31, 2024

--

Introduction

In my previous blog, I detailed the architecture behind graph neural networks. These neural networks are modelled on convolutional layers and take in a graph of labelled nodes and edges. They can perform many tasks, one such task being node classification.

This blog will detail the implementation of a Graph Neural Network in Keras, the full code can be found here, and is adapted from the excellent Keras Tutorial.

Implementation

Dataset

First, let us inspect the dataset we are working with. We will be using the Cora dataset which consists of 2708 papers, each of which has the potential to cite another paper. Each paper can be classified into one of seven classes and has a binary word vector of size 1433 indicating if a particular word is present in the paper.

Therefore, in this case, the nodes of our graph are the papers and each can link to another node through a citation. The data associated with each node is the binary word vector, this can be grouped into a matrix X.

We first read the citation data, this is set up so that a source paper will cite a target paper. Some examples from this dataset can be seen below:

      target   source
4816 337766 110162
716 3095 83449
32 35 1130856
736 3187 129897
3975 105856 193931

We can then read the data on the papers, this classifies the paper into one of seven categories. Some random examples from this dataset are shown below.

                 2369             1604        857              68               1873
paper_id 1119987 1121603 218682 13212 662250
term_0 0 0 0 0 0
term_1 0 0 0 0 0
term_2 0 0 0 0 0
term_3 1 0 0 0 0
... ... ... ... ... ...
term_1429 0 0 0 0 0
term_1430 0 0 0 0 0
term_1431 0 0 0 0 0
term_1432 0 0 0 0 0
subject Case_Based Neural_Networks Case_Based Neural_Networks Neural_Networks

We preprocess the data so that the IDs of the papers start from 0.

class_values = sorted(papers["subject"].unique())
paper_values = sorted(papers["paper_id"].unique())
class_idx = {name: idx for idx, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(paper_values)}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

The citation data can be visualised using the networkx library, which produces a graph similar to the figure below.

plt.figure(figsize=(10, 10))
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
subjects = list(
papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"]
)
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)
Graph Data Visualized

The paper data X is randomly split into train and test datasets and shuffled. Our training data are the IDs of the papers and the labels of the papers as we will be performing node classification.

train_data, test_data = [], []
# Get the papers from each class (subject)
for _, group_data in papers.groupby("subject"):
random_selection = np.random.rand(len(group_data.index)) <= 0.5
train_data.append(group_data[random_selection])
test_data.append(group_data[~random_selection])

# Shuffle the data
train_data = pd.concat(train_data).sample(frac=1)
test_data = pd.concat(test_data).sample(frac=1)

feature_names = list(set(papers.columns) - {"paper_id", "subject"})
num_features = len(feature_names)
num_classes = len(class_idx)

x_train = train_data.paper_id.to_numpy() # type: ignore[attr-defined]
x_test = test_data.paper_id.to_numpy() # type: ignore[attr-defined]
y_train = train_data["subject"] # type: ignore[call-overload]
y_test = test_data["subject"] # type: ignore[call-overload]

Finally, we create the nodes and edges for our network. The node features X are simply the word vectors shaped into a matrix.

node_features = tf.cast(
papers.sort_values("paper_id")[feature_names].to_numpy(),
dtype=tf.dtypes.float32,
)

Typically, our adjacency matrix would be an NxN matrix (where N is the number of nodes), however, it is more efficient here if we use the citation data as is as our adjacency matrix.

edges = citations[["source", "target"]].to_numpy().T
# Set to ones, as no weights needed here
edge_weights = tf.ones(shape=edges.shape[1])

graph_info = (node_features, edges, edge_weights)

Model Building Blocks

To create our graph neural network, we will need a few Python functions.

First, we create a simple function that produces a feed-forward network. This takes in a list of hidden layer units and outputs a neural network.

def create_feed_forward_layer(
hidden_units: list[int], dropout_rate: float
) -> keras.Sequential:
"""Create a feed forward network."""
ffn_layers = []
for units in hidden_units:
ffn_layers.append(keras.layers.BatchNormalization())
ffn_layers.append(keras.layers.Dropout(dropout_rate))
ffn_layers.append(keras.layers.Dense(units))
ffn_layers.append(keras.layers.Activation("gelu"))

return keras.Sequential(ffn_layers)

Then, we create the main layer for our graph neural network. This should take in node representations and our adjacency matrix, perform aggregation on the neighbour nodes and update the node representations based on these.

class GraphConvLayer(keras.layers.Layer):  # type: ignore[misc]
"""Creates a Graph Convolutional Layer."""

def __init__(
self,
hidden_units: list[int],
dropout_rate: float = 0.2,
aggregration_type: str = "mean",
combination_type: str = "concat",
) -> None:
"""Init variables and layers."""
super().__init__()
self.aggregation_type = aggregration_type
self.combination_type = combination_type

self.ffn_prepare = create_feed_forward_layer(
hidden_units, dropout_rate
)
self.update_fn = create_feed_forward_layer(hidden_units, dropout_rate)

First, we gather the neighbour nodes for each node in the graph.

def call(
self, inputs: tuple[tf.Tensor, NDArray[Any], tf.Tensor | None]
) -> tf.Tensor:
"""Forward Pass."""
node_representations, edges, edge_weights = inputs
node_indices, neighbour_indices = edges[0], edges[1]
neighbour_representations = tf.gather(
node_representations, neighbour_indices
)

Given a list of indices, the tf.gather function gathers the vectors from the input matrix according to the indices. For example, if we had the matrix:

[[1, 2, 3, 4]

[3, 4, 5, 6]

[6, 7, 8, 9]]

And the indices [0, 0, 1, 1, 2]

We would produce a matrix as so:

[[1, 2, 3, 4]

[1, 2, 3, 4]

[3, 4, 5, 6]

[3, 4, 5, 6]

[6, 7, 8, 9]]

The resulting data is then passed into the prepare function which passes it through a fully connected layer. The layer will see a batch of node features and will apply the same weights to each feature in the batch, allowing the nodes to share weights here.

def prepare(
self,
node_representations: tf.Tensor,
weights: tf.Tensor | None = None,
) -> tf.Tensor:
"""Pass Neighbour features through a NN to produce messages."""
messages = self.ffn_prepare(node_representations)
if weights is not None:
messages = messages * tf.expand_dims(weights, -1)
return messages


def call(
self, inputs: tuple[tf.Tensor, NDArray[Any], tf.Tensor | None]
) -> tf.Tensor:
...
neighbour_messages = self.prepare(
neighbour_representations, edge_weights
)

We can then aggregate the neighbour node features using the following function.

def aggregate(
self,
node_indices: NDArray[Any],
neighbour_messages: tf.Tensor,
node_representations: tf.Tensor,
) -> tf.Tensor:
"""Aggregate Messages from Neighbours."""
num_nodes = node_representations.shape[0]
if self.aggregation_type == "sum":
aggregated_message = tf.math.unsorted_segment_sum(
neighbour_messages, node_indices, num_segments=num_nodes
)

elif self.aggregation_type == "mean":
aggregated_message = tf.math.unsorted_segment_mean(
neighbour_messages, node_indices, num_segments=num_nodes
)

else:
aggregated_message = tf.math.unsorted_segment_max(
neighbour_messages, node_indices, num_segments=num_nodes
)

return aggregated_message

def call(
self, inputs: tuple[tf.Tensor, NDArray[Any], tf.Tensor | None]
) -> tf.Tensor:
...
aggregated_messages = self.aggregate(
node_indices, neighbour_messages, node_representations
)

Each unsorted segment function takes the neighbour node features and sums them according to the indices. Given our example above and the indices [0, 0, 1, 1, 2], unsorted_segment_sum would give:

[[2, 4, 6, 8]

[6, 8, 10, 12]

[6, 7, 8, 9]]

Lastly, we pass the aggregated messages through the update function which adds or concatenates them to the node features and passes the result through a fully connected layer. Again, since we will be passing through a batch of node features, the same weights are applied to each feature.

def update(
self, node_representations: tf.Tensor, aggregated_messages: tf.Tensor
) -> tf.Tensor:
"""Update node representations based on the incomoing messages."""
if self.combination_type == "concat":
h = tf.concat([node_representations, aggregated_messages], axis=1)
else:
h = node_representations + aggregated_messages

node_embeddings = self.update_fn(h)
return node_embeddings

def call(
self, inputs: tuple[tf.Tensor, NDArray[Any], tf.Tensor | None]
) -> tf.Tensor:
...
return self.update(node_representations, aggregated_messages)

Given an aggregation function of a sum and a non-concatenation update, the layer satisfies the following equations:

Aggregation and Update Function [1]

Model

We incorporate these layers into our model, which classifies nodes, it is constructed using a combination of feed-forward layers and graph convolutional layers.

When creating the model, we pass in our node features, adjacency matrix and edge weights, the latter of which are scaled to sum to 1.

class GNNNodeClassifier(keras.models.Model):  # type: ignore[misc]
"""Graph Neural Network Model."""

def __init__(
self,
graph_info: tuple[tf.Tensor, NDArray[Any], tf.Tensor | None],
num_classes: int,
hidden_units: list[int],
aggregation_type: str = "sum",
combination_type: str = "concat",
dropout_rate: float = 0.2,
) -> None:
"""Init variables and layers."""
super().__init__()
node_features, edges, edge_weights = graph_info
self.node_features = node_features
self.edges = edges
self.edge_weights = edge_weights

if self.edge_weights is None:
self.edge_weights = tf.ones(shape=edges.shape[1])

self.edge_weights /= tf.math.reduce_sum(self.edge_weights)

self.preprocess = create_feed_forward_layer(
hidden_units, dropout_rate
)

self.conv1 = GraphConvLayer(
hidden_units,
dropout_rate,
aggregation_type,
combination_type,
)
self.conv2 = GraphConvLayer(
hidden_units,
dropout_rate,
aggregation_type,
combination_type,
)

self.postprocess = create_feed_forward_layer(
hidden_units, dropout_rate
)
self.compute_logits = keras.layers.Dense(units=num_classes)

At each forward pass, our model takes in a list of node indices to classify. Our node features are then passed through the sequence of feedforward and graph convolutional layers.

Then, given our final node embeddings, we retrieve the node embeddings corresponding to the node indices we are classifying and then pass them through a final layer for classification.

def call(self, input_node_indices: NDArray[Any]) -> tf.Tensor:
"""Model Forward Pass."""
x = self.preprocess(self.node_features)
x1 = self.conv1((x, self.edges, self.edge_weights))
x = x1 + x
x2 = self.conv2((x, self.edges, self.edge_weights))
x = x2 + x
x = self.postprocess(x)
node_embeddings = tf.gather(x, input_node_indices)
return self.compute_logits(node_embeddings)

Using a regular adjacency matrix

As seen in our model, our adjacency matrix differs from the usual NxN size. This is because it is more efficient and allows us to easily perform different types of aggregations. If we did want to use the regular NxN matrix, we could and this is formed as so.

edge_data = citations[["source", "target"]].to_numpy().T
paper_ids = sorted(papers["paper_id"].unique())

# Create proper NxN adjacency matrix
edges = np.zeros(shape=(len(paper_ids), len(paper_ids)))

for value in paper_ids:
idxs = np.where(edge_data[0] == value)[0]
neighbours = edge_data[1, idxs].tolist()
edges[value, neighbours] = 1.0

Then, when it comes time to aggregate the neighbour node features we can simply perform the matrix calculation AX, which will sum up all the neighbouring node features for each node.

def aggregate(
self,
neighbour_messages: tf.Tensor,
adjacency_matrix: tf.Tensor,
) -> tf.Tensor:
"""Aggregate Messages from Neighbours."""
aggregated_message = tf.matmul(adjacency_matrix, neighbour_messages)
return aggregated_message

def call(
self, inputs: tuple[tf.Tensor, NDArray[Any]]
) -> tf.Tensor:
"""Forward Pass."""
node_representations, adjacency_matrix = inputs
neighbour_messages = self.ffn_prepare(node_representations)

aggregated_messages = self.aggregate(
neighbour_messages, adjacency_matrix,
)
...

If we wanted to perform different types of aggregation, then this adjacency matrix would have to be manipulated first, leading to further complexity.

Running the Model

We train the model as if it were a normal classification model. The following graphs show the training loss and training accuracy, on our test set we achieved over 80% accuracy.

graph_info = (node_features, edges, edge_weights)

hidden_units = [32, 32]
learning_rate = 0.01
dropout_rate = 0.5
num_epochs = 300
batch_size = 256


gnn_model = GNNNodeClassifier(
graph_info=graph_info,
num_classes=num_classes,
hidden_units=hidden_units,
dropout_rate=dropout_rate,
)

# Compile and train model
history = compile_and_train_model(
gnn_model, x_train, y_train, learning_rate, num_epochs, batch_size
)

# Test Model
x_test = test_data.paper_id.to_numpy()
_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Loss vs Epochs
Accuracy vs Epochs

Conclusions

This blog detailed the implementation of a graph convolutional layer in Keras, the full code for this project can be found on GitHub.

  1. Deep Learning: Foundations and Concepts, Christopher M Bishop

--

--

Nathan Bailey

MSc AI and ML Student @ ICL. Ex ML Engineer @ Arm, Ex FPGA Engineer @ Arm + Intel, University of Warwick CSE Graduate, Climber. https://www.nathanbaileyw.com