Skip to content

Commit b106ae0

Browse files
author
Harshitha Parnandi Venkata
committed
Changed Image Classification API to take in a VBuffer<byte> type instead of ImagePath.
1 parent edfd10f commit b106ae0

3 files changed

Lines changed: 39 additions & 20 deletions

File tree

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public static void Example()
5959
IDataView testDataset = trainTestData.TestSet;
6060

6161
var pipeline = mlContext.Model.ImageClassification(
62-
"ImagePath", "Label",
62+
"ImageVBuf", "Label",
6363
// Just by changing/selecting InceptionV3 here instead of
6464
// ResnetV2101 you can try a different architecture/pre-trained
6565
// model.
@@ -129,9 +129,13 @@ private static void TrySinglePrediction(string imagesForPredictions,
129129
IEnumerable<ImageData> testImages = LoadImagesFromDirectory(
130130
imagesForPredictions, false);
131131

132+
byte[] imgBytes = File.ReadAllBytes(testImages.First().ImagePath);
133+
VBuffer<Byte> imgData = new VBuffer<byte>(imgBytes.Length, imgBytes);
134+
132135
ImageData imageToPredict = new ImageData
133136
{
134-
ImagePath = testImages.First().ImagePath
137+
ImagePath = testImages.First().ImagePath,
138+
ImageVBuf = imgData
135139
};
136140

137141
var prediction = predictionEngine.Predict(imageToPredict);
@@ -174,7 +178,7 @@ public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
174178

175179
foreach (var file in files)
176180
{
177-
if (Path.GetExtension(file) != ".jpg")
181+
if (Path.GetExtension(file) != ".JPEG" && Path.GetExtension(file) != ".jpg")
178182
continue;
179183

180184
var label = Path.GetFileName(file);
@@ -192,10 +196,15 @@ public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
192196
}
193197
}
194198

199+
// Get the buffer of bytes
200+
byte[] imgBytes = File.ReadAllBytes(Path.Combine(folder, file));
201+
VBuffer<Byte> imgData = new VBuffer<byte>(imgBytes.Length, imgBytes);
202+
195203
yield return new ImageData()
196204
{
197205
ImagePath = file,
198-
Label = label
206+
Label = label,
207+
ImageVBuf = imgData
199208
};
200209

201210
}
@@ -292,6 +301,8 @@ public class ImageData
292301

293302
[LoadColumn(1)]
294303
public string Label;
304+
305+
public VBuffer<byte> ImageVBuf;
295306
}
296307

297308
public class ImagePrediction

docs/samples/Microsoft.ML.Samples/Program.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public static class Program
1010

1111
internal static void RunAll()
1212
{
13+
/*
1314
int samples = 0;
1415
foreach (var type in Assembly.GetExecutingAssembly().GetTypes())
1516
{
@@ -22,8 +23,11 @@ internal static void RunAll()
2223
samples++;
2324
}
2425
}
25-
26+
2627
Console.WriteLine("Number of samples that ran without any exception: " + samples);
28+
*/
29+
ResnetV2101TransferLearningTrainTestSplit.Example();
30+
2731
}
2832
}
2933
}

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -200,16 +200,16 @@ public ImageProcessor(ImageClassificationTransformer transformer)
200200
_imagePreprocessingRunner.AddOutputs(transformer._resizedImageTensorName);
201201
}
202202

203-
public Tensor ProcessImage(string path)
203+
public Tensor ProcessImage(VBuffer<byte> imgBuf)
204204
{
205-
var imageTensor = new Tensor(File.ReadAllBytes(path), TF_DataType.TF_STRING);
205+
var imageTensor = new Tensor(imgBuf.DenseValues().ToArray(), TF_DataType.TF_STRING);
206206
var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0];
207207
imageTensor.Dispose();
208208
return processedTensor;
209209
}
210210
}
211211

212-
private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imagepathColumnName,
212+
private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imageColumnName,
213213
ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath,
214214
ImageClassificationMetrics.Dataset dataset, ImageClassificationMetricsCallback metricsCallback)
215215
{
@@ -220,16 +220,19 @@ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName
220220
labelColumnName, typeof(uint).ToString(),
221221
labelColumn.Type.RawType.ToString());
222222

