-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Experimental logging/counters API #18235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/jit/interpreter.cpp
Outdated
| // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n"; | ||
| } | ||
| pc = new_pc; | ||
| logging::getLogger()->addStatValue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be very careful about logging in hot portions of the core runtime, since it adds a synchronization point between threads which can quickly become contended. It's not clear to me whether the counters we have here are valuable enough to justify that overhead—they seem too low level to tell us very much.
torch/csrc/jit/script/logging.cpp
Outdated
| case AggregationType::SUM: { | ||
| float sum = 0; | ||
| for (auto x : kv.second) | ||
| sum += x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces here please, and at the other loops in this file
torch/csrc/jit/register_prim_ops.cpp
Outdated
| // HACK alert: stuffing a struct into a ByteTensor | ||
| Operator("prim::TimePoint() -> Tensor", [](const Node* node) { | ||
| return [](Stack& stack) { | ||
| logging::JITTimePoint *ptr = new logging::JITTimePoint(logging::timePoint()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment to revisit this when c++-bound script classes are available
| TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp); | ||
|
|
||
| namespace runtime_counters { | ||
| constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use Symbol for these?
torch/csrc/jit/script/logging.h
Outdated
| public: | ||
| TORCH_API virtual void addStatValue( | ||
| const std::string& stat_name, | ||
| float val) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets use an int64_t instead of a float. Generally the stat values represent countable units, either events or milliseconds or whatever. I think it better matches fb303's interface as well.
torch/csrc/jit/script/logging.cpp
Outdated
| std::shared_ptr<LoggerBase> global_logger = std::make_shared<NoopLogger>(); | ||
|
|
||
| std::shared_ptr<LoggerBase> getLogger() { | ||
| std::unique_lock<std::mutex> lk(m); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will definitely be a problem, since all logging to all keys will contend for this lock. Do we have a RW lock/shared mutex implementation that we can use here? If we don't have that, we'll have to come up with a better way to make this thread safe
torch/csrc/jit/script/logging.h
Outdated
| TORCH_API virtual void addStatValue( | ||
| const std::string& stat_name, | ||
| float val) = 0; | ||
| TORCH_API virtual std::unordered_map<std::string, float> getCounters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe getCounterValue is a better interface, to avoid aggregating/copying the entire logger when we want to get a single counter
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not opposed to having some logging. But at is size, it seems like too much code to support into the future given what it provides. I think it can be cut down quite a bit by using existing bindings and simplifying the aggregation code.
torch/csrc/jit/register_prim_ops.cpp
Outdated
| }), | ||
| Operator("prim::BumpCounter(str key, int val) -> ()", [](const Node* node) { | ||
| return [](Stack& stack) { | ||
| auto val = pop(stack).toInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: use in-order pop rather than stack manipulation code.
std::string key;
int64_t val;
pop(stack, key, value);
torch/csrc/jit/register_prim_ops.cpp
Outdated
| return 0; | ||
| }; | ||
| }), | ||
| Operator("prim::GetCounters(str name) -> int", [](const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: GetCounter
torch/csrc/jit/script/init.cpp
Outdated
| } else if ( | ||
| obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { | ||
| return std::make_shared<AnnotateValue>(); | ||
| } else if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please to not edit this core code with arbitrary primitive builtins. There is already a way to associate existing python functions with builtins _find_builtin below.
torch/csrc/jit/register_prim_ops.cpp
Outdated
| }; | ||
| })}); | ||
| }), | ||
| Operator("prim::BumpCounter(str key, int val) -> ()", [](const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just use Peter's API for registering custom, tracable, operators? This would automatically generate all of the (1) implementation, (2) python binding, and (3) tracing code for you. It is what it was designed for. As it currently is, this PR reproduces that work.
| return previous; | ||
| } | ||
|
|
||
| JITTimePoint timePoint() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just use inline int64_t getTime( which returns nanoseconds as an int64_t. This means we stick with the type system already available in TorchScript.
torch/csrc/jit/script/logging.h
Outdated
|
|
||
| private: | ||
| mutable std::mutex m; | ||
| std::unordered_map<std::string, std::vector<int64_t>> raw_counters; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just:
struct Counter {
int64_t sum;
size_t count;
};
std::unordered_map<std::string, Counter> raw_counters;
You get both sum an avg, and do not need to ever allocate memory.
| : AggregationType::SUM; | ||
| vals = &raw_counters.at(name); | ||
| } | ||
| switch (type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
race! You released the lock on the hash table. It has resized and then moved the contained vector to another address. *vals is invalid.
|
This looks amazing How does this play in cuda world, since cuda is asynchronous? |
|
@sidazhang I believe this would not accurately measure CUDA runtime for the reason you point out. I think i'll make a follow-up task to address that |
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation looks good. I have questions about naming consistency and what the API should be when retrieving a counter value.
| auto key = pop(stack).toString(); | ||
|
|
||
| auto schema = parseSchema("prim::BumpCounter(str key, int val) -> ()"); | ||
| if (jit::tracer::isTracing()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the fact that we have to manually write tracing code for these. This is because of a bug in the custom API that @smessmer has a PR with a fix out for. Can we make sure to update this when the fix lands?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use torch::RegisterOperators instead of torch::jit::RegisterOperators? It doesn't give you a Stack but concrete types and you don't need to do tracing manually. Or is this a special case that wouldn't work with that?
torch/csrc/jit/script/logging.cpp
Outdated
| void recordDurationSince(const std::string& name, JITTimePoint tp) { | ||
| auto end = std::chrono::high_resolution_clock::now(); | ||
| // Measurement in microseconds. | ||
| auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e6; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not nanoseconds? Otherwise granularity is at the microsecond level for no good reason.
torch/csrc/jit/register_prim_ops.cpp
Outdated
| auto val = pop(stack).toInt(); | ||
| auto key = pop(stack).toString(); | ||
|
|
||
| auto schema = parseSchema("prim::BumpCounter(str key, int val) -> ()"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this called prim::BumpCounter when the C++ function is called addStatValue? Can this be made consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I didn't want to have to rebuild the world after changing interned_strings.h :p let me change it now
torch/csrc/jit/script/logging.h
Outdated
| TORCH_API virtual void addStatValue( | ||
| const std::string& stat_name, | ||
| int64_t val) = 0; | ||
| TORCH_API virtual int64_t getCounterValue(const std::string& name) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected contract here? At this level it is abstract what getCounterValue returns. In the example it is aggregated value, but the aggregation type is set per-stat by a subclass. It is weird that half the API is in the abstract class.
| _(prim, SetAttr) \ | ||
| _(prim, GetAttr) \ | ||
| _(prim, AddStatValue) \ | ||
| _(prim, GetCounter) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dead
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API:
- `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++
- `torch/jit/_logging.py` specifies the Python API
We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API.
From the frontend perspective, we can create log events in two ways:
1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter.
2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations.
Examples of frontend usage can be found in `test_jit.py TestLogging`.
We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API.
Pull Request resolved: pytorch/pytorch#18235
Differential Revision: D14545060
Pulled By: jamesr66a
fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881
|
@jamesr66a merged this pull request in 85f3601. |
Summary: Resubmit #20698 which got messed up. Idea is that when PyTorch is used in a custom build environment (e.g. Facebook), it's useful to track usage of various APIs centrally. This PR introduces a simple very lightweight mechanism to do so - only first invocation of a trigger point would be logged. This is significantly more lightweight than #18235 and thus we can allow to put logging in e.g. TensorImpl. Also adds an initial list of trigger points. Trigger points are added in such a way that no static initialization triggers them, i.e. just linking with libtorch.so will not cause any logging. Further suggestions of what to log are welcomed. Pull Request resolved: #20745 Differential Revision: D15429196 Pulled By: dzhulgakov fbshipit-source-id: a5e41a709a65b7ebccc6b95f93854e583cf20aca
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually used by anyone? It seems that we're starting to get a lot of requests for logging in the library and maybe we should start a more principled approach instead of developing 3 different interfaces
| previous = global_logger.load(); | ||
| } | ||
| return previous; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leaks the previous logger
| } | ||
|
|
||
| JITTimePoint timePoint() { | ||
| return JITTimePoint{std::chrono::high_resolution_clock::now()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::chrono is extremely slow on many platforms. We have better timing code in the autograd benchmark.
| public: | ||
| TORCH_API virtual void addStatValue( | ||
| const std::string& stat_name, | ||
| int64_t val) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we make this into a generic API that takes an IValue instead?
This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API:
torch/csrc/jit/script/logging.hspecifies the externally-facing API in C++torch/jit/_logging.pyspecifies the Python APIWe use an interface,
LoggerBase, to define the interactions between users and a logging backend. Implementing a subclass ofLoggerBaseallows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API.From the frontend perspective, we can create log events in two ways:
add_stat_value(name, val)function. This calls into the Logger backend with a key/value pair. For example, we might calladd_stat_value('foo', 1)to bump an event counter.time_point()function to record a timestamp in nanoseconds. This can be used in conjunction withadd_stat_valueto record runtime wall clock durations.Examples of frontend usage can be found in
test_jit.py TestLogging.We provide a trivial
LockingLoggerimplementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via theget_counters()API.