@@ -29,23 +29,27 @@ internal sealed class TransformWrapper : ITransformer
2929 private readonly IDataView _xf ;
3030 private readonly bool _allowSave ;
3131 private readonly bool _isRowToRowMapper ;
32+ private readonly bool _useLastTransformOnly ;
3233
33- public TransformWrapper ( IHostEnvironment env , IDataView xf , bool allowSave = false )
34+ public TransformWrapper ( IHostEnvironment env , IDataView xf , bool allowSave = false , bool useLastTransformOnly = false )
3435 {
3536 Contracts . CheckValue ( env , nameof ( env ) ) ;
3637 _host = env . Register ( nameof ( TransformWrapper ) ) ;
3738 _host . CheckValue ( xf , nameof ( xf ) ) ;
3839 _xf = xf ;
3940 _allowSave = allowSave ;
4041 _isRowToRowMapper = IsChainRowToRowMapper ( _xf ) ;
42+ _useLastTransformOnly = useLastTransformOnly ;
4143 }
4244
4345 public DataViewSchema GetOutputSchema ( DataViewSchema inputSchema )
4446 {
4547 _host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
4648
4749 var dv = new EmptyDataView ( _host , inputSchema ) ;
48- var output = ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , dv ) ;
50+ var output = _useLastTransformOnly ? ApplyTransformUtils . ApplyTransformToData ( _host , ( IDataTransform ) _xf , dv ) :
51+ ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , dv ) ;
52+
4953 return output . Schema ;
5054 }
5155
@@ -115,7 +119,8 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
115119 _isRowToRowMapper = IsChainRowToRowMapper ( _xf ) ;
116120 }
117121
118- public IDataView Transform ( IDataView input ) => ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , input ) ;
122+ public IDataView Transform ( IDataView input ) => _useLastTransformOnly ? ApplyTransformUtils . ApplyTransformToData ( _host , ( IDataTransform ) _xf , input ) :
123+ ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , input ) ;
119124
120125 private static bool IsChainRowToRowMapper ( IDataView view )
121126 {
@@ -133,18 +138,29 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
133138 {
134139 _host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
135140 var input = new EmptyDataView ( _host , inputSchema ) ;
136- var revMaps = new List < IRowToRowMapper > ( ) ;
137141 IDataView chain ;
138- for ( chain = ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , input ) ; chain is IDataTransform xf ; chain = xf . Source )
142+ if ( _useLastTransformOnly )
143+ {
144+ chain = ApplyTransformUtils . ApplyTransformToData ( _host , ( IDataTransform ) _xf , input ) ;
145+ return new CompositeRowToRowMapper ( inputSchema , new [ ] { ( IRowToRowMapper ) chain } ) ;
146+ }
147+ else
139148 {
140- // Everything in the chain ought to be a row mapper.
141- _host . Assert ( xf is IRowToRowMapper ) ;
142- revMaps . Add ( ( IRowToRowMapper ) xf ) ;
149+ var revMaps = new List < IRowToRowMapper > ( ) ;
150+ for ( chain = ApplyTransformUtils . ApplyAllTransformsToData ( _host , _xf , input ) ;
151+ chain is IDataTransform xf ;
152+ chain = xf . Source )
153+ {
154+ // Everything in the chain ought to be a row mapper.
155+ _host . Assert ( xf is IRowToRowMapper ) ;
156+ revMaps . Add ( ( IRowToRowMapper ) xf ) ;
157+ }
158+
159+ // The walkback should have ended at the input.
160+ Contracts . Assert ( chain == input ) ;
161+ revMaps . Reverse ( ) ;
162+ return new CompositeRowToRowMapper ( inputSchema , revMaps . ToArray ( ) ) ;
143163 }
144- // The walkback should have ended at the input.
145- Contracts . Assert ( chain == input ) ;
146- revMaps . Reverse ( ) ;
147- return new CompositeRowToRowMapper ( inputSchema , revMaps . ToArray ( ) ) ;
148164 }
149165 }
150166
0 commit comments