论文Accurate predictions on small data with a tabular foundation model提出了TabPFN(Tabular Prior-data Fitted Network)这一基于生成式Transformer模型的表格数据基础模型。该模型旨在为小型到中型数据集提供更高效的预测,并且能够自动适应表格数据中的各种挑战。
论文作者为Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin Hoo, Robin Tibor Schirrmeister & Frank Hutter,来自如下机构:Machine Learning Lab, University of Freiburg, Freiburg, Germany. Computational Medicine, Berlin Institute of Health at Charité, Universitätsmedizin Berlin, Berlin, Germany. Prior Labs, Freiburg, Germany. Neuromedical AI Lab, Department of Neurosurgery, Medical Center – University of Freiburg, Faculty of Medicine, University of Freiburg, Freiburg, Germany. Medical Physics, Department of Diagnostic and Interventional Radiology, Medical Center – University of Freiburg, Faculty of Medicine, University of Freiburg, Freiburg, Germany. ELLIS Institute
Tübingen, Tübingen, Germany.
一、研究背景与动机
表格数据(Tabular data)是广泛应用于多个领域的重要数据类型,包括生物医学、粒子物理学、经济学、气候科学等。传统上,表格数据的预测任务依赖于手工设计的算法和特征工程,其中填补缺失值是其中的一项核心任务。尽管深度学习在处理原始数据(如图像、文本等)方面取得了显著进展,但在表格数据的处理上,深度学习模型的表现并不理想。这是因为表格数据的异质性较强,特征类型众多(如布尔值、分类数据、连续值等),且常常包含缺失值、异常值、不平衡数据等,使得深度学习模型难以有效处理。
过去20年来,基于决策树的集成方法(如梯度提升决策树、随机森林等)在表格数据上表现卓越,并且成为了主流方法。然而,这些方法在面对小数据集时往往表现不佳,且它们无法有效应对数据漂移(数据分布随时间变化)或领域迁移等问题。为了克服这些局限,研究者提出了TabPFN(Tabular Prior-data Fitted Network),这是一种基于生成式Transformer模型的表格数据基础模型,旨在为小型到中型数据集提供更高效的预测,并且能够自动适应表格数据中的各种挑战。
二、方法与实现
1. 内在学习机制(In-context Learning,ICL)
TabPFN的核心思想是利用内在学习(ICL,或称为“上下文学习” 、 “情境学习”)机制,借鉴了大语言模型(如GPT-3)中的机制。在ICL中,模型不依赖于显式的训练过程,而是通过从输入示例中直接学习如何进行推理。在TabPFN中,这一机制使得模型在遇到新的数据集时,能够快速调整并生成预测结果,无需通过传统的逐样本训练。具体来说,TabPFN通过在数百万个合成数据集上进行训练,自动学习如何应对包括缺失值、噪声、不平衡等复杂数据挑战,并通过一次前向传播完成训练和预测。
2. 合成数据生成与预训练
为了充分训练TabPFN,研究者首先生成大量的合成数据集,这些数据集包含了各种常见的数据挑战,如缺失值、异常值、噪声等。为了生成这些数据,研究者使用了结构化因果模型(SCM)。这种模型通过构建一个有向无环图(DAG)来表示数据特征之间的因果关系,从而能够生成具有复杂依赖关系的合成数据。通过生成大量具有不同因果结构的合成数据集,TabPFN能够学习到多样的预测任务,并自动掌握如何处理不同数据集中的问题。
TabPFN的预训练过程仅需一次,而不需要每次遇到新数据集时都进行训练。这使得模型在面对新数据集时能够迅速适应,并在非常短的时间内完成推理任务。TabPFN通过在合成数据集上的训练,学会了生成一个通用的学习算法,能够有效地处理各种表格数据问题。
3. Transformer架构与优化
TabPFN的架构基于Transformer模型,并针对表格数据进行了优化。Transformer模型在处理序列数据时表现出色,能够有效地捕捉长距离依赖关系。然而,传统的Transformer架构并不适用于表格数据,因为表格数据的结构是二维的,而Transformer模型通常处理的是一维序列。为了克服这一限制,TabPFN采用了二维注意力机制,即每个表格单元(单元格)不仅会与其同行的其他特征进行交互,还会与其他样本的相应特征进行交互。这种设计确保了模型能够充分利用表格数据的二维结构,提高了预测的效率和准确性。
此外,TabPFN对传统Transformer模型进行了多项优化,以减少内存和计算需求。例如,使用了Flash Attention技术来加速计算,使用半精度(half-precision)计算来降低内存占用,还通过激活检查点技术来进一步优化计算过程。这些优化使得TabPFN能够在计算资源有限的情况下处理大规模数据集。
4. 推理与训练分离
TabPFN的另一个创新之处在于它能够将训练和推理过程分离。在传统的机器学习模型中,训练过程和推理过程通常是紧密耦合的,即每次推理都需要重新训练模型。然而,在TabPFN中,训练过程仅在合成数据集上进行一次,而推理过程可以直接在实际数据集上进行。通过将训练和推理分离,TabPFN避免了重复计算的浪费,显著提高了推理效率。
三、实验结果与分析
1. 定性分析
在定性实验中,TabPFN展现了其卓越的建模能力。例如,在多个玩具问题(toy problem[)上,TabPFN能够很好地模拟平滑函数和非平滑函数。尤其是在步进函数的建模上,TabPFN表现得尤为突出,能够有效拟合这些函数,而传统的多层感知机(MLP)和基于树的模型(如CatBoost)则难以处理。TabPFN在建模这些函数时,既没有出现灾难性失败,也能够准确地捕捉到函数的变化趋势。
此外,TabPFN在复杂的物理实验模拟中也表现出色。在经典的双缝干涉实验中,TabPFN能够仅通过一次前向传播,在不到2秒的时间内预测出多模态的光强分布。这种能力使得TabPFN能够在处理复杂分布时,仍然保持较高的效率和准确性。
2. 定量分析
TabPFN在多个真实数据集上的定量实验也验证了其强大的预测能力。在分类任务中,TabPFN的表现超越了当前最强的基线模型CatBoost,并且在推理时间上实现了巨大的加速。例如,在分类任务中,TabPFN的默认配置比CatBoost调优后的配置快了5140倍,而在回归任务中,TabPFN的推理速度比CatBoost快了3000倍。同时,TabPFN在这些任务中的预测准确性也优于现有基线模型。
3. 鲁棒性分析
TabPFN对噪声、缺失值和冗余特征具有很强的鲁棒性。在实验中,TabPFN在处理具有缺失值和冗余特征的数据时,依然能够保持较好的性能。此外,当数据集的样本或特征数量较少时,TabPFN仍然能够表现出色,且其预测结果未受到样本数量减少的显著影响。
四、模型的优势与应用
1. 可解释性
TabPFN不仅在预测性能上表现出色,还具备较强的可解释性。通过SHAP值(Shapley Additive Explanations),TabPFN能够计算各个特征对最终预测的贡献度,从而为模型的预测结果提供清晰的解释。这使得TabPFN特别适合用于那些对模型可解释性有严格要求的领域,如医疗诊断、金融风险评估等。
2. 数据生成与密度估计
TabPFN还具有数据生成和密度估计的能力。通过生成新的数据样本,TabPFN能够用于数据增强,尤其在数据稀缺的领域中,能够有效地扩充训练数据集。此外,TabPFN的密度估计能力使得它能够识别异常数据点,应用于异常检测、欺诈检测等任务。
3. 高效性与应用潜力
TabPFN能够在极短的时间内完成训练和预测,尤其在小型到中型数据集(如10,000个样本和500个特征)上表现尤为出色。其高效性使得它在学术研究和工业应用中均具有巨大的潜力,能够加速数据科学家的工作流程,特别是在需要快速原型设计和迭代的任务中。
五、研究意义与展望
TabPFN的提出标志着表格数据建模方法的重大突破。通过内在学习和生成式模型的结合,TabPFN能够自动学习应对各种数据挑战的策略,并为表格数据分析提供了一种全新的思路。未来的研究可以围绕以下几个方向展开:
- 扩展TabPFN,以支持更大规模的数据集,探索其在大规模数据处理中的应用;
- 进一步研究TabPFN在多模态数据、时间序列数据等复杂数据类型中的应用;
- 研究TabPFN对数据漂移和领域迁移的适应能力,提升其在动态数据环境中的表现。