Skip to content
13 changes: 12 additions & 1 deletion lib/streamlit/DeltaGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from streamlit.proto.NumberInput_pb2 import NumberInput
from streamlit.proto.TextInput_pb2 import TextInput
from streamlit.logger import get_logger
from streamlit.type_util import is_type

LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -1720,7 +1721,17 @@ def _check_and_convert_to_indices(options, default_values):
return None

if not isinstance(default_values, list):
default_values = [default_values]
# This if is done before others because calling if not x (done
# right below) when x is of type pd.Series() or np.array() throws a
# ValueError exception.
if is_type(default_values, "numpy.ndarray") or is_type(
default_values, "pandas.core.series.Series"
):
default_values = list(default_values)
Comment on lines +1727 to +1730
Copy link
Copy Markdown
Collaborator Author

@arnaudmiribel arnaudmiribel May 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for pd and np:
This if is done before others because calling if not x (done right below) when x is of type pd.Series() or np.array() throws a ValueError exception.

Also added protos of that in the tests.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, so I just added what you wrote here to the code for future reference.

elif not default_values:
default_values = [default_values]
Comment on lines +1731 to +1732
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful for empty objects e.g. x in ("", None, {})

else:
default_values = list(default_values)

for value in default_values:
if value not in options:
Expand Down
17 changes: 17 additions & 0 deletions lib/tests/streamlit/multiselect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ def test_defaults(self, defaults, expected):
self.assertListEqual(c.default[:], expected)
self.assertEqual(c.options, ["Coffee", "Tea", "Water"])

@parameterized.expand(
[
(("Tea", "Water"), [1, 2]),
((i for i in ("Tea", "Water")), [1, 2]),
(np.array(["Coffee", "Tea"]), [0, 1]),
(pd.Series(np.array(["Coffee", "Tea"])), [0, 1]),
]
)
def test_default_types(self, defaults, expected):
"""Test that iterables other than lists can be passed as defaults."""
st.multiselect("the label", ["Coffee", "Tea", "Water"], defaults)

c = self.get_delta_from_queue().new_element.multiselect
self.assertEqual(c.label, "the label")
self.assertListEqual(c.default[:], expected)
self.assertEqual(c.options, ["Coffee", "Tea", "Water"])

@parameterized.expand(
[
(["Tea", "Vodka", None], StreamlitAPIException),
Expand Down