1. Introduction
Single-cell RNA sequencing (scRNA-seq) is a formidable tool that provides in-depth insights into the genetic traits of each individual cell [
1]. This technology is invaluable for distinguishing cell types with fine granularity [
2], exploring the depths of developmental biology, identifying the mechanisms of complex diseases [
3], and mapping the developmental trajectories of cells [
4]. In the analysis of scRNA-seq data, precisely discerning different cell types is an essential step [
4]. As such, cell clustering methods have become a crucial component of scRNA-seq data analysis, capable of identifying a variety of cell types without any preset assumptions [
2]. Traditional clustering techniques, such as K-means [
5], hierarchical clustering [
6], and density-based clustering [
7], have been utilized to tackle clustering tasks. Nonetheless, the clustering analysis of scRNA-seq data still presents computational and statistical challenges, as limitations in sequencing technology and environmental factors lead to extreme sparsity in the data and a high incidence of zeros [
8,
9]. Therefore, the development of more efficient and accurate clustering methods for scRNA-seq data is compellingly necessary.
Numerous clustering methods have been developed to address these challenges. For instance, CIDR employs a fast PCA-based approach using dissimilarity matrices for data imputation and clustering [
10]. SC3 introduces a consensus clustering framework tailored for scRNA-seq data, reducing dimensions through PCA and Laplace transforms [
11]. SIMLR utilizes multi-kernel learning to attain more robust distance metrics and to handle extensive data missingness [
12]. Despite these advances, due to the sparsity of gene expression levels in the data, these methods often provide suboptimal solutions when dealing with scRNA-seq data [
13]. Furthermore, these methods commonly rely on computationally intensive full-graph Laplacian matrices, which demand considerable computational and storage resources. AutoClass learns the data distribution from raw scRNA-seq data and reconstructs gene expression values based on specific cell types. However, these methods rely on the original distribution of scRNA-seq data and overlook the topological structure information inherent in the data. In recent years, deep embedding clustering methods have emerged as successful approaches for modeling high-dimensional and sparse scRNA-seq data. Examples include scDeepcluster and scziDesk. These methods refine clusters iteratively by learning highly confident assignments and leveraging auxiliary target distributions, ultimately leading to improved clustering results. Nevertheless, these deep embedding clustering methods frequently overlook the propagation of structural information and the relationships between nodes.
Recently, graph neural networks (GNNs) have garnered attention from researchers due to their ability to capture the relational information between neighboring nodes in a graph [
14]. GNNs can reveal the connections between a target node and its surrounding nodes, thereby enhancing the representation of node features [
15]. This has made GNNs a popular method for processing single-cell RNA sequencing (scRNA-seq) data. For example, scGAE utilizes a graph autoencoder to preserve the topological structure and perform dimensionality reduction on scRNA-seq data. GraphSCC combines graph convolutional networks with denoising autoencoder networks to simultaneously capture the complex relationships between cells and the intrinsic characteristics of cells. scTAG leverages GNNs to summarize the related data of adjacent nodes and maps cell expression data to a ZINB model [
13]. Although current GNN-based methods have achieved remarkable results in clustering scRNA-seq data, these strategies often overlook the global information of the graph, leading to an inability to better extract effective latent features.
Due to the limited and hard-to-obtain labeling resources for cell types, as well as the difficulty in learning more effective feature representations, graph contrastive learning has demonstrated strong potential. The core idea is to improve the accuracy of feature representation by increasing the similarity between positive samples while decreasing the similarity between negative samples. Contrastive learning is mainly used in unsupervised representation learning, where it can fully utilize a large amount of unlabeled data. It has shown superior performance on these datasets, even surpassing some supervised learning methods [
16]. This makes contrastive learning naturally suited for scRNA-seq data analysis. The first method to apply this to scRNA-seq data clustering was contrastive-sc [
17], which uses a dropout neural network layer to randomly mask a set of genes, assigning a weight of zero to randomly selected genes. In this case, some key features may be ignored, and important features for model learning might be missed, potentially leading to decreased accuracy in clustering [
16]. scNAME [
18] combines neighborhood contrastive loss with an auxiliary masking estimation task to delve deeper into the correlations between features and similarities between cells. While common contrastive learning methods effectively utilize genes as features of cells, they do not consider the interrelations between cells. GNNs, however, have the capability to capture and represent the complex high-order structural relationships between cells [
19]. Therefore, using graph contrastive learning for cell clustering is a novel approach that could potentially improve the accuracy of cell clustering.
Consequently, we introduce a novel deep graph contrastive learning clustering method called scZAG. scZAG utilizes a ZINB graph convolutional autoencoder to capture the zero-inflation characteristic of scRNA-seq data and reduce noise impact. It employs a joint adaptive data augmentation strategy targeting topological structures and node attributes, preserving key structures and attributes within the cell graph. Utilizing an APPNPGCN as the encoder, scZAG effectively captures local and global graph structure information. By leveraging graph contrastive learning, scZAG learns representative node features and optimizes the cell clustering process using Kullback–Leibler (KL) divergence, ensuring similar cells are assigned to the same cluster while dissimilar ones are separated.
3. Discussion
We have introduced a deep learning method for scRNA-seq clustering named scZAG. Although traditional graph neural networks excel at processing graph-structured information, they can sometimes be overly influenced by the graph structure, especially when the edge weights are not set properly. To overcome this issue, scZAG incorporates APPNPGCN, a method that achieves a better balance between node features and graph structure information through an improved propagation scheme based on personalized PageRank. Given that scRNA-seq data typically contain multiple cell types and states, which present complex patterns and structures in high-dimensional space, the iterative propagation mechanism of APPNPGCN aids the model in capturing the intricate relationships between these cells.
Furthermore, we have implemented a novel adaptive data augmentation approach. By integrating graph contrastive learning, this method is designed to learn the similarities and differences between nodes, aiding in the identification and emphasis of key features that facilitate the distinction of cell types or states, thereby enhancing the clustering outcome. To further augment the model’s performance, we have integrated a ZINB model within the APPNPGCN encoder. The ZINB model handles the zero-inflation and discrete distribution commonly found in scRNA-seq data, providing a more accurate representation of gene expression profiles. By integrating the ZINB model with GNN, we can leverage the more precise cell feature representations provided by ZINB. This enables GNNs to learn more discriminative feature representations, thus improving the clustering algorithm’s ability to identify cell types and states accurately. Personalized PageRank, combined with the ZINB model, offers more accurate information propagation by incorporating both graph structure and cell features. This integration allows clustering algorithms to utilize cell–cell relationships more effectively for clustering, resulting in more precise identification of different cell types and states. The integration of the ZINB model can improve the clustering algorithm’s stability and accuracy. By utilizing the more accurate feature representations provided by the ZINB model and the more accurate graph structure modeling provided by the personalized PageRank algorithm, the clustering algorithm’s sensitivity to noise and disturbances in the data can be reduced, leading to improved clustering stability and accuracy. Lastly, scZAG utilizes a self-optimizing deep embedding clustering approach, feeding the latent features extracted by APPNPGCN into an adaptive clustering module and employing the KL divergence to fulfill the clustering task.
To validate the clustering efficacy of scZAG, we compared it against other state-of-the-art scRNA-seq clustering methods on ten real datasets. Based on the clustering performance, and corroborated by the results of the Mann–Whitney U test, scZAG significantly outperformed the other methods. We also conducted a thorough analysis of hyperparameters to identify the optimal settings for the scZAG approach. Ablation studies confirmed that each component within scZAG positively contributes to the overall performance of the model. Finally, our visualization analysis demonstrated that, compared to other methods, scZAG’s latent embedding representations more effectively differentiate and separate the various cell populations.
When the number of cell types in the dataset increases, both our method and other methods may experience a decrease in accuracy. This is because as the diversity of cell types increases, the complexity of the dataset also increases, making it more challenging to differentiate between different cell types. Additionally, there may be overlapping or similar gene expression patterns among different cell types in the dataset, further complicating accurate classification. In the future, we will continue to improve the balance of scZAG and apply it to the integration of single-cell multi-omics data. Furthermore, we aim to enhance the interpretability of the model by integrating topic modeling techniques.
4. Materials and Methods
The architecture of scZAG is illustrated in
Figure 5. scZAG can be divided into three modules: the ZINB-based autoencoder module, the graph contrastive learning module, and the clustering module. First, the ZINB-based autoencoder module employs a ZINB graph convolutional autoencoder to capture the zero-inflation features of scRNA-seq data, further extracting global probability information to reduce noise impact. The graph contrastive learning module introduces a joint adaptive data augmentation strategy targeting topological structures and node attributes, including edge deletion and feature masking. We utilize centrality metrics to identify important edges and feature dimensions. Using the APPNPGCN as the encoder for scZAG, feature extraction is performed on the enhanced cell graphs, considering higher-order neighbor information of nodes to better capture both local and global graph structures. Based on the dimensionality-reduced data, graph contrastive learning effectively captures the topological structures and relationships between nodes in the graph, thus learning more representative node features. The clustering module optimizes the cell clustering process using KL divergence. KL divergence is a method for measuring the difference between two probability distributions, ensuring that during clustering, similar units or nodes are assigned to the same cluster while dissimilar ones are separated. By minimizing KL divergence, we can ensure that the clustering results closely resemble the underlying distribution of the data.
4.1. Data Sources
To validate the clustering performance of the scZAG model on scRNA-seq data, we compared it with several state-of-the-art scRNA-seq clustering methods across ten real datasets, which are detailed in
Table 2 and originate from recently published papers on scRNA-seq clustering. These datasets come from various sequencing platforms, species, and organs. To assess the clustering performance, we employed two common evaluation metrics: the adjusted Rand index (ARI) and normalized mutual information (NMI), both of which are used to measure the consistency between the generated clusters and the true groups. For both of these evaluation metrics, higher values indicate better clustering performance.
4.2. Data Pre-Processing
The scRNA-seq gene expression matrix
is used as the input for our model, where
represents the expression count of the
jth gene (1 ≤
j ≤ O) in the
ith cell (1 ≤
≤ N). To ensure the quality and reliability of the data, we employ the following pre-processing methods to pre-process the raw scRNA-seq gene expression matrix. First, quality control and data filtering constitute the initial step of our pre-processing. Taking reference from scGNN [
31], we filter out genes that are expressed in more than 1% of cells but are non-zero and genes that are not expressed. Next, as the count matrix data are discrete and subject to large-scale factor variations, we normalize them, followed by rescaling the discrete data using a natural logarithm transformation. The normalization is defined as follows:
where
represents the median of the total expression values of the cells. Lastly, we select the top 500 highly variable genes based on the normalized discrete values calculated by the scanpy package [
32]. This approach is intended to highlight key variations within the data, thus improving the accuracy and interpretability of further analysis.
4.3. Cell Graph
Similar to previous work [
13], we use the KNN (K nearest neighbors) algorithm to construct a cell graph from the pre-processed data, where each node in the graph represents a cell. For each cell, we find its K nearest neighbors and connect them. Thus, each cell is connected to its K closest cells, forming a graph. In our experiments, we set the value of K to 15. We use the Euclidean distance to describe the correlation between nodes in order to identify the k shortest distances. Subsequently, the cell graph we construct is undirected, with all edges weighted equally at 1.
4.4. Graph Contrastive Learning Framework
The graph contrastive learning framework we employ follows the common graph contrastive learning paradigm, where the model aims to maximize the consistency of representations across different views. Specifically, we begin by generating two graph views through random graph augmentations applied to the input data. We then utilize a contrastive objective that ensures that the encoded embeddings of each node within the two different views remain consistent with each other and distinguishable from the embeddings of other nodes.
In each iteration of scZAG, we employ two random augmentation functions,
and
, where
represents the set of all possible augmentation functions. We then obtain two augmented graphs,
and
, with node embeddings
and
, where
represents the feature matrix of the view and
represents the adjacency matrix of the view. A discriminator (contrastive objective) is then employed to distinguish between embeddings of the same node in these two views and embeddings of other nodes. For any node, its embedding
in one view is the anchor, and its embedding
in the other view is the “positive sample.” All other embeddings in the two views are treated as negative samples, representing nodes different from the anchor. To facilitate meaningful feature representations, we utilize the InfoNCE multi-view contrastive loss function, defining the pairwise objective for each positive pair as follows:
where
is a temperature parameter. We define the critic function
, where
represents a predetermined similarity function, and
is a non-linear projection function, which aims to enhance the expressive power of the critic function. The projection function
is implemented through a two-layer perceptron model. Through this design, the model is able to learn a powerful critic function that can accurately evaluate the similarity of node embeddings in different views, thereby contributing to the achievement of the contrastive learning objective.
For each pair of positive samples, negative samples are defined from both inter-view and intra-view nodes, corresponding to the first and second terms of Equation (2). As the two views are symmetrical, the loss from the alternate view is similarly denoted as
. The overall objective to maximize, representing the average of all positive sample pairs, is defined as follows:
In summary, each training round involves applying two random data augmentation functions, and , to generate augmented graphs, and . Node features within these augmented graphs are learned using a graph convolutional autoencoder based on personalized PageRank propagation, resulting in node embeddings and . The model optimizes the objective function in Equation (3) during training, adjusting its parameters to maximize this function and learn node embeddings that effectively capture relationships between nodes.
4.5. Adaptive Graph Augmentation
In our scZAG model, we employ an adaptive augmentation approach [
20] that preserves important structures and attributes while perturbing less significant edges and nodes. This means that when we randomly delete edges and mask node features, the probability of deletion varies according to the importance of each edge and node. Edges or features with lower importance are more likely to be removed or masked; conversely, those with higher importance have a lower probability of being disrupted. Overall, we emphasize the preservation of important structures and attributes rather than random destruction of the view. This method better guides the model to retain fundamental topological structures and semantic graph patterns.
4.5.1. Topology-Level Augmentation
For the topology-level augmentation, we randomly drop edges from the graph with a bias towards the importance of the edges. Formally, we sample a modified subset
from the original adjacency matrix
with the following probability:
where
represents the set of edges in the generated view. The importance of edge
denoted by
allows the augmentation function to more likely disrupt edges of lesser importance, ensuring that the generated view maintains critical connectivity structures. Node centrality is employed to assess the prominence of nodes, and we define edge centrality
based on the centrality of the connecting nodes
. Specifically,
, where
is a node centrality measure.
To assess the likelihood of edge removal based on centrality, we introduce
, accounting for varying centrality magnitudes. Subsequently, we normalize centrality values to transform them into probabilities, defined as follows:
where
is a hyperparameter that governs the overall likelihood of edge deletion,
and
represent the mean and maximum values of
, respectively. The term
is a cutoff probability that truncates the probability to prevent an excessively high chance of deletion, which would lead to an overly disrupted graph structure.
We define PageRank centrality as the node centrality function. PageRank centrality is determined by the PageRank weights derived from the PageRank algorithm, which disseminates influence across directed edges, and nodes that accumulate the greatest influence are considered important. Formally, the centrality values are computed as follows:
where
is the vector of PageRank centrality scores for each node.
is a damping factor that mitigates the absorption of ranks by sink nodes in the graph. Following the recommendation of Lawrence et al. [
33], we set the damping factor
to 0.85. Since our cell graph is undirected, we transform it into a directed graph before applying the PageRank algorithm, where each undirected edge is replaced by two directed edges.
4.5.2. Node-Attribute-Level Augmentation
At the node attribute level, we introduce noise to node attributes by randomly masking portions of dimensions in node features with zeros. Formally, we first adopt a random vector
, where each dimension is independently drawn from a Bernoulli distribution, i.e.,
. The resulting node feature matrix
is:
where
denotes the concatenation operation, and
is the element-wise multiplication. Similar to topological enhancement, the probability
reflects the importance of that node feature in the
th dimension. We assume that feature dimensions that appear in nodes with greater influence are important and define the weight of feature dimensions as follows. For sparse one-hot node features, we compute the dimension weight of
as follows:
where
indicates the presence of feature dimension
in node
, and
measures the importance of node
. For dense node features of a node
, we take the absolute value
to assess the feature weight in dimension
:
Similar to topological enhancement, we then normalize these weights to obtain the importance probabilities for the feature in a given dimension:
where
. Finally, combining topological enhancement and node attribute enhancement, we generate two augmented views,
and
.
4.6. Graph Convolution Based on Personalized PageRank Propagation
To better capture the global structural information in graphs while maintaining computational efficiency, we utilize an autoencoder based on approximate personalized propagation of neural predictions using graph convolution (APPNPGCN) [
34]. The core idea is to propagate neural predictions through an approximation of personalized PageRank, which is guided by the graph’s edge structure. This helps generate node embeddings that reflect a node’s global position and its neighborhood information within the graph. Using this approach enhances the model’s ability to learn from scRNA-seq data. In the case of standard GCNs, when multiple layers are involved, the mean aggregation approach can lead to over-smoothing issues. Thus, standard GCNs lose the capability to capture local structures. Resorting to larger neighborhoods would inevitably increase the neural network’s depth and the number of learnable parameters.
To overcome the loss of local structure capture, we draw on the connection between the limiting distribution and PageRank. Using a personalized PageRank variant with a root node
, the equation
computes the distribution of personalized PageRank values starting from the root node x. α is a damping factor determining the likelihood of returning to the root node x for random walk restarts, and
is the normalized adjacency matrix describing node connectivity. Solving this equation iteratively yields personalized PageRank values,
, for each node
based on the root node
x. By solving this equation, we can obtain the following:
The indicator vector allows us to preserve the local neighborhood of nodes even within the limiting distribution.
We begin with initial predictions based on each node’s unique features and enhance them using personalized PageRank, defining the core concept of neural predictive personalized propagation (APPNP). APPNP employs a power iteration method for efficient topic-sensitive PageRank approximation with linear computational complexity. Unlike traditional PageRank, each power iteration step corresponds to random walks with restarts, improving prediction accuracy. The computation for each power iteration step in the topic-sensitive PageRank variant follows the formulae below:
where
denotes the feature matrix,
is a neural network with parameter theta that generates predictive results
. The prediction matrix
serves as both the starting vector and the propagation set, with
being the number of iterations, where
. During the process of generating predictions and continuously propagating these predictions, the model undergoes end-to-end training. This means that during back propagation, gradients flow through the propagation scheme (implicitly involving an infinite number of neighborhood aggregation layers). Incorporating these propagation effects can significantly enhance the model’s accuracy.
4.7. ZINB-Based Graph Convolutional Autoencoder
To address the issues of excessive zeros and over-dispersion in scRNA-seq data, we incorporate a zero-inflated negative binomial distribution into a graph convolutional autoencoder based on personalized PageRank propagation to learn low-dimensional embeddings of gene expression. The reconstruction of scRNA-seq data using the ZINB-based graph convolutional autoencoder is defined as follows:
where
and
denote the dispersion and mean parameters, respectively, while
represents the zero-inflation probability, which is the likelihood of an observation being zero. To estimate the three critical parameters of the ZINB distribution,
,
, and
, we employ three distinct fully connected layers within our computational framework.
where
is a fully connected neural network comprising three hidden layers with 128, 256, and 512 nodes, respectively.
,
, and
represent three weight matrices corresponding to the parameters of our model.
,
, and
denote the zero-inflation probability, the mean of the negative binomial distribution, and the dispersion parameter, respectively. We utilize the negative log-likelihood of the ZINB distribution as the reconstruction loss for the original data
:
4.8. Self-Optimizing Deep Graph-Embedded Clustering
Self-optimizing deep embedding clustering integrates graph embedding with deep clustering, aiming to optimize clustering performance by learning non-linear embeddings within the graph structure. Traditional clustering algorithms are unsupervised and label-free, lacking optimization feedback during training. Therefore, we employ self-optimizing deep embedding clustering, which receives optimization feedback throughout the training process, yielding more efficient and accurate results when clustering graph-structured data. We utilize the KL divergence to measure the discrepancy between two probability distributions,
P and
Q. Here,
P represents the true distribution, while
Q represents the model distribution or an approximation. We define the clustering loss as follows:
where
is the soft label for the embedded node
.
measures the similarity between
and the cluster center
using a Student’s
t-distribution and is defined as follows:
Initial clusters {
} are generated via spectral clustering after pre-training our model.
is the auxiliary target distribution, refined to the following:
This approach allows the model to iteratively improve the cluster assignments through feedback mechanisms inherently built into the training process.