Skip to content

Commit 0e3dc6a

Browse files
committed
PR feedback.
1 parent e74e52a commit 0e3dc6a

5 files changed

Lines changed: 74 additions & 66 deletions

File tree

src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ internal TransformerChain<ITransformer> GetTransformer()
412412
}
413413
else
414414
{
415-
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, true);
415+
ITransformer transformer = new TransformWrapper(_host, transform.Transform);
416416
result = result.Append(transformer);
417417
}
418418
}

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,61 +27,28 @@ internal sealed class TransformWrapper : ITransformer
2727

2828
private readonly IHost _host;
2929
private readonly IDataView _xf;
30-
private readonly bool _allowSave;
3130
private readonly bool _isRowToRowMapper;
32-
private readonly bool _useLastTransformOnly;
3331

34-
public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false, bool useLastTransformOnly = false)
32+
public TransformWrapper(IHostEnvironment env, IDataView xf)
3533
{
3634
Contracts.CheckValue(env, nameof(env));
3735
_host = env.Register(nameof(TransformWrapper));
3836
_host.CheckValue(xf, nameof(xf));
3937
_xf = xf;
40-
_allowSave = allowSave;
4138
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
42-
_useLastTransformOnly = useLastTransformOnly;
4339
}
4440

4541
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
4642
{
4743
_host.CheckValue(inputSchema, nameof(inputSchema));
4844

4945
var dv = new EmptyDataView(_host, inputSchema);
50-
var output = _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv) :
51-
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
46+
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);
5247

5348
return output.Schema;
5449
}
5550

56-
void ICanSaveModel.Save(ModelSaveContext ctx)
57-
{
58-
if (!_allowSave)
59-
throw _host.Except("Saving is not permitted.");
60-
ctx.CheckAtModel();
61-
ctx.SetVersionInfo(GetVersionInfo());
62-
63-
var dataPipe = _xf;
64-
var transforms = new List<IDataTransform>();
65-
while (dataPipe is IDataTransform xf)
66-
{
67-
// REVIEW: a malicious user could construct a loop in the Source chain, that would
68-
// cause this method to iterate forever (and throw something when the list overflows). There's
69-
// no way to insulate from ALL malicious behavior.
70-
transforms.Add(xf);
71-
dataPipe = xf.Source;
72-
Contracts.AssertValue(dataPipe);
73-
}
74-
transforms.Reverse();
75-
76-
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));
77-
78-
ctx.Writer.Write(transforms.Count);
79-
for (int i = 0; i < transforms.Count; i++)
80-
{
81-
var dirName = string.Format(TransformDirTemplate, i);
82-
ctx.SaveModel(transforms[i], dirName);
83-
}
84-
}
51+
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");
8552

8653
private static VersionInfo GetVersionInfo()
8754
{
@@ -100,7 +67,6 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
10067
Contracts.CheckValue(env, nameof(env));
10168
_host = env.Register(nameof(TransformWrapper));
10269
_host.CheckValue(ctx, nameof(ctx));
103-
_allowSave = true;
10470
ctx.CheckAtModel(GetVersionInfo());
10571
int n = ctx.Reader.ReadInt32();
10672
_host.CheckDecode(n >= 0);
@@ -119,8 +85,7 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
11985
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
12086
}
12187

122-
public IDataView Transform(IDataView input) => _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input) :
123-
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
88+
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
12489

12590
private static bool IsChainRowToRowMapper(IDataView view)
12691
{
@@ -137,30 +102,8 @@ private static bool IsChainRowToRowMapper(IDataView view)
137102
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
138103
{
139104
_host.CheckValue(inputSchema, nameof(inputSchema));
140-
var input = new EmptyDataView(_host, inputSchema);
141-
IDataView chain;
142-
if (_useLastTransformOnly)
143-
{
144-
chain = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
145-
return new CompositeRowToRowMapper(inputSchema, new[] { (IRowToRowMapper)chain });
146-
}
147-
else
148-
{
149-
var revMaps = new List<IRowToRowMapper>();
150-
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
151-
chain is IDataTransform xf;
152-
chain = xf.Source)
153-
{
154-
// Everything in the chain ought to be a row mapper.
155-
_host.Assert(xf is IRowToRowMapper);
156-
revMaps.Add((IRowToRowMapper)xf);
157-
}
158-
159-
// The walkback should have ended at the input.
160-
Contracts.Assert(chain == input);
161-
revMaps.Reverse();
162-
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
163-
}
105+
return new CompositeRowToRowMapper(inputSchema,
106+
new[] { (IRowToRowMapper)ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) });
164107
}
165108
}
166109

