Skip to content

Commit 89cd059

Browse files
shub-krisShubham Krishna
andauthored
Add working example for Per Entity Training (#25081)
* Add working example for Per Entity Training * Fix linting and formatting * Add documentation, fix docstrings, linting * add dynamic destination writing * Update saving models, documentation * remove extra lines and fix import order * Update documentation * Update documentation Co-authored-by: Shubham Krishna <“[email protected]”>
1 parent 9aa2c52 commit 89cd059

4 files changed

Lines changed: 219 additions & 1 deletion

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A pipeline to demonstrate per-entity training.
19+
20+
This pipeline reads data from a CSV file, that contains information
21+
about 15 different attributes like salary >=50k, education level,
22+
native country, age, occupation and others. The pipeline does some filtering
23+
by selecting certain education level, discarding missing values and empty rows.
24+
The pipeline then groups the rows based on education level and
25+
trains Decision Trees for each group and finally saves them.
26+
"""
27+
28+
import argparse
29+
import logging
30+
import pickle
31+
32+
import pandas as pd
33+
from sklearn.compose import ColumnTransformer
34+
from sklearn.pipeline import Pipeline
35+
from sklearn.preprocessing import LabelEncoder
36+
from sklearn.preprocessing import MinMaxScaler
37+
from sklearn.preprocessing import OneHotEncoder
38+
from sklearn.tree import DecisionTreeClassifier
39+
40+
import apache_beam as beam
41+
from apache_beam.io import fileio
42+
from apache_beam.options.pipeline_options import PipelineOptions
43+
from apache_beam.options.pipeline_options import SetupOptions
44+
45+
46+
class CreateKey(beam.DoFn):
47+
def process(self, element, *args, **kwargs):
48+
# 3rd column of the dataset is Education
49+
idx = 3
50+
key = element.pop(idx)
51+
yield (key, element)
52+
53+
54+
def custom_filter(element):
55+
"""Discard data point if contains ?,
56+
doesn't have all features, or
57+
doesn't have Bachelors, Masters or a Doctorate Degree"""
58+
return len(element) == 15 and '?' not in element \
59+
and ' Bachelors' in element or ' Masters' in element \
60+
or ' Doctorate' in element
61+
62+
63+
class PrepareDataforTraining(beam.DoFn):
64+
"""Preprocess data in a format suitable for training."""
65+
def process(self, element, *args, **kwargs):
66+
key, values = element
67+
#Convert to dataframe
68+
df = pd.DataFrame(values)
69+
last_ix = len(df.columns) - 1
70+
X, y = df.drop(last_ix, axis=1), df[last_ix]
71+
# select categorical and numerical features
72+
cat_ix = X.select_dtypes(include=['object', 'bool']).columns
73+
num_ix = X.select_dtypes(include=['int64', 'float64']).columns
74+
# label encode the target variable to have the classes 0 and 1
75+
y = LabelEncoder().fit_transform(y)
76+
yield (X, y, cat_ix, num_ix, key)
77+
78+
79+
class TrainModel(beam.DoFn):
80+
"""Takes preprocessed data as input,
81+
transforms categorical columns using OneHotEncoder,
82+
normalizes numerical columns and then
83+
fits a decision tree classifier.
84+
"""
85+
def process(self, element, *args, **kwargs):
86+
X, y, cat_ix, num_ix, key = element
87+
steps = [('c', OneHotEncoder(handle_unknown='ignore'), cat_ix),
88+
('n', MinMaxScaler(), num_ix)]
89+
# one hot encode categorical, normalize numerical
90+
ct = ColumnTransformer(steps)
91+
# wrap the model in a pipeline
92+
pipeline = Pipeline(steps=[('t', ct), ('m', DecisionTreeClassifier())])
93+
pipeline.fit(X, y)
94+
yield (key, pipeline)
95+
96+
97+
class ModelSink(fileio.FileSink):
98+
def open(self, fh):
99+
self._fh = fh
100+
101+
def write(self, record):
102+
_, trained_model = record
103+
pickled_model = pickle.dumps(trained_model)
104+
self._fh.write(pickled_model)
105+
106+
def flush(self):
107+
self._fh.flush()
108+
109+
110+
def parse_known_args(argv):
111+
"""Parses args for the workflow."""
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument(
114+
'--input',
115+
dest='input',
116+
help='Path to the text file containing sentences.')
117+
parser.add_argument(
118+
'--output-dir',
119+
dest='output',
120+
required=True,
121+
help='Path of directory for saving trained models.')
122+
return parser.parse_known_args(argv)
123+
124+
125+
def run(
126+
argv=None,
127+
save_main_session=True,
128+
):
129+
"""
130+
Args:
131+
argv: Command line arguments defined for this example.
132+
save_main_session: Used for internal testing.
133+
"""
134+
known_args, pipeline_args = parse_known_args(argv)
135+
pipeline_options = PipelineOptions(pipeline_args)
136+
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
137+
with beam.Pipeline(options=pipeline_options) as pipeline:
138+
_ = (
139+
pipeline | "Read Data" >> beam.io.ReadFromText(known_args.input)
140+
| "Split data to make List" >> beam.Map(lambda x: x.split(','))
141+
| "Filter rows" >> beam.Filter(custom_filter)
142+
| "Create Key" >> beam.ParDo(CreateKey())
143+
| "Group by education" >> beam.GroupByKey()
144+
| "Prepare Data" >> beam.ParDo(PrepareDataforTraining())
145+
| "Train Model" >> beam.ParDo(TrainModel())
146+
|
147+
"Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink()))
148+
149+
150+
if __name__ == "__main__":
151+
logging.getLogger().setLevel(logging.INFO)
152+
run()

website/www/site/content/en/documentation/ml/overview.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,5 @@ You can find examples of end-to-end AI/ML pipelines for several use cases:
9090
* [Multi model pipelines in Beam](/documentation/ml/multi-model-pipelines): Explains how multi-model pipelines work and gives an overview of what you need to know to build one using the RunInference API.
9191
* [Online Clustering in Beam](/documentation/ml/online-clustering): Demonstrates how to set up a real-time clustering pipeline that can read text from Pub/Sub, convert the text into an embedding using a transformer-based language model with the RunInference API, and cluster the text using BIRCH with stateful processing.
9292
* [Anomaly Detection in Beam](/documentation/ml/anomaly-detection): Demonstrates how to set up an anomaly detection pipeline that reads text from Pub/Sub in real time and then detects anomalies using a trained HDBSCAN clustering model with the RunInference API.
93-
* [Large Language Model Inference in Beam](/documentation/ml/large-language-modeling): Demonstrates a pipeline that uses RunInference to perform translation with the T5 language model which contains 11 billion parameters.
93+
* [Large Language Model Inference in Beam](/documentation/ml/large-language-modeling): Demonstrates a pipeline that uses RunInference to perform translation with the T5 language model which contains 11 billion parameters.
94+
* [Per Entity Training in Beam](/documentation/ml/per-entity-training): Demonstrates a pipeline that trains a Decision Tree Classifier per education level for predicting if the salary of a person is >= 50k.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
---
2+
title: "Per Entity Training"
3+
---
4+
<!--
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
-->
17+
18+
# Per Entity Training
19+
The aim of this pipeline example is to demonstrate per entity training in Beam. Per entity training refers to the process of training a machine learning model for each individual entity, rather than training a single model for all entities. In this approach, a separate model is trained for each entity based on the data specific to that entity. Per entity training can be beneficial in the following scenarios:
20+
21+
* Having separate models allows for more personalized and tailored predictions for each group. Each group may have different characteristics, patterns, and behaviors that a single large model may not be able to capture effectively.
22+
23+
* Having separate models can also help to reduce the complexity of the overall model and make it more efficient. The overall model would only need to focus on the specific characteristics and patterns of the individual group, rather than trying to account for all possible characteristics and patterns across all groups.
24+
25+
* Having separate models can address issues of bias and fairness. Because a single model trained on a diverse dataset might not generalize well to certain groups, separate models for each group can reduce the impact of bias.
26+
27+
* This approach is often favored in production settings, because it makes it easier to detect issues specific to a limited segment of the overall population.
28+
29+
* When working with smaller models and datasets, the process of training and retraining can be completed more rapidly and efficiently. Both the training and retraining can be done in parallel, reducing the amount of time spent waiting for results. Furthermore, smaller models and datasets also have the advantage of being less resource-intensive, which allows them to be run on less expensive hardware.
30+
31+
## Dataset
32+
This example uses [Adult Census Income dataset](https://archive.ics.uci.edu/ml/datasets/adult). The dataset contains information about individuals, including their demographic characteristics, employment status, and income level. The dataset includes both categorical and numerical features, such as age, education, occupation, and hours worked per week, as well as a binary label indicating whether an individual's income is above or below 50,000 USD. The primary goal of this dataset is to be used for classification tasks, where the model will predict whether an individual's income is above or below a certain threshold based on the provided features.The pipeline expects the `adult.data` CSV file as an input. This file can be downloaded from [here](https://archive.ics.uci.edu/ml/machine-learning-databases/adult/).
33+
34+
### Run the Pipeline
35+
First, install the required packages `apache-beam==2.44.0`, `scikit-learn==1.0.2` and `pandas==1.3.5`.
36+
You can view the code on [GitHub](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/per_entity_training.py).
37+
Use `python per_entity_training.py --input path/to/adult.data`
38+
39+
40+
### Train the pipeline
41+
The pipeline can be broken down into the following main steps:
42+
1. Read the data from the provided input path.
43+
2. Filter the data based on some criteria.
44+
3. Create key based on education level.
45+
4. Group dataset based on the key generated.
46+
5. Preprocess the dataset.
47+
6. Train model per education level.
48+
7. Save the trained models.
49+
50+
The following code snippet contains the detailed steps:
51+
52+
{{< highlight >}}
53+
with beam.Pipeline(options=pipeline_options) as pipeline:
54+
_ = (
55+
pipeline | "Read Data" >> beam.io.ReadFromText(known_args.input)
56+
| "Split data to make List" >> beam.Map(lambda x: x.split(','))
57+
| "Filter rows" >> beam.Filter(custom_filter)
58+
| "Create Key" >> beam.ParDo(CreateKey())
59+
| "Group by education" >> beam.GroupByKey()
60+
| "Prepare Data" >> beam.ParDo(PrepareDataforTraining())
61+
| "Train Model" >> beam.ParDo(TrainModel())
62+
|
63+
"Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink()))
64+
{{< /highlight >}}

website/www/site/layouts/partials/section-menu/en/documentation.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@
224224
<li><a href="/documentation/ml/runinference-metrics/">RunInference Metrics</a></li>
225225
<li><a href="/documentation/ml/anomaly-detection/">Anomaly Detection</a></li>
226226
<li><a href="/documentation/ml/large-language-modeling">Large Language Model Inference in Beam</a></li>
227+
<li><a href="/documentation/ml/per-entity-training">Per Entity Training in Beam</a></li>
227228
</ul>
228229
</li>
229230
<li class="section-nav-item--collapsible">

0 commit comments

Comments
 (0)