Skip to content

Commit 1815725

Browse files
committed
DOC more informative make_multilabel_classification example
1 parent d732830 commit 1815725

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

doc/sphinxext/gen_rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def generate_file_rst(fname, target_dir, src_dir, root_dir, plot_gallery):
870870
my_stdout = my_stdout.replace(
871871
my_globals['__doc__'],
872872
'')
873-
my_stdout = my_stdout.strip()
873+
my_stdout = my_stdout.strip().expandtabs()
874874
if my_stdout:
875875
stdout = '**Script output**::\n\n %s\n\n' % (
876876
'\n '.join(my_stdout.split('\n')))

examples/datasets/plot_random_multilabel_dataset.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,66 +10,87 @@
1010
Points are labeled as follows, where Y means the class is present:
1111
1212
===== ===== ===== ======
13-
Classes Color
14-
------------------- ------
13+
1 2 3 Color
14+
===== ===== ===== ======
1515
Y N N Red
1616
N Y N Blue
1717
N N Y Yellow
1818
Y Y N Purple
1919
Y N Y Orange
2020
Y Y N Green
21-
Y Y Y Black
21+
Y Y Y Brown
2222
===== ===== ===== ======
2323
24-
Below the scatter of these data points we show the underlying class
25-
distribution.
24+
A star marks the expected sample for each class; its size reflects the
25+
probability of selecting that class label.
2626
2727
The left and right examples highlight the ``n_labels`` parameter:
28-
more of the samples in the right plot have 2 or 3 classes.
28+
more of the samples in the right plot have 2 or 3 labels.
2929
30+
Note that this two-dimensional example is very degenerate:
31+
generally the number of features would be much greater than the
32+
"document length", while here we have much larger documents than vocabulary.
33+
Similarly, with ``n_classes > n_features``, it is much less likely that a
34+
feature distinguishes a particular cluss.
3035
"""
3136

37+
from __future__ import print_function
3238
import numpy as np
3339
import matplotlib.pyplot as plt
3440

3541
from sklearn.datasets import make_multilabel_classification as make_ml_clf
3642

37-
COLORS = np.array(['!', 'r', 'b', 'purple', 'y', 'orange', 'g', 'k'])
43+
print(__doc__)
44+
45+
COLORS = np.array(['!',
46+
'#FF3333', # red
47+
'#0198E1', # blue
48+
'#BF5FFF', # purple
49+
'#FCD116', # yellow
50+
'#FF7216', # orange
51+
'#4DBD33', # green
52+
'#87421F' # brown
53+
])
54+
55+
# Use same random seed for multiple calls to make_multilabel_classification to
56+
# ensure same distributions
57+
RANDOM_SEED = np.random.randint(2 ** 10)
3858

3959

4060
def plot_2d(ax, n_labels=1, n_classes=3, length=50):
41-
X, Y, p_c, p_w_c = make_ml_clf(n_samples=100, n_features=2,
61+
X, Y, p_c, p_w_c = make_ml_clf(n_samples=150, n_features=2,
4262
n_classes=n_classes, n_labels=n_labels,
4363
length=length, allow_unlabeled=False,
4464
return_indicator=True,
45-
return_distributions=True)
46-
47-
ax.scatter(X[:, 0], X[:, 1], color=COLORS.take((Y * [4, 2, 1]
48-
).sum(axis=1)))
65+
return_distributions=True,
66+
random_state=RANDOM_SEED)
67+
68+
ax.scatter(X[:, 0], X[:, 1], color=COLORS.take((Y * [1, 2, 4]
69+
).sum(axis=1)),
70+
marker='.')
71+
ax.scatter(p_w_c[0] * length, p_w_c[1] * length,
72+
marker='*', linewidth=.5, edgecolor='black',
73+
s=20 + 1500 * p_c ** 2,
74+
color=COLORS.take([1, 2, 4]))
4975
ax.set_xlabel('Feature 0 count')
5076
return p_c, p_w_c
5177

5278

53-
_, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,
54-
sharex='row',
55-
sharey='row')
56-
plt.subplots_adjust(hspace=.3)
79+
_, (ax1, ax2) = plt.subplots(1, 2, sharex='row', sharey='row', figsize=(8, 4))
80+
plt.subplots_adjust(bottom=.15)
5781

82+
p_c, p_w_c = plot_2d(ax1, n_labels=1)
5883
ax1.set_title('n_labels=1, length=50')
5984
ax1.set_ylabel('Feature 1 count')
60-
p_c, p_w_c = plot_2d(ax1, n_labels=1)
85+
86+
plot_2d(ax2, n_labels=3)
87+
ax2.set_title('n_labels=3, length=50')
6188
ax2.set_xlim(left=0, auto=True)
6289
ax2.set_ylim(bottom=0, auto=True)
6390

64-
ax3.bar([1, 2, 3], p_c, color=COLORS.take([1, 2, 4]))
65-
ax3.set_ylabel('Class prior probability')
66-
67-
ax2.set_title('n_labels=2, length=50')
68-
p_c, p_w_c = plot_2d(ax2, n_labels=2)
69-
70-
ax4.bar([1, 2, 3], p_c, color=COLORS.take([1, 2, 4]))
71-
ax4.set_ylim(0, 1)
72-
ax4.set_xticks([])
91+
plt.show()
7392

74-
###plt.show()
75-
plt.savefig('/tmp/foo.pdf')
93+
print('The data was generated from (random_state=%d):' % RANDOM_SEED)
94+
print('Class', 'P(C)', 'P(w0|C)', 'P(w1|C)', sep='\t')
95+
for k, p, p_w in zip(['red', 'blue', 'yellow'], p_c, p_w_c.T):
96+
print('%s\t%0.2f\t%0.2f\t%0.2f' % (k, p, p_w[0], p_w[1]))

0 commit comments

Comments
 (0)