@@ -189,8 +189,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
189189 return ( jpegData , resizedImage ) ;
190190 }
191191
192- private static Tensor Encode ( VBuffer < byte > buffer , int length )
192+ private static Tensor Encode ( VBuffer < byte > buffer )
193193 {
194+ int length = buffer . Length ;
194195 var size = c_api . TF_StringEncodedSize ( ( UIntPtr ) length ) ;
195196 var handle = c_api . TF_AllocateTensor ( TF_DataType . TF_STRING , IntPtr . Zero , 0 , ( UIntPtr ) ( ( ulong ) size + 8 ) ) ;
196197 //AllocationType = AllocationType.Tensorflow;
@@ -221,7 +222,7 @@ public ImageProcessor(ImageClassificationTransformer transformer)
221222
222223 public Tensor ProcessImage ( VBuffer < byte > imgBuf )
223224 {
224- var imageTensor = Encode ( imgBuf , imgBuf . Length ) ;
225+ var imageTensor = Encode ( imgBuf ) ;
225226 var processedTensor = _imagePreprocessingRunner . AddInput ( imageTensor , 0 ) . Run ( ) [ 0 ] ;
226227 imageTensor . Dispose ( ) ;
227228 return processedTensor ;
@@ -1170,7 +1171,7 @@ internal sealed class Options : TransformInputBase
11701171 private readonly IHost _host ;
11711172 private readonly Options _options ;
11721173 private readonly DnnModel _dnnModel ;
1173- private readonly TF_DataType [ ] _tfInputTypes ;
1174+ private readonly DataViewType [ ] _inputTypes ;
11741175 private readonly DataViewType [ ] _outputTypes ;
11751176 private ImageClassificationTransformer _transformer ;
11761177
@@ -1179,7 +1180,7 @@ internal ImageClassificationEstimator(IHostEnvironment env, Options options, Dnn
11791180 _host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( ImageClassificationEstimator ) ) ;
11801181 _options = options ;
11811182 _dnnModel = dnnModel ;
1182- _tfInputTypes = new [ ] { TF_DataType . TF_STRING } ;
1183+ _inputTypes = new [ ] { new VectorDataViewType ( NumberDataViewType . Byte ) } ;
11831184 _outputTypes = new [ ] { new VectorDataViewType ( NumberDataViewType . Single ) , NumberDataViewType . UInt32 . GetItemType ( ) } ;
11841185 }
11851186
@@ -1206,9 +1207,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
12061207 var input = _options . InputColumns [ i ] ;
12071208 if ( ! inputSchema . TryFindColumn ( input , out var col ) )
12081209 throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input ) ;
1209- // var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes [i]) ;
1210- // if (col.ItemType != expectedType)
1211- // throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
1210+ var expectedType = _inputTypes [ i ] ;
1211+ if ( ! col . ItemType . Equals ( expectedType ) )
1212+ throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , expectedType . ToString ( ) , col . ItemType . ToString ( ) ) ;
12121213 }
12131214 for ( var i = 0 ; i < _options . OutputColumns . Length ; i ++ )
12141215 {
0 commit comments