Skip to content

Commit 0e83dac

Browse files
committed
Update plot_rbf_parameters.py
move the imports closer to their first usage
1 parent 7f59f1b commit 0e83dac

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

examples/svm/plot_rbf_parameters.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,12 @@
7575
7676
"""
7777

78-
import numpy as np
79-
import matplotlib.pyplot as plt
80-
from matplotlib.colors import Normalize
81-
82-
from sklearn.svm import SVC
83-
from sklearn.preprocessing import StandardScaler
84-
from sklearn.datasets import load_iris
85-
from sklearn.model_selection import StratifiedShuffleSplit
86-
from sklearn.model_selection import GridSearchCV
87-
8878
# %%
8979
# Utility class to move the midpoint of a colormap to be around
9080
# the values of interest.
9181

82+
import numpy as np
83+
from matplotlib.colors import Normalize
9284

9385
class MidpointNormalize(Normalize):
9486
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
@@ -106,6 +98,8 @@ def __call__(self, value, clip=None):
10698
#
10799
# dataset for grid search
108100

101+
from sklearn.datasets import load_iris
102+
109103
iris = load_iris()
110104
X = iris.data
111105
y = iris.target
@@ -126,6 +120,8 @@ def __call__(self, value, clip=None):
126120
# instead of fitting the transformation on the training set and
127121
# just applying it on the test set.
128122

123+
from sklearn.preprocessing import StandardScaler
124+
129125
scaler = StandardScaler()
130126
X = scaler.fit_transform(X)
131127
X_2d = scaler.fit_transform(X_2d)
@@ -138,6 +134,10 @@ def __call__(self, value, clip=None):
138134
# 10 is often helpful. Using a basis of 2, a finer
139135
# tuning can be achieved but at a much higher cost.
140136

137+
from sklearn.svm import SVC
138+
from sklearn.model_selection import StratifiedShuffleSplit
139+
from sklearn.model_selection import GridSearchCV
140+
141141
C_range = np.logspace(-2, 10, 13)
142142
gamma_range = np.logspace(-9, 3, 13)
143143
param_grid = dict(gamma=gamma_range, C=C_range)
@@ -169,6 +169,8 @@ def __call__(self, value, clip=None):
169169
#
170170
# draw visualization of parameter effects
171171

172+
import matplotlib.pyplot as plt
173+
172174
plt.figure(figsize=(8, 6))
173175
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
174176
for (k, (C, gamma, clf)) in enumerate(classifiers):

0 commit comments

Comments
 (0)