Skip to content

XiaomengFanmcislab/Riemannian-implicit-differentiation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Riemannian Implicit Differentiation via a Fixed-Point Equation for Riemannian Bilevel Optimization

🔥TNNLS 2025

📄 Paper

Key idea

Extending implicit differentiation to other Riemannian bilevel optimization tasks is nontrivial because it requires much expert involvement for case-by-case derivations. In this article, we propose a Riemannian implicit differentiation method that provides a unified expression for outer gradients, leading to flexible application to other tasks with less expert involvement.

f7b39365-da0f-46ea-bb99-2a6e6578bc77

Performance Highlights

PCA task on the Grassmann manifolds:

b084e413-74fb-4feb-8887-aef41921337b

Clustering task on the SPD manifolds:

5d76a9c2-65ea-4554-8cac-b4ab94b20953

Long-tailed classification on Stiefel manifolds:

TOP-1 ERROR (%) ON CIFAR-LT-10/CIFAR-LT-100

Method CIFAR-LT-10 (200) CIFAR-LT-10 (100) CIFAR-LT-10 (50) CIFAR-LT-10 (20) CIFAR-LT-100 (200) CIFAR-LT-100 (100) CIFAR-LT-100 (50) CIFAR-LT-100 (20)
Cross-entropy training 34.32 29.63 25.19 17.77 65.16 61.68 56.15 48.86
Class-balanced cross-entropy loss [65] 31.11 27.63 21.95 15.64 64.30 61.44 55.45 42.88
Class-balanced fine-tuning [66] 33.76 28.66 22.56 16.78 61.34 58.5 53.78 47.70
L2RW [67] 33.75 27.77 23.55 18.65 67.00 61.10 56.83 49.25
Meta-weight net [68] 32.8 26.43 20.9 15.55 63.38 58.39 54.34 46.96
Two-component weighting [69] 29.34 23.59 19.49 13.54 60.69 56.65 51.47 44.38
Divide and Retain [70] - - - - 59.47 55.21 50.68 -
Ours 20.4 17.49 14.74 11.54 57.45 52.07 47.64 41.25

Few-shot classification:

ACCURACY (%) ON THE MINI-IMAGENET DATASET

Method Backbone 1-shot 5-way 5-shot 5-way
MAML [72] ResNet12 51.03 ± 0.50 68.26 ± 0.47
L2F [73] ResNet12 57.48 ± 0.49 74.68 ± 0.43
CAML [74] ResNet12 59.23 ± 0.99 72.35 ± 0.71
ALFA [75] ResNet12 60.06 ± 0.49 77.42 ± 0.42
MetaOptNet [76] ResNet12 62.64 ± 0.61 78.63 ± 0.46
MetaFun [77] ResNet12 62.12 ± 0.30 78.20 ± 0.16
DSN [78] ResNet12 62.64 ± 0.66 78.83 ± 0.45
Chen et al. [79] ResNet12 63.17 ± 0.23 79.26 ± 0.17
MeTAL [80] ResNet12 59.64 ± 0.38 76.20 ± 0.19
LEO [81] WRN-28-10 61.76 ± 0.08 77.59 ± 0.12
Con-MetaReg [82] ResNet12 53.68 ± 0.50 66.88 ± 0.42
Hyper ProtoNet [4] ResNet18 59.47 ± 0.20 76.84 ± 0.14
Hyperbolic kernel [83] ResNet18 61.04 ± 0.21 77.33 ± 0.15
CurAML [84] ResNet12 63.13 ± 0.41 81.04 ± 0.39
Poincaré radial kernel [85] ResNet18 62.15 ± 0.20 77.81 ± 0.15
Ours ResNet12 64.5 ± 0.23 82.1 ± 0.15

ACCURACY (%) ON THE TIERED-IMAGENET DATASET

Method Backbone 1-shot 5-way 5-shot 5-way
ProtoNet [86] ResNet12 53.51 $\pm$ 0.89 72.69 $\pm$ 0.74
MAML [72] ResNet12 58.58 $\pm$ 0.49 71.24 $\pm$ 0.43
L2F [73] ResNet12 63.94 $\pm$ 0.48 77.61 $\pm$ 0.41
ALFA [75] ResNet12 64.43 $\pm$ 0.49 81.77 $\pm$ 0.39
DSN [78] ResNet12 66.22 $\pm$ 0.75 82.79 $\pm$ 0.48
MetaOptNet [76] ResNet12 65.99 $\pm$ 0.72 83.28 $\pm$ 0.12
MetaFun [77] ResNet12 67.72 $\pm$ 0.14 78.20 $\pm$ 0.16
Chen et al. [79] ResNet12 68.62 $\pm$ 0.27 83.74 $\pm$ 0.18
MeTAL [80] ResNet12 63.89 $\pm$ 0.43 80.14 $\pm$ 0.40
LEO [81] WRN-28-10 66.33 $\pm$ 0.05 81.44 $\pm$ 0.09
Con-MetaReg [82] ResNet12 54.41 $\pm$ 0.53 68.23 $\pm$ 0.47
Hyper ProtoNet [4] ResNet18 54.44 $\pm$ 0.23 71.96 $\pm$ 0.20
Hyperbolic kernel [83] ResNet18 57.78 $\pm$ 0.23 76.48 $\pm$ 0.18
CurAML [84] ResNet12 68.46 $\pm$ 0.56 83.84 $\pm$ 0.40
Poincaré radial kernel [85] ResNet18 65.33 $\pm$ 0.21 77.48 $\pm$ 0.20
Ours ResNet12 71.56 $\pm$ 0.46 85.75 $\pm$ 0.20

Experiments

PCA task on the Grassmann manifolds:

Run the script below to train your model with our method.

python Grassmann_pca/train/train.py

Evaluate the trained model using the following code.

python Grassmann_pca/test/test.py

Clustering task on the SPD manifolds:

Run the script below to train your model with our method.

bash SPD_clustering/train/train.py

Evaluate the trained model using the following code.

python SPD_clustering/test/evaluation.py

Long-tailed classification on Stiefel manifolds:

Run the script below to train and test your model with our method.

bash Stiefel_C-LT/vali_20.sh
bash Stiefel_C-LT/vali_50.sh
bash Stiefel_C-LT/vali_100.sh
bash Stiefel_C-LT/vali_200.sh

Few-shot classification on Hyperbolic manifolds:

Run the script below to train and test your model with our method.

bash Hyperbolic_few-shot/miniimagenet/miniimagenet_shot1.sh
bash Hyperbolic_few-shot/miniimagenet/miniimagenet_shot5.sh
bash Hyperbolic_few-shot/tieredimagenet/tieredimagenet_shot1.sh
bash Hyperbolic_few-shot/tieredimagenet/tieredimagenet_shot1.sh

Citation

If you find our work helpful, please consider cite our paper 📝 and star us ⭐️!

@ARTICLE{11247945,
  author={Fan, Xiaomeng and Wu, Yuwei and Gao, Zhi and Lu, Zhipeng and Li, Feng and Harandi, Mehrtash and Jia, Yunde},
  journal={IEEE Transactions on Neural Networks and Learning Systems}, 
  title={Riemannian Implicit Differentiation via a Fixed-Point Equation for Riemannian Bilevel Optimization}, 
  year={2025},
  volume={},
  number={},
  pages={1-15},
  doi={10.1109/TNNLS.2025.3624316}}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published