Skip to content

Commit 0879374

Browse files
committed
Add option to execute only the last transform in TransformWrapper.
1 parent f6faab1 commit 0879374

2 files changed

Lines changed: 32 additions & 13 deletions

File tree

src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView s
403403
internal TransformerChain<ITransformer> GetTransformer()
404404
{
405405
var result = new TransformerChain<ITransformer>();
406+
IDataTransform lastTransformer = null;
406407
foreach (var transform in _transforms)
407408
{
408409
if (transform.Transform is RowToRowMapperTransform mapper)
@@ -412,9 +413,11 @@ internal TransformerChain<ITransformer> GetTransformer()
412413
}
413414
else
414415
{
415-
ITransformer transformer = new TransformWrapper(_host, transform.Transform);
416+
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, lastTransformer is RowToRowMapperTransform);
416417
result = result.Append(transformer);
417418
}
419+
420+
lastTransformer = transform.Transform;
418421
}
419422
return result;
420423
}

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,27 @@ internal sealed class TransformWrapper : ITransformer
2929
private readonly IDataView _xf;
3030
private readonly bool _allowSave;
3131
private readonly bool _isRowToRowMapper;
32+
private readonly bool _useLastTransformOnly;
3233

33-
public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
34+
public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false, bool useLastTransformOnly = false)
3435
{
3536
Contracts.CheckValue(env, nameof(env));
3637
_host = env.Register(nameof(TransformWrapper));
3738
_host.CheckValue(xf, nameof(xf));
3839
_xf = xf;
3940
_allowSave = allowSave;
4041
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
42+
_useLastTransformOnly = useLastTransformOnly;
4143
}
4244

4345
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
4446
{
4547
_host.CheckValue(inputSchema, nameof(inputSchema));
4648

4749
var dv = new EmptyDataView(_host, inputSchema);
48-
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
50+
var output = _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv) :
51+
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
52+
4953
return output.Schema;
5054
}
5155

@@ -115,7 +119,8 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
115119
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
116120
}
117121

118-
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
122+
public IDataView Transform(IDataView input) => _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input) :
123+
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
119124

120125
private static bool IsChainRowToRowMapper(IDataView view)
121126
{
@@ -133,18 +138,29 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
133138
{
134139
_host.CheckValue(inputSchema, nameof(inputSchema));
135140
var input = new EmptyDataView(_host, inputSchema);
136-
var revMaps = new List<IRowToRowMapper>();
137141
IDataView chain;
138-
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
142+
if (_useLastTransformOnly)
143+
{
144+
chain = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
145+
return new CompositeRowToRowMapper(inputSchema, new[] { (IRowToRowMapper)chain });
146+
}
147+
else
139148
{
140-
// Everything in the chain ought to be a row mapper.
141-
_host.Assert(xf is IRowToRowMapper);
142-
revMaps.Add((IRowToRowMapper)xf);
149+
var revMaps = new List<IRowToRowMapper>();
150+
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
151+
chain is IDataTransform xf;
152+
chain = xf.Source)
153+
{
154+
// Everything in the chain ought to be a row mapper.
155+
_host.Assert(xf is IRowToRowMapper);
156+
revMaps.Add((IRowToRowMapper)xf);
157+
}
158+
159+
// The walkback should have ended at the input.
160+
Contracts.Assert(chain == input);
161+
revMaps.Reverse();
162+
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
143163
}
144-
// The walkback should have ended at the input.
145-
Contracts.Assert(chain == input);
146-
revMaps.Reverse();
147-
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
148164
}
149165
}
150166

0 commit comments

Comments
 (0)