Skip to content

Commit 9590265

Browse files
committed
Factor out getting correct column class
1 parent 9ed8815 commit 9590265

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

astropy/table/table.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class used to create new non-mixin columns, and this is a function of
5858
generic way to copy a mixin object but not the data.
5959
6060
- Be aware of column objects that have indices set.
61+
62+
- `cls.ColumnClass` is a property that effectively uses the `masked` attribute
63+
to choose either `cls.Column` or `cls.MaskedColumn`.
6164
"""
6265

6366
__doctest_skip__ = ['Table.read', 'Table.write', 'Table._read',
@@ -905,10 +908,7 @@ def _convert_data_to_col(self, data, copy=True, default_name=None, dtype=None, n
905908
# gets upgraded to MaskedColumn, but the converse (pre-4.0) behavior
906909
# of downgrading from MaskedColumn to Column (for non-masked table)
907910
# does not happen.
908-
if issubclass(self.ColumnClass, data.__class__):
909-
col_cls = self.ColumnClass
910-
else:
911-
col_cls = data.__class__
911+
col_cls = self._get_col_cls_for_table(data)
912912

913913
elif self._is_mixin_for_table(data):
914914
# Copy the mixin column attributes if they exist since the copy below
@@ -956,6 +956,31 @@ def _init_from_dict(self, data, names, dtype, n_cols, copy):
956956
data_list = [data[name] for name in names]
957957
self._init_from_list(data_list, names, dtype, n_cols, copy)
958958

959+
def _get_col_cls_for_table(self, col):
960+
"""Get the correct column class to use for upgrading any Column-like object.
961+
962+
For a masked table, ensure any Column-like object is a subclass
963+
of the table MaskedColumn.
964+
965+
For unmasked table, ensure any MaskedColumn-like object is a subclass
966+
of the table MaskedColumn. If not a MaskedColumn, then ensure that any
967+
Column-like object is a subclass of the table Column.
968+
"""
969+
970+
col_cls = col.__class__
971+
972+
if self.masked:
973+
if isinstance(col, Column) and not isinstance(col, self.MaskedColumn):
974+
col_cls = self.MaskedColumn
975+
else:
976+
if isinstance(col, MaskedColumn):
977+
if not isinstance(col, self.MaskedColumn):
978+
col_cls = self.MaskedColumn
979+
elif isinstance(col, Column) and not isinstance(col, self.Column):
980+
col_cls = self.Column
981+
982+
return col_cls
983+
959984
def _convert_col_for_table(self, col):
960985
"""
961986
Make sure that all Column objects have correct base class for this type of
@@ -964,7 +989,9 @@ def _convert_col_for_table(self, col):
964989
override this method.
965990
"""
966991
if isinstance(col, Column) and not isinstance(col, self.ColumnClass):
967-
col = self.ColumnClass(col) # copy attributes and reference data
992+
col_cls = self._get_col_cls_for_table(col)
993+
col = col_cls(col, copy=False)
994+
968995
return col
969996

970997
def _init_from_cols(self, cols):

0 commit comments

Comments
 (0)