Skip to content

Commit 986d59e

Browse files
authored
[BEAM-9547] Add DataFrame.insert implementation (#14663)
[BEAM-9547] Add DataFrame.insert implementation
2 parents 5e53622 + 69db5e7 commit 986d59e

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

sdks/python/apache_beam/dataframe/frames.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,47 @@ def explode(self, column, ignore_index):
14761476
preserves_partition_by=preserves,
14771477
requires_partition_by=partitionings.Arbitrary()))
14781478

1479+
@frame_base.args_to_kwargs(pd.DataFrame)
1480+
@frame_base.populate_defaults(pd.DataFrame)
1481+
def insert(self, value, **kwargs):
1482+
if isinstance(value, list):
1483+
raise frame_base.WontImplementMethod(
1484+
"insert(value=list) is not supported because it joins the input "
1485+
"list to the deferred DataFrame based on the order of the data.",
1486+
reason="order-sensitive")
1487+
1488+
if isinstance(value, pd.core.generic.NDFrame):
1489+
value = frame_base.DeferredFrame.wrap(
1490+
expressions.ConstantExpression(value))
1491+
1492+
if isinstance(value, frame_base.DeferredFrame):
1493+
def func_zip(df, value):
1494+
df = df.copy()
1495+
df.insert(value=value, **kwargs)
1496+
return df
1497+
1498+
inserted = frame_base.DeferredFrame.wrap(
1499+
expressions.ComputedExpression(
1500+
'insert',
1501+
func_zip,
1502+
[self._expr, value._expr],
1503+
requires_partition_by=partitionings.Index(),
1504+
preserves_partition_by=partitionings.Arbitrary()))
1505+
else:
1506+
def func_elementwise(df):
1507+
df = df.copy()
1508+
df.insert(value=value, **kwargs)
1509+
return df
1510+
inserted = frame_base.DeferredFrame.wrap(
1511+
expressions.ComputedExpression(
1512+
'insert',
1513+
func_elementwise,
1514+
[self._expr],
1515+
requires_partition_by=partitionings.Arbitrary(),
1516+
preserves_partition_by=partitionings.Arbitrary()))
1517+
1518+
self._expr = inserted._expr
1519+
14791520
@frame_base.args_to_kwargs(pd.DataFrame)
14801521
@frame_base.populate_defaults(pd.DataFrame)
14811522
def aggregate(self, func, axis=0, *args, **kwargs):

sdks/python/apache_beam/dataframe/frames_test.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ def _run_error_test(
9898
f'Expected {expected_error!r} to be raised, but got {actual!r}'
9999
) from actual
100100

101+
def _run_inplace_test(self, func, arg, **kwargs):
102+
"""Verify an inplace operation performed by func.
103+
104+
Checks that func performs the same inplace operation on arg, in pandas and
105+
in Beam."""
106+
def wrapper(df):
107+
df = df.copy()
108+
func(df)
109+
return df
110+
111+
self._run_test(wrapper, arg, **kwargs)
112+
101113
def _run_test(self, func, *args, distributed=True, nonparallel=False):
102114
"""Verify that func(*args) produces the same result in pandas and in Beam.
103115
@@ -185,13 +197,12 @@ def test_get_column(self):
185197
def test_set_column(self):
186198
def new_column(df):
187199
df['NewCol'] = df['Speed']
188-
return df
189200

190201
df = pd.DataFrame({
191202
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
192203
'Speed': [380., 370., 24., 26.]
193204
})
194-
self._run_test(new_column, df)
205+
self._run_inplace_test(new_column, df)
195206

196207
def test_str_split(self):
197208
s = pd.Series([
@@ -212,13 +223,12 @@ def test_str_split(self):
212223
def test_set_column_from_index(self):
213224
def new_column(df):
214225
df['NewCol'] = df.index
215-
return df
216226

217227
df = pd.DataFrame({
218228
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
219229
'Speed': [380., 370., 24., 26.]
220230
})
221-
self._run_test(new_column, df)
231+
self._run_inplace_test(new_column, df)
222232

223233
def test_tz_localize_ambiguous_series(self):
224234
# This replicates a tz_localize doctest:
@@ -706,11 +716,7 @@ def test_dataframe_eval_query(self):
706716
self._run_test(lambda df: df.eval('foo = a + b - c'), df)
707717
self._run_test(lambda df: df.query('a > b + c'), df)
708718

709-
def eval_inplace(df):
710-
df.eval('foo = a + b - c', inplace=True)
711-
return df.foo
712-
713-
self._run_test(eval_inplace, df)
719+
self._run_inplace_test(lambda df: df.eval('foo = a + b - c'), df)
714720

715721
# Verify that attempting to access locals raises a useful error
716722
deferred_df = frame_base.DeferredFrame.wrap(
@@ -726,9 +732,8 @@ def test_index_name_assignment(self):
726732

727733
def change_index_names(df):
728734
df.index.names = ['A', None]
729-
return df
730735

731-
self._run_test(change_index_names, df)
736+
self._run_inplace_test(change_index_names, df)
732737

733738
@parameterized.expand((x, ) for x in [
734739
0,
@@ -1046,6 +1051,14 @@ def test_dataframe_sum_nonnumeric_raises(self):
10461051
# projecting only numeric columns should too
10471052
self._run_test(lambda df: df[['foo', 'bar']].sum(), GROUPBY_DF)
10481053

1054+
def test_insert(self):
1055+
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
1056+
1057+
self._run_inplace_test(lambda df: df.insert(1, 'C', df.A * 2), df)
1058+
self._run_inplace_test(
1059+
lambda df: df.insert(0, 'foo', pd.Series([8], index=[1])), df)
1060+
self._run_inplace_test(lambda df: df.insert(2, 'bar', value='q'), df)
1061+
10491062

10501063
class AllowNonParallelTest(unittest.TestCase):
10511064
def _use_non_parallel_operation(self):

0 commit comments

Comments
 (0)