Transformers can do Bayesian Clustering
About
Bayesian clustering accounts for uncertainty but is computationally demanding at scale. Furthermore, real-world datasets often contain missing values, and simple imputation ignores the associated uncertainty, resulting in suboptimal results. We present Cluster-PFN, a Transformer-based model that extends Prior-Data Fitted Networks (PFNs) to unsupervised Bayesian clustering. Trained entirely on synthetic datasets generated from a finite Gaussian Mixture Model (GMM) prior, Cluster-PFN learns to estimate the posterior distribution over both the number of clusters and the cluster assignments. Our method estimates the number of clusters more accurately than handcrafted model selection procedures such as AIC, BIC and Variational Inference (VI), and achieves clustering quality competitive with VI while being orders of magnitude faster. Cluster-PFN can be trained on complex priors that include missing data, outperforming imputation-based baselines on real-world genomic datasets, at high missingness. These results show that the Cluster-PFN can provide scalable and flexible Bayesian clustering.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Cluster count estimation | 1000 2D Easy synthetic (test) | Accuracy0.64 | 5 | |
| Cluster count prediction | 1000 2D Easy | Accuracy64 | 5 | |
| Cluster count prediction | 1000 5D Easy datasets | Accuracy72 | 5 | |
| Cluster count prediction | 1000 2D Hard | Accuracy44 | 5 | |
| Cluster count prediction | 1000 5D Hard | Accuracy52 | 5 | |
| Clustering | Synthetic Clustering Datasets 5D Easy | Mean ARI Rank1.54 | 3 | |
| Clustering | Synthetic Clustering Datasets 2D Hard | Mean ARI Rank1.5 | 3 | |
| Clustering | Synthetic Clustering Datasets 5D Hard | Mean ARI Rank1.49 | 3 | |
| Clustering | 30,000 Synthetic Datasets 2D Easy (test) | Mean AMI Rank1.57 | 3 | |
| Clustering | 30,000 Synthetic Datasets 5D Easy (test) | Mean AMI Rank1.57 | 3 |