Skip to content

Commit 74751fb

Browse files
committed
Revert "Update plot_rbf_parameters.py"
This reverts commit 0e83dac.
1 parent 0e83dac commit 74751fb

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

examples/svm/plot_rbf_parameters.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,20 @@
7575
7676
"""
7777

78+
import numpy as np
79+
import matplotlib.pyplot as plt
80+
from matplotlib.colors import Normalize
81+
82+
from sklearn.svm import SVC
83+
from sklearn.preprocessing import StandardScaler
84+
from sklearn.datasets import load_iris
85+
from sklearn.model_selection import StratifiedShuffleSplit
86+
from sklearn.model_selection import GridSearchCV
87+
7888
# %%
7989
# Utility class to move the midpoint of a colormap to be around
8090
# the values of interest.
8191

82-
import numpy as np
83-
from matplotlib.colors import Normalize
8492

8593
class MidpointNormalize(Normalize):
8694
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
@@ -98,8 +106,6 @@ def __call__(self, value, clip=None):
98106
#
99107
# dataset for grid search
100108

101-
from sklearn.datasets import load_iris
102-
103109
iris = load_iris()
104110
X = iris.data
105111
y = iris.target
@@ -120,8 +126,6 @@ def __call__(self, value, clip=None):
120126
# instead of fitting the transformation on the training set and
121127
# just applying it on the test set.
122128

123-
from sklearn.preprocessing import StandardScaler
124-
125129
scaler = StandardScaler()
126130
X = scaler.fit_transform(X)
127131
X_2d = scaler.fit_transform(X_2d)
@@ -134,10 +138,6 @@ def __call__(self, value, clip=None):
134138
# 10 is often helpful. Using a basis of 2, a finer
135139
# tuning can be achieved but at a much higher cost.
136140

137-
from sklearn.svm import SVC
138-
from sklearn.model_selection import StratifiedShuffleSplit
139-
from sklearn.model_selection import GridSearchCV
140-
141141
C_range = np.logspace(-2, 10, 13)
142142
gamma_range = np.logspace(-9, 3, 13)
143143
param_grid = dict(gamma=gamma_range, C=C_range)
@@ -169,8 +169,6 @@ def __call__(self, value, clip=None):
169169
#
170170
# draw visualization of parameter effects
171171

172-
import matplotlib.pyplot as plt
173-
174172
plt.figure(figsize=(8, 6))
175173
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
176174
for (k, (C, gamma, clf)) in enumerate(classifiers):

0 commit comments

Comments
 (0)