223-
var imagePathColumn = input.Schema[imagepathColumnName];
223+
var imageBufColumn = input.Schema[imageColumnName];
224+
var imagePathColumn = input.Schema["ImagePath"];
224225
Runner runner = new Runner(_session);
225226
runner.AddOutputs(outputTensorName);
226227

227228
using (TextWriter writer = File.CreateText(cacheFilePath))
228-
using (var cursor = input.GetRowCursor(input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imagePathColumn.Index)))
229+
using (var cursor = input.GetRowCursor(input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imageBufColumn.Index || c.Index == imagePathColumn.Index)))
229230
{
230231
var labelGetter = cursor.GetGetter<uint>(labelColumn);
232+
var imageBufGetter = cursor.GetGetter<VBuffer<byte>>(imageBufColumn);
231233
var imagePathGetter = cursor.GetGetter<ReadOnlyMemory<char>>(imagePathColumn);
232234
UInt32 label = UInt32.MaxValue;
235+
VBuffer<byte> imageBuf = default;
233236
ReadOnlyMemory<char> imagePath = default;
234237
runner.AddInput(inputTensorName);
235238
ImageClassificationMetrics metrics = new ImageClassificationMetrics();
@@ -238,9 +241,10 @@ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName
238241
while (cursor.MoveNext())
239242
{
240243
labelGetter(ref label);
244+
imageBufGetter(ref imageBuf);
241245
imagePathGetter(ref imagePath);
242246
var imagePathStr = imagePath.ToString();
243-
var imageTensor = imageProcessor.ProcessImage(imagePathStr);
247+
var imageTensor = imageProcessor.ProcessImage(imageBuf);
244248
runner.AddInput(imageTensor, 0);
245249
var featurizedImage = runner.Run()[0]; // Reuse memory?
246250
writer.WriteLine(label - 1 + "," + string.Join(",", featurizedImage.ToArray<float>()));
@@ -795,8 +799,8 @@ public Mapper(ImageClassificationTransformer parent, DataViewSchema inputSchema)
795799
private class OutputCache
796800
{
797801
public long Position;
798-
private ValueGetter<ReadOnlyMemory<char>> _imagePathGetter;
799-
private ReadOnlyMemory<char> _imagePath;
802+
private ValueGetter<VBuffer<byte>> _imageBufGetter;
803+
private VBuffer<byte> _imageBuf;
800804
private Runner _runner;
801805
private ImageProcessor _imageProcessor;
802806
public UInt32 PredictedLabel { get; set; }
@@ -805,8 +809,8 @@ private class OutputCache
805809

806810
public OutputCache(DataViewRow input, ImageClassificationTransformer transformer)
807811
{
808-
_imagePath = default;
809-
_imagePathGetter = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[transformer._inputs[0]]);
812+
_imageBuf = default;
813+
_imageBufGetter = input.GetGetter<VBuffer<byte>>(input.Schema[transformer._inputs[0]]);
810814
_runner = new Runner(transformer._session);
811815
_runner.AddInput(transformer._inputTensorName);
812816
_runner.AddOutputs(transformer._softmaxTensorName);
@@ -823,8 +827,8 @@ public void UpdateCacheIfNeeded()
823827
if (_inputRow.Position != Position)
824828
{
825829
Position = _inputRow.Position;
826-
_imagePathGetter(ref _imagePath);
827-
var processedTensor = _imageProcessor.ProcessImage(_imagePath.ToString());
830+
_imageBufGetter(ref _imageBuf);
831+
var processedTensor = _imageProcessor.ProcessImage(_imageBuf);
828832
var outputTensor = _runner.AddInput(processedTensor, 0).Run();
829833
ClassProbabilities = outputTensor[0].ToArray<float>();
830834
PredictedLabel = (UInt32)outputTensor[1].ToArray<long>()[0];
@@ -1189,9 +1193,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
11891193
var input = _options.InputColumns[i];
11901194
if (!inputSchema.TryFindColumn(input, out var col))
11911195
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
1192-
var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes[i]);
1193-
if (col.ItemType != expectedType)
1194-
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
1196+
//var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes[i]);
1197+
//if (col.ItemType != expectedType)
1198+
// throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
11951199
}
11961200
for (var i = 0; i < _options.OutputColumns.Length; i++)
11971201
{

0 commit comments

Comments
 (0)