-
-
Notifications
You must be signed in to change notification settings - Fork 26.7k
Description
TLDR so far: sklearn uses SVD which is more numerically stable, but perhaps it should default to something quicker such as suggested below unless SVD is needed. Later post contains detailed suggestion.
Description
Standard linear regression implementation seems unnecessarily slow, at least in many cases. May extend to other linear estimators e.g. Ridge.
Steps/Code to Reproduce
from sklearn.linear_model import LinearRegression
import numpy as np
import time
def mylin(X,Y):
CXX = np.dot(X.T,X)/X.shape[0]
CXY = np.dot(X.T,Y)/X.shape[0]
return np.linalg.solve(CXX,CXY).T
X = np.random.normal(size=[500000,40])
Y = np.random.normal(size=[500000,10])
lin = LinearRegression(False)
t = time.time()
lin.fit(X,Y)
print('Time to fit sklearn linear:', time.time()-t)
t = time.time()
coef = mylin(X,Y)
print('Time to fit analytic linear regression solution:', time.time()-t)
print('Correlation between solutions:', np.corrcoef(coef.flatten(), lin.coef_.flatten())[0,1])Results
On my laptop, I see the following results:
Time to fit sklearn linear: 0.5804157257080078
Time to fit analytic linear regression solution: 0.023940324783325195
Correlation between solutions: 1.0
I see a similar speedup on my work machine, which is also a conda install but on Ubuntu 18.04
Versions
Could not locate executable g77
Could not locate executable f77
Could not locate executable ifort
Could not locate executable ifl
Could not locate executable f90
Could not locate executable efl
Could not locate executable gfortran
Could not locate executable f95
Could not locate executable g95
Could not locate executable efort
Could not locate executable efc
Could not locate executable flang
don't know how to compile Fortran code on platform 'nt'System:
python: 3.7.1 (default, Dec 10 2018, 22:54:23) [MSC v.1915 64 bit (AMD64)]
executable: C:\Users---\Anaconda3\pythonw.exe
machine: Windows-10-10.0.17763-SP0BLAS:
macros:
lib_dirs:
cblas_libs: cblasPython deps:
pip: 18.1
setuptools: 40.6.3
sklearn: 0.20.1
numpy: 1.16.3
scipy: 1.1.0
Cython: 0.29.2
pandas: 0.23.4
C:\Users\ShakesBeer\Anaconda3\lib\site-packages\numpy\distutils\system_info.py:638: UserWarning:
Atlas (http://math-atlas.sourceforge.net/) libraries not found.
Directories to search for the libraries can be specified in the
numpy/distutils/site.cfg file (section [atlas]) or by setting
the ATLAS environment variable.
self.calc_info()
C:\Users\ShakesBeer\Anaconda3\lib\site-packages\numpy\distutils\system_info.py:638: UserWarning:
Blas (http://www.netlib.org/blas/) libraries not found.
Directories to search for the libraries can be specified in the
numpy/distutils/site.cfg file (section [blas]) or by setting
the BLAS environment variable.
self.calc_info()
C:\Users\ShakesBeer\Anaconda3\lib\site-packages\numpy\distutils\system_info.py:638: UserWarning:
Blas (http://www.netlib.org/blas/) sources not found.
Directories to search for the sources can be specified in the
numpy/distutils/site.cfg file (section [blas_src]) or by setting
the BLAS_SRC environment variable.
self.calc_info()
Worth mentioning that I am using MKL, and np.show_config() confirms this.