-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[core] refactor AbstractTrainer #4804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] refactor AbstractTrainer #4804
Conversation
1340302 to
f5e9042
Compare
f5e9042 to
8320420
Compare
| from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels | ||
| from autogluon.core.calibrate import calibrate_decision_threshold | ||
| from autogluon.core.calibrate.conformity_score import compute_conformity_score | ||
| from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling | ||
| from autogluon.core.callbacks import AbstractCallback | ||
| from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS | ||
| from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary | ||
| from autogluon.core.metrics import Scorer, compute_metric, get_metric | ||
| from autogluon.core.models import ( | ||
| AbstractModel, | ||
| BaggedEnsembleModel, | ||
| GreedyWeightedEnsembleModel, | ||
| SimpleWeightedEnsembleModel, | ||
| StackerEnsembleModel, | ||
| WeightedEnsembleModel, | ||
| ) | ||
| from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match | ||
| from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager | ||
| from autogluon.core.utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why switch from relative to absolute imports? Tabular, core, common, and features all use relative imports.
Is there a general guideline you are following for favoring absolute over relative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing particular in this case, it's just that these will move into tabular anyway. I can switch to relative imports for the ones that remain after that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8 recommends absolute imports but I would vote in favor of consistency with other modules here.
Innixma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Added some minor comments but overall this is a great step in the right direction for unifying the API of trainer.
| if not isinstance(model, str): | ||
| model = model.name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic and variants of it can probably be put into utility methods since they are so common. Maybe as a fast follow PR. Will change a lot of 2-4 line logic into 1 line, and improve type hinting in the IDE so we can be more explicit about the type of the variable in the code.
| if not self.low_memory: | ||
| self.models[model.name] = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The low_memory logic is something I implemented a very long time ago but at a certain point gave up on testing/using, since users didn't really care too much that files were being saved on disk, and it was a hassle trying to get everything working in-memory.
If time-series isn't using it, we may consider removing the low_memory logic entirely in a follow-up PR, it would make the code simpler and avoid us having to juggle multiple input/output variable types depending on the low_memory setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. 👍
| @property | ||
| def path_root(self) -> str: | ||
| """directory containing learner.pkl""" | ||
| return os.path.dirname(self.path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember wanting to refactor some of this logic as currently we have self.path and self.path_root which is a bit confusing. I might look into it closer as a follow-up.
The overall problem is that self.path doesn't contain all of the artifacts needed by the Trainer, such as those from path_utils. In an ideal world having it be self contained would be nice, kind of how Predictor is self-contained. But I would need to think about it more, because I also thought about a world where we could have multiple trainers for a single predictor, in which case it would be good for trainers to re-use certain artifacts between them.
shchur
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this herculean effort! Only a several questions and minor comments
| from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels | ||
| from autogluon.core.calibrate import calibrate_decision_threshold | ||
| from autogluon.core.calibrate.conformity_score import compute_conformity_score | ||
| from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling | ||
| from autogluon.core.callbacks import AbstractCallback | ||
| from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS | ||
| from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary | ||
| from autogluon.core.metrics import Scorer, compute_metric, get_metric | ||
| from autogluon.core.models import ( | ||
| AbstractModel, | ||
| BaggedEnsembleModel, | ||
| GreedyWeightedEnsembleModel, | ||
| SimpleWeightedEnsembleModel, | ||
| StackerEnsembleModel, | ||
| WeightedEnsembleModel, | ||
| ) | ||
| from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match | ||
| from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager | ||
| from autogluon.core.utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8 recommends absolute imports but I would vote in favor of consistency with other modules here.
| model = model.name | ||
| self.model_graph.nodes[model][attribute] = val | ||
|
|
||
| def get_minimum_model_set(self, model, include_self=True) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Some methods are missing type hints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a nit at all and thanks for pointing it out.
a heads up, one thing that will come up in the next PR is how we handle this AbstractModel type being passed around. Because TimeSeriesTrainer will like to only work with AbstractTimeSeriesModel, however if we constrain it trivially in the function signature then pyright will complain LSP is violated. We will then have to revisit this class to use generics...
class AbstractTrainer(Generic[T]):
def work_with_model(model: T):
...
class AbstractTimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
...
note the syntax gets much lighter for working with generics starting from 3.11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was really hoping that by specializing in Python instead of C++ I will be able to avoid learning what a generic is... welp
1fa84e4 to
7842f9b
Compare
Issue #, if available:
Description of changes:
This is the first PR of a few for a major refactor of
AbstractTrainerunifying the interfaces ofAbstractTimeSeriesTrainerand tabular's Trainer class into one abstract class in core.This PR factors out
AbstractTrainerincore, including model management behavior that is common in time series and tabular, as well as common interfaces.Subsequent PRs will
AbstractTrainerin core to tabular, since it is now almost exclusively tabular specific concretizations.SimpleAbstractTrainerfrom timeseries, and move it to core.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.