Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Rank-N-Contrast: Learning Continuous Representations for Regression

About

Deep regression models typically learn in an end-to-end fashion without explicitly emphasizing a regression-aware representation. Consequently, the learned representations exhibit fragmentation and fail to capture the continuous nature of sample orders, inducing suboptimal results across a wide range of regression tasks. To fill the gap, we propose Rank-N-Contrast (RNC), a framework that learns continuous representations for regression by contrasting samples against each other based on their rankings in the target space. We demonstrate, theoretically and empirically, that RNC guarantees the desired order of learned representations in accordance with the target orders, enjoying not only better performance but also significantly improved robustness, efficiency, and generalization. Extensive experiments using five real-world regression datasets that span computer vision, human-computer interaction, and healthcare verify that RNC achieves state-of-the-art performance, highlighting its intriguing properties including better data efficiency, robustness to spurious targets and data corruptions, and generalization to distribution shifts. Code is available at: https://github.com/kaiwenzha/Rank-N-Contrast.

Kaiwen Zha, Peng Cao, Jeany Son, Yuzhe Yang, Dina Katabi• 2022

Related benchmarks

TaskDatasetResultRank
Age EstimationAgeDB (val)
Age MAE6.14
13
RegressionSkyFinder
MAE2.86
11
Dysarthric speech severity assessmentSAP In-domain (test)
SRCC0.726
10
Dysarthric speech severity assessmentEWA-DB Cross-domain (test)
SRCC0.714
10
RegressionTUAB
MAE6.97
10
RegressionMPIIFaceGaze
Angular Error5.27
10
Dysarthric speech severity assessmentUASpeech Cross-domain (test)
SRCC0.959
10
Dysarthric speech severity assessmentNeuroVoz Cross-domain (test)
SRCC0.577
10
Dysarthric speech severity assessmentEasyCall Cross-domain (test)
SRCC0.868
10
Dysarthric speech severity assessmentDysArinVox Cross-domain (test)
SRCC0.564
10
Showing 10 of 10 rows

Other info

Code

Follow for update