Skip to content

Commit c3df0e0

Browse files
committed
Fix bug
1 parent 9935c3c commit c3df0e0

1 file changed

Lines changed: 4 additions & 14 deletions

File tree

src/Microsoft.ML.Transforms/Dracula/Featurizer.cs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Diagnostics;
76
using System.Linq;
87
using Microsoft.ML;
98
using Microsoft.ML.Data;
@@ -157,12 +156,11 @@ public void GetFeatures(int iCol, int iSlot, Random rand, long key, Span<float>
157156
countsBuffer[i++] = count;
158157
}
159158

160-
if (rand != null)
161-
sum = AddLaplacianNoisePerLabel(iCol, rand, countsBuffer);
159+
sum = AddLaplacianNoisePerLabel(iCol, rand, countsBuffer);
162160

163161
// add log odds in the next _logOddsCount indices.
164162
GenerateLogOdds(iCol, countTable, countsBuffer, features.Slice(_labelBinCount, _logOddsCount), sum);
165-
AssertValidOutput(features);
163+
_host.Assert(FloatUtils.IsFinite(features));
166164

167165
// Add the last feature: an indicator for isGarbage.
168166
features[NumFeatures - 1] = isGarbage ? 1 : 0;
@@ -173,12 +171,11 @@ public void GetFeatures(int iCol, int iSlot, Random rand, long key, Span<float>
173171
private float AddLaplacianNoisePerLabel(int iCol, Random rand, Span<float> counts)
174172
{
175173
_host.Assert(_labelBinCount == counts.Length);
176-
_host.AssertValue(rand);
177174

178175
float sum = 0;
179176
for (int ifeat = 0; ifeat < _labelBinCount; ifeat++)
180177
{
181-
if (LaplaceScale[iCol] > 0)
178+
if (rand != null && LaplaceScale[iCol] > 0)
182179
counts[ifeat] += Stats.SampleFromLaplacian(rand, 0, LaplaceScale[iCol]);
183180

184181
// Clamp to zero when noise is too big and negative.
@@ -192,7 +189,7 @@ private float AddLaplacianNoisePerLabel(int iCol, Random rand, Span<float> count
192189
}
193190

194191
// Fills _labelBinCount log odds features. One per class, or only one if 2 classes.
195-
private void GenerateLogOdds(int iCol, ICountTable countTable, Span<float> counts, Span<float> logOdds, Single sum)
192+
private void GenerateLogOdds(int iCol, ICountTable countTable, Span<float> counts, Span<float> logOdds, float sum)
196193
{
197194
_host.Assert(counts.Length == _labelBinCount);
198195
_host.Assert(logOdds.Length == _logOddsCount);
@@ -210,12 +207,5 @@ private void GenerateLogOdds(int iCol, ICountTable countTable, Span<float> count
210207
}
211208
}
212209
}
213-
214-
[Conditional("DEBUG")]
215-
private void AssertValidOutput(Span<float> features)
216-
{
217-
foreach (var feature in features)
218-
_host.Assert(FloatUtils.IsFinite(feature));
219-
}
220210
}
221211
}

0 commit comments

Comments
 (0)