Skip to content

Conversation

@jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Jan 11, 2023

The online updates of the sufficient statistics should be normalized by batch_size since it's an average over the batch. When fitting with a constant batch size it doesn't matter but when using partial_fit on batches of different sizes it has an impact.
It's described in the original paper of Mairal Online Learning for Matrix Factorization and Sparse Coding, page 9 of the PDF, in section 3.4.3 mini-batch extension. (https://www.jmlr.org/papers/volume11/mairal10a/mairal10a.pdf)

I computed the final objective function for 100 datasets randomly generated, fitted using partial_fit on batch of various sizes and the objective function is constantly lower with this fix (code below).

from sklearn.decomposition import MiniBatchDictionaryLearning
import numpy as np
 
a = np.hstack(([0], np.logspace(1, 3, num=10).astype(int)))
slices = [slice(a[k], a[k+1]) for k in range(len(a) - 1)]

objs = []
for seed in range(100):
    X = np.random.RandomState(seed).random_sample((1000, 100))
    dl = MiniBatchDictionaryLearning(n_components=15, max_iter=10, random_state=0)
    for sl in slices:
        dl.partial_fit(X[sl])
        obj = 0.5 * (np.sum((X - dl.transform(X) @ dl.components_)**2)) + dl.alpha * np.sum(np.abs(dl.transform(X)))
        objs.append(obj)

(objs_main - objs_this_pr).min()
# 156.0638606704997
(objs_main - objs_this_pr).max()
#  930.3064058826549
# It corresponds to an improvement on the objective function between 1% and 7%.

It's not really possible to add a test for this, but I think the results above are convincing enough.

Comment on lines +2081 to +2082
>>> np.mean(X_transformed == 0) < 0.5
True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed that because there was doctest failures in some jobs. The result was slightly under 0.39 in some jobs and slightly over 0.39 in others.

Since the purpose is to show that the matrix has some sparsity, I think this change is acceptable.

@glemaitre glemaitre self-requested a review January 13, 2023 18:59
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it makes sense looking at the algorithm. LGTM.

@ogrisel ogrisel added this to the 1.2.1 milestone Jan 19, 2023
@ogrisel ogrisel enabled auto-merge (squash) January 19, 2023 11:04
@ogrisel ogrisel merged commit cfd428a into scikit-learn:main Jan 19, 2023
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 20, 2023
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 20, 2023
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 23, 2023
adrinjalali pushed a commit that referenced this pull request Jan 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants