-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Towards #1798 . #2170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Towards #1798 . #2170
Conversation
rogancarr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved with some questions / comments!
| string weights = null) | ||
| { | ||
| Contracts.CheckValue(ctx, nameof(ctx)); | ||
| var env = CatalogUtils.GetEnvironment(ctx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we do this when the other API needs the same context? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the new convention is to have two APIs: one for basic usage that let's you set the column names, and most used params, the other one that takes in Options, and allows to set everything.
In reply to: 248767099 [](ancestors = 248767099)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I meant, why do we generate a new env when the other API needs the same context as we have in ctx? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the trainer ctor needs an IHostEnvironment.
The
CatalogUtils.GetEnvironment(ctx)
does not generate a new environment, just gets the one from within the context.
In reply to: 248781925 [](ancestors = 248781925)
| string labelColumn = DefaultColumnNames.Label, | ||
| string featureColumn = DefaultColumnNames.Features, | ||
| Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null) | ||
| Action<SymSgdClassificationTrainer.Options> advancedSettings = null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This argument is unused. Drop it? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// <param name="labelColumn">The labelColumn column.</param> | ||
| /// <param name="featureColumn">The features column.</param> | ||
| /// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
| public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too long? #Resolved
| /// </summary> | ||
| /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
| /// <param name="options">Algorithm advanced options.</param> | ||
| public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too long? #Resolved
| internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) | ||
| internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), | ||
| TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need the IsExplicit() field here anymore. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/> | ||
| /// </summary> | ||
| internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) | ||
| internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: args => options? #Resolved
| ShortName = ShortName, | ||
| XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""OLS""]/*' />" })] | ||
| public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input) | ||
| public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: input => options? #Resolved
|
|
||
| public override TrainerInfo Info { get; } | ||
| private readonly Arguments _args; | ||
| private readonly Options _args; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _args => _options? #Resolved
| ShortName = SymSgdClassificationTrainer.ShortName, | ||
| XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""SymSGD""]/*' />" })] | ||
| public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input) | ||
| public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: input => options? #Resolved
| EntryPointUtils.CheckInputArgs(host, input); | ||
|
|
||
| return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
| return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ditto #Resolved
bbcbf71 to
0976ad9
Compare
|
PR could use a more descriptive title. I have to dereference the #1798 pointer to understand or dive into the description instead of just scan the repo's PR titles. |
| /// </summary> | ||
| /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
| /// <param name="labelColumn">The labelColumn column.</param> | ||
| /// <param name="labelColumn">The label column.</param> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
column [](start = 48, length = 6)
Is it column or column name? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me do that on a separate PR; since it touches a lot more files, and is orthogonal to the purpose of this PR.
In reply to: 248854888 [](ancestors = 248854888)
| /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param> | ||
| /// <param name="labelColumn">The labelColumn column.</param> | ||
| /// <param name="labelColumn">The label column.</param> | ||
| /// <param name="featureColumn">The featureColumn column.</param> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
featureColumn [](start = 44, length = 13)
feature vector column's name? #Pending
|
There are many doc strings describe their parameters as a column but they are actually columns' names. Do we want to distinguish column from column name in public APIs? @TomFinley, @eerhardt, any input please? #Pending |
| public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsLinearRegressionModelParameters>, OlsLinearRegressionModelParameters> | ||
| { | ||
| public sealed class Arguments : LearnerInputBaseWithWeight | ||
| public sealed class Options : LearnerInputBaseWithWeight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is LearnerInputBaseWithWeights still a good name? Would it be better to use LearnerOptionBaseWithWeight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // incredibly uesful thing to have around. | ||
| /// <summary> | ||
| /// L2 regularization weight. | ||
| /// </summary> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge them into summary. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
// Adding L2 regularization turns this into a form of ridge regression,
// rather than, strictly speaking, ordinary least squares. But it is an
// incredibly useful thing to have around.
looks more like a doc string.
In reply to: 248858191 [](ancestors = 248858191,248857752)
| public float L2Weight = 1e-6f; | ||
|
|
||
| /// <summary> | ||
| /// Whether to calculate per parameter significance statistics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parameter [](start = 41, length = 9)
parameter --> parameter (e.g., the coefficient of the i-th input feature) #Resolved
| } | ||
|
|
||
| /// <summary> | ||
| /// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Predict [](start = 12, length = 7)
It doesn't look like this function accepts a feature vector and then produce a float. #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can accept a feature vector in the options. Is your comment different?
In reply to: 248869365 [](ancestors = 248869365)
| public sealed class Options : LearnerInputBaseWithLabel | ||
| { | ||
| /// <summary> | ||
| /// Degree of lock-free parallelism. Determinism not guaranteed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Determinism not guaranteed. [](start = 49, length = 27)
not guaranteed if a value grater than one is set. #Resolved
| { | ||
| /// <summary> | ||
| /// Degree of lock-free parallelism. Determinism not guaranteed. | ||
| /// Multi-threading is not supported currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multi-threading is not supported currently. [](start = 16, length = 43)
Does it mean that this argument is not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! It doesn't seem like it is getting used. I know we switched the parallelism off in the native code, due to MKL being sequential. I'll check if that is the case, and log an issue.
In reply to: 248885364 [](ancestors = 248885364)
| public int NumberOfIterations = 50; | ||
|
|
||
| /// <summary> | ||
| /// Tolerance for difference in average loss in consecutive passes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add If the reduction on loss is smaller than the specified tolerance in one iteration, training process will be terminated. #Resolved
| public float Tolerance = 1e-4f; | ||
|
|
||
| /// <summary> | ||
| /// Learning rate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe extend it to Learning rate. A larger value can potentially reduce the training time but incur numerical instability and over-fitting. #Resolved
wschin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of my comments are around doc strings. Not sure if you want to address here them but I think they should not block maybe.
| private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount) | ||
| { | ||
| int numFeatures = data.Schema.Feature.Value.Type.GetVectorSize(); | ||
| var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| CursOpt.Weight [](start = 100, length = 16)
You don't accept weight in options, and I don't see any calls to Weight column.
Would be nice to remove it from cursor creation. #Resolved
This PR addresses the estimators inside HalLearners: Two public extension methods, one for simple arguments and the other for advanced options Delete unecessary constructors Pass Options objects as arguments instead of Action delegate Rename Arguments to Options Rename Options objects as options (instead of args or advancedSettings used so far)
@wschin your feedback is incorporated. I appreciate the time to look at the PR, and the level of detail + suggestions. |
Fixing a typo.
04ecc89 to
5f4095a
Compare
This PR addresses the estimators inside HalLearners:
Two public extension methods, one for simple arguments and the other for advanced options
Delete unecessary constructors
Pass Options objects as arguments instead of Action delegate
Rename Arguments to Options
Rename Options objects as options (instead of args or advancedSettings used so far)