-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathnumerics.py
More file actions
131 lines (110 loc) · 5.03 KB
/
numerics.py
File metadata and controls
131 lines (110 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Connects all half, float and double tensors to CheckNumericsOp."""
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("verify_tensor_all_finite")
def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
Args:
t: Tensor to check.
msg: Message to log on failure.
name: A name for this operation (optional).
x: Alias for t.
message: Alias for msg.
Returns:
Same tensor as `t`.
"""
x = deprecation.deprecated_argument_lookup("x", x, "t", t)
message = deprecation.deprecated_argument_lookup(
"message", message, "msg", msg)
return verify_tensor_all_finite_v2(x, message, name)
@tf_export("debugging.assert_all_finite", v1=[])
@dispatch.add_dispatch_support
def verify_tensor_all_finite_v2(x, message, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
>>> @tf.function
... def f(x):
... x = tf.debugging.assert_all_finite(x, 'Input x must be all finite')
... return x + 1
>>> f(tf.constant([np.inf, 1, 2]))
Traceback (most recent call last):
...
InvalidArgumentError: ...
Args:
x: Tensor to check.
message: Message to log on failure.
name: A name for this operation (optional).
Returns:
Same tensor as `x`.
"""
with ops.name_scope(name, "VerifyFinite", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
with ops.colocate_with(x):
verify_input = array_ops.check_numerics(x, message=message)
out = control_flow_ops.with_dependencies([verify_input], x)
return out
@tf_export(v1=["add_check_numerics_ops"])
def add_check_numerics_ops():
"""Connect a `tf.debugging.check_numerics` to every floating point tensor.
`check_numerics` operations themselves are added for each `half`, `float`,
or `double` tensor in the current default graph. For all ops in the graph, the
`check_numerics` op for all of its (`half`, `float`, or `double`) inputs
is guaranteed to run before the `check_numerics` op on any of its outputs.
Note: This API is not compatible with the use of `tf.cond` or
`tf.while_loop`, and will raise a `ValueError` if you attempt to call it
in such a graph.
Returns:
A `group` op depending on all `check_numerics` ops added.
Raises:
ValueError: If the graph contains any numeric operations in a control flow
structure.
RuntimeError: If called with eager execution enabled.
@compatibility(eager)
Not compatible with eager execution. To check for `Inf`s and `NaN`s under
eager execution, call `tf.debugging.enable_check_numerics()` once before
executing the checked operations.
@end_compatibility
"""
if context.executing_eagerly():
raise RuntimeError(
"add_check_numerics_ops() is not compatible with eager execution. "
"To check for Inf's and NaN's under eager execution, call "
"tf.debugging.enable_check_numerics() once before executing the "
"checked operations.")
check_op = []
# This code relies on the ordering of ops in get_operations().
# The producer of a tensor always comes before that tensor's consumer in
# this list. This is true because get_operations() returns ops in the order
# added, and an op can only be added after its inputs are added.
for op in ops.get_default_graph().get_operations():
for output in op.outputs:
if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
if op._get_control_flow_context() is not None: # pylint: disable=protected-access
raise ValueError("`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`.")
message = op.name + ":" + str(output.value_index)
with ops.control_dependencies(check_op):
check_op = [array_ops.check_numerics(output, message=message)]
return control_flow_ops.group(*check_op)