test/Microsoft.ML.Functional.Tests/ModelFiles.cs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,43 @@ void AssertIsGam(ITransformer trans)
213213
Done();
214214
}
215215

216+
public class ModelInput
217+
{
218+
#pragma warning disable SA1401
219+
public string[] CategoricalFeatures;
220+
public float[] NumericalFeatures;
221+
#pragma warning restore SA1401
222+
}
223+
224+
public class ModelOutput
225+
{
226+
#pragma warning disable SA1401
227+
public float[] Score;
228+
#pragma warning restore SA1401
229+
}
230+
231+
232+
[Fact]
233+
public void LoadModelWithOptionalColumnTransform()
234+
{
235+
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
236+
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
237+
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
238+
var mlContext = new MLContext();
239+
ITransformer trainedModel;
240+
DataViewSchema dataViewSchema;
241+
using (var stream = new FileStream(Path.Join(Directory.GetCurrentDirectory(), @"..\..\..\..\test\data\backcompat\modelwithoptionalcolumntransform.zip"),
242+
FileMode.Open, FileAccess.Read, FileShare.Read))
243+
{
244+
trainedModel = mlContext.Model.Load(stream, out dataViewSchema);
245+
}
246+
247+
var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
248+
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });
249+
250+
Assert.Equal(1, prediction.Score[0]);
251+
}
252+
216253
[Fact]
217254
public void SaveAndLoadModelWithLoader()
218255
{

test/Microsoft.ML.Tests/Scenarios/WordBagTest.cs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ public static void WordBags()
3131
var textPipeline =
3232
mlContext.Transforms.Text.ProduceWordBags("Text", "Text",
3333
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf).Append(
34-
mlContext.Transforms.Text.ProduceWordBags("Text2", "Text2",
34+
mlContext.Transforms.Text.ProduceWordBags("Text2", new[] { "Text2", "Text2" },
3535
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf));
3636

3737

3838
var textTransformer = textPipeline.Fit(dataview);
39-
var transformedDataView = textTransformer.Transform(dataview);
4039
var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData, TransformedTextData>(textTransformer);
4140
var prediction = predictionEngine.Predict(samples[0]);
4241
Assert.Equal(prediction.Text, new float[] {
@@ -46,6 +45,35 @@ public static void WordBags()
4645
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 });
4746
}
4847

48+
[Fact]
49+
public static void WordBagsHash()
50+
{
51+
var mlContext = new MLContext();
52+
var samples = new List<TextData>()
53+
{
54+
new TextData(){ Text = "This is an example to compute bag-of-word features." },
55+
new TextData(){ Text = "ML.NET's ProduceWordBags API produces bag-of-word features from input text." },
56+
new TextData(){ Text = "It does so by first tokenizing text/string into words/tokens then " },
57+
new TextData(){ Text = "computing n-grams and their neumeric values." },
58+
new TextData(){ Text = "Each position in the output vector corresponds to a particular n-gram." },
59+
new TextData(){ Text = "The value at each position corresponds to," },
60+
new TextData(){ Text = "the number of times n-gram occured in the data (Tf), or" },
61+
new TextData(){ Text = "the inverse of the number of documents contain the n-gram (Idf)," },
62+
new TextData(){ Text = "or compute both and multipy together (Tf-Idf)." },
63+
};
64+
65+
var dataview = mlContext.Data.LoadFromEnumerable(samples);
66+
var textPipeline =
67+
mlContext.Transforms.Text.ProduceHashedWordBags("Text", "Text", ngramLength: 3, useAllLengths: false).Append(
68+
mlContext.Transforms.Text.ProduceHashedWordBags("Text2", new[] { "Text2", "Text2" }, ngramLength: 3, useAllLengths: false));
69+
70+
71+
var textTransformer = textPipeline.Fit(dataview);
72+
var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData, TransformedTextData>(textTransformer);
73+
var prediction = predictionEngine.Predict(samples[0]);
74+
Assert.Equal(65536, prediction.Text.Length);
75+
}
76+
4977
private class TextData
5078
{
5179
public string Text { get; set; }
9.05 KB
Binary file not shown.

0 commit comments

Comments
 (0)