Skip to content

Commit 05fa959

Browse files
enya-yxenya-yx
authored andcommitted
Revert "Fix passthrough feature reference in sql-based derived feature (feathr-ai#815)"
This reverts commit 6b5cd00.
1 parent 7f68162 commit 05fa959

File tree

2 files changed

+5
-26
lines changed

2 files changed

+5
-26
lines changed

src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransforma
44
import com.linkedin.feathr.offline.client.DataFrameColName
55
import com.linkedin.feathr.offline.derived.DerivedFeature
66
import com.linkedin.feathr.offline.derived.functions.SQLFeatureDerivationFunction
7-
import com.linkedin.feathr.offline.job.FeatureTransformation.FEATURE_NAME_PREFIX
87
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
98
import org.apache.spark.sql.functions.expr
109
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -22,14 +21,12 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
2221
* @param deriveFeature derived feature definition
2322
* @param keyTag list of tags represented by integer
2423
* @param keyTagId2StringMap Map from the tag integer id to the string tag
25-
* @param asIsFeatureNames features names that does not to be rewritten, i.e. passthrough features, as they do not have key tags
2624
* @return Rewritten SQL expression
2725
*/
2826
private[offline] def rewriteDerivedFeatureExpression(
2927
deriveFeature: DerivedFeature,
3028
keyTag: Seq[Int],
31-
keyTagId2StringMap: Seq[String],
32-
asIsFeatureNames: Set[String]): String = {
29+
keyTagId2StringMap: Seq[String]): String = {
3330
if (!deriveFeature.derivation.isInstanceOf[SQLFeatureDerivationFunction]) {
3431
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_ERROR, "Should not rewrite derived feature expression for non-SQLDerivedFeatures")
3532
}
@@ -45,7 +42,7 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
4542
val namePattern = if (parameterNames.isEmpty) consumeFeatureName.getFeatureName else parameterNames(index)
4643
// getBinding.map(keyTag.get) resolves the call tags
4744
val newName =
48-
if (!asIsFeatureNames.contains(FEATURE_NAME_PREFIX + consumeFeatureName.getFeatureName)
45+
if (!consumeFeatureName.getBinding.isEmpty // Passthrough features do not have keyTag
4946
// Feature generation code path does not create columns with tags.
5047
// The check ensures we do not run into IndexOutOfBoundsException when keyTag & keyTagId2StringMap are empty.
5148
&& keyTag.nonEmpty
@@ -101,15 +98,7 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
10198
derivationFunction: SQLFeatureDerivationFunction,
10299
mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = {
103100
// sql expression based derived feature needs rewrite, e.g, replace the feature names with feature column names in the dataframe
104-
// Passthrough fields do not need rewrite as they do not have tags.
105-
val passthroughFieldNames = df.schema.fields.map(f =>
106-
if (f.name.startsWith(FEATURE_NAME_PREFIX)) {
107-
f.name
108-
} else {
109-
FEATURE_NAME_PREFIX + f.name
110-
}
111-
).toSet
112-
val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList, passthroughFieldNames)
101+
val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList)
113102
val tags = Some(keyTags.map(keyTagList).toList)
114103
val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags)
115104
df.withColumn(featureColumnName, expr(rewrittenExpr))

src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -484,16 +484,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
484484
|
485485
|derivations: {
486486
| f_trip_time_distance: {
487-
| definition: "f_trip_distance * f_trip_time_duration"
488-
| type: NUMERIC
489-
| }
490-
| f_trip_time_distance_sql: {
491-
| key: [trip]
492-
| inputs: {
493-
| trip_distance: { key: [trip], feature: f_trip_distance }
494-
| trip_time_duration: { key: [trip], feature: f_trip_time_duration }
495-
| }
496-
| definition.sqlExpr: "trip_distance * trip_time_duration"
487+
| definition: "f_trip_distance * f_trip_time_duration"
497488
| type: NUMERIC
498489
| }
499490
|}
@@ -523,8 +514,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
523514
|featureList: [
524515
| {
525516
| key: DOLocationID
526-
| featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance,
527-
| f_trip_time_duration, f_is_long_trip_distance, f_day_of_week, f_trip_time_distance_sql]
517+
| featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance, f_trip_time_duration, f_is_long_trip_distance, f_day_of_week]
528518
| }
529519
|]
530520
""".stripMargin

0 commit comments

Comments
 (0)