-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathstate_ops.py
More file actions
1043 lines (831 loc) · 39.4 KB
/
state_ops.py
File metadata and controls
1043 lines (831 loc) · 39.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variables.
See the [Variables](https://www.tensorflow.org/guide/variables) guide.
"""
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import state_grad # pylint: disable=unused-import
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_state_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import deprecation
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
shared_name=""):
"""Deprecated. Used variable_op_v2 instead."""
if not set_shape:
shape = tensor_shape.unknown_shape()
ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
container=container, shared_name=shared_name)
# TODO(mrry): Move this to where it is used, so we can get rid of this op
# wrapper?
if set_shape:
ret.set_shape(shape)
return ret
def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
"""Create a variable Operation.
See also variables.Variable.
Args:
shape: The shape of the tensor managed by this variable
dtype: The underlying type of the tensor values.
name: optional name to use for the variable op.
container: An optional string. Defaults to "".
If non-empty, this variable is placed in the given container.
Otherwise, a default container is used.
shared_name: An optional string. Defaults to "".
If non-empty, this variable is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
Returns:
A variable tensor.
"""
return gen_state_ops.variable_v2(
shape=shape,
dtype=dtype,
name=name,
container=container,
shared_name=shared_name)
def init_variable(v, init, name="init"):
"""Initializes variable with "init".
This op does the following:
if init is a Tensor, v = init
if callable(init): v = init(VariableShape(v), v.dtype)
Args:
v: Variable to initialize
init: Tensor to assign to v,
Or an object convertible to Tensor e.g. nparray,
Or an Initializer that generates a tensor given the shape and type of v.
An "Initializer" is a callable that returns a tensor that "v" should be
set to. It will be called as init(shape, dtype).
name: Optional name for the op.
Returns:
The operation that initializes v.
"""
with ops.name_scope(None, v.op.name + "/", [v, init]):
with ops.name_scope(name) as scope:
with ops.colocate_with(v):
if callable(init):
assert v.get_shape().is_fully_defined(), "Variable shape unknown."
# TODO(mrry): Convert to v.shape when the property and
# accessor are reconciled (and all initializers support
# tf.TensorShape objects).
value = init(v.get_shape().as_list(), v.dtype.base_dtype)
value = ops.convert_to_tensor(value, name="value")
return gen_state_ops.assign(v, value, name=scope)
else:
init = ops.convert_to_tensor(init, name="init")
return gen_state_ops.assign(v, init, name=scope)
def is_variable_initialized(ref, name=None):
"""Checks whether a tensor has been initialized.
Outputs boolean scalar indicating whether the tensor has been initialized.
Args:
ref: A mutable `Tensor`.
Should be from a `Variable` node. May be uninitialized.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `bool`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
return ref.is_initialized(name=name)
@tf_export(v1=["assign_sub"])
def assign_sub(ref, value, use_locking=None, name=None):
"""Update `ref` by subtracting `value` from it.
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value`
must have the same shape.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
`complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
from a `Variable` node.
value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
be subtracted to the variable.
use_locking: An optional `bool`. Defaults to `False`. If True, the
subtraction will be protected by a lock; otherwise the behavior is
undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the new value after the variable has been updated.
@compatibility(TF2)
`tf.compat.v1.assign_sub` is mostly compatible with eager
execution and `tf.function`.
To switch to the native TF2 style, one could use method 'assign_sub' of
`tf.Variable`:
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `ref` | `self` | In `assign_sub()` method |
| `value` | `value` | In `assign_sub()` method |
| `use_locking` | `use_locking` | In `assign_sub()` method |
| `name` | `name` | In `assign_sub()` method |
| - | `read_value` | Set to True to replicate |
: : : behavior (True is default) :
#### Before & After Usage Example
Before:
>>> with tf.Graph().as_default():
... with tf.compat.v1.Session() as sess:
... a = tf.compat.v1.Variable(1, dtype=tf.int64)
... sess.run(a.initializer)
... update_op = tf.compat.v1.assign_sub(a, 1)
... res_a = sess.run(update_op)
... res_a
0
After:
>>> b = tf.Variable(1, dtype=tf.int64)
>>> res_b = b.assign_sub(1)
>>> res_b.numpy()
0
@end_compatibility
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign_sub(
ref, value, use_locking=use_locking, name=name)
return ref.assign_sub(value)
@tf_export(v1=["assign_add"])
def assign_add(ref, value, use_locking=None, name=None):
"""Update `ref` by adding `value` to it.
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have
the same shape.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
`complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
from a `Variable` node.
value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
be added to the variable.
use_locking: An optional `bool`. Defaults to `False`. If True, the addition
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the new value after the variable has been updated.
@compatibility(TF2)
`tf.compat.v1.assign_add` is mostly compatible with eager
execution and `tf.function`.
To switch to the native TF2 style, one could use method 'assign_add' of
`tf.Variable`:
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `ref` | `self` | In `assign_add()` method |
| `value` | `value` | In `assign_add()` method |
| `use_locking` | `use_locking` | In `assign_add()` method |
| `name` | `name` | In `assign_add()` method |
| - | `read_value` | Set to True to replicate |
: : : behavior (True is default) :
#### Before & After Usage Example
Before:
>>> with tf.Graph().as_default():
... with tf.compat.v1.Session() as sess:
... a = tf.compat.v1.Variable(0, dtype=tf.int64)
... sess.run(a.initializer)
... update_op = tf.compat.v1.assign_add(a, 1)
... res_a = sess.run(update_op)
... res_a
1
After:
>>> b = tf.Variable(0, dtype=tf.int64)
>>> res_b = b.assign_add(1)
>>> res_b.numpy()
1
@end_compatibility
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign_add(
ref, value, use_locking=use_locking, name=name)
return ref.assign_add(value)
@tf_export(v1=["assign"])
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update `ref` by assigning `value` to it.
This operation outputs a Tensor that holds the new value of `ref` after
the value has been assigned. This makes it easier to chain operations that
need to use the reset value.
Args:
ref: A mutable `Tensor`. Should be from a `Variable` node. May be
uninitialized.
value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
be assigned to the variable.
validate_shape: An optional `bool`. Defaults to `True`. If true, the
operation will validate that the shape of 'value' matches the shape of the
Tensor being assigned to. If false, 'ref' will take on the shape of
'value'.
use_locking: An optional `bool`. Defaults to `True`. If True, the assignment
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
A `Tensor` that will hold the new value of `ref` after
the assignment has completed.
@compatibility(TF2)
`tf.compat.v1.assign` is mostly compatible with eager
execution and `tf.function`. However, argument 'validate_shape' will be
ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when
constructing the variable:
>>> import tensorflow as tf
>>> a = tf.Variable([1], shape=tf.TensorShape(None))
>>> tf.compat.v1.assign(a, [2,3])
To switch to the native TF2 style, one could use method 'assign' of
`tf.Variable`:
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `ref` | `self` | In `assign()` method |
| `value` | `value` | In `assign()` method |
| `validate_shape` | Not supported | Specify `shape` in the |
: : : constructor to replicate :
: : : behavior :
| `use_locking` | `use_locking` | In `assign()` method |
| `name` | `name` | In `assign()` method |
| - | `read_value` | Set to True to replicate |
: : : behavior (True is default) :
@end_compatibility
#### Before & After Usage Example
Before:
>>> with tf.Graph().as_default():
... with tf.compat.v1.Session() as sess:
... a = tf.compat.v1.Variable(0, dtype=tf.int64)
... sess.run(a.initializer)
... update_op = tf.compat.v1.assign(a, 2)
... res_a = sess.run(update_op)
... res_a
2
After:
>>> b = tf.Variable(0, dtype=tf.int64)
>>> res_b = b.assign(2)
>>> res_b.numpy()
2
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign(
ref, value, use_locking=use_locking, name=name,
validate_shape=validate_shape)
return ref.assign(value, name=name)
@tf_export(v1=["count_up_to"])
@deprecated(None, "Prefer Dataset.range instead.")
def count_up_to(ref, limit, name=None):
r"""Increments 'ref' until it reaches 'limit'.
Args:
ref: A Variable. Must be one of the following types: `int32`, `int64`.
Should be from a scalar `Variable` node.
limit: An `int`.
If incrementing ref would bring it above limit, instead generates an
'OutOfRange' error.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `ref`.
A copy of the input before increment. If nothing else modifies the
input, the values produced will all be distinct.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.count_up_to(ref, limit=limit, name=name)
return gen_state_ops.resource_count_up_to(
ref.handle, limit, T=ref.dtype, name=name)
@tf_export(v1=["scatter_update"])
def scatter_update(ref, indices, updates, use_locking=True, name=None):
# pylint: disable=line-too-long
r"""Applies sparse updates to a variable reference.
This operation computes
```python
# Scalar indices
ref[indices, ...] = updates[...]
# Vector indices (for each i)
ref[indices[i], ...] = updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
If values in `ref` is to be updated more than once, because there are
duplicate entries in `indices`, the order at which the updates happen
for each value is undefined.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
</div>
Args:
ref: A `Variable`.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to store in `ref`.
use_locking: An optional `bool`. Defaults to `True`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_update(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_nd_update"])
def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
r"""Applies sparse `updates` to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to update 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.compat.v1.scatter_nd_update(ref, indices, updates)
with tf.compat.v1.Session() as sess:
print sess.run(update)
```
The resulting update to ref would look like this:
[1, 11, 3, 10, 9, 6, 7, 12]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
ref: A Variable.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into ref.
updates: A `Tensor`. Must have the same type as `ref`.
A Tensor. Must have the same type as ref. A tensor of updated
values to add to ref.
use_locking: An optional `bool`. Defaults to `True`.
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
name: A name for the operation (optional).
Returns:
The value of the variable after the update.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_nd_update(
ref, indices, updates, use_locking, name)
return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_add"])
def scatter_add(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Adds sparse updates to the variable referenced by `resource`.
This operation computes
```python
# Scalar indices
ref[indices, ...] += updates[...]
# Vector indices (for each i)
ref[indices[i], ...] += updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the updated value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
</div>
Args:
ref: A `Variable`.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to store in `ref`.
use_locking: An optional `bool`. Defaults to `False`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_add(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_nd_add"])
def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
r"""Applies sparse addition to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
```
For example, say we want to add 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that addition would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
add = tf.compat.v1.scatter_nd_add(ref, indices, updates)
with tf.compat.v1.Session() as sess:
print sess.run(add)
```
The resulting update to ref would look like this:
[1, 13, 3, 14, 14, 6, 7, 20]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
`qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
`uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into ref.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to add to ref.
use_locking: An optional `bool`. Defaults to `False`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_nd_add(
ref, indices, updates, use_locking, name)
return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_sub"])
def scatter_sub(ref, indices, updates, use_locking=False, name=None):
r"""Subtracts sparse updates to a variable reference.
```python
# Scalar indices
ref[indices, ...] -= updates[...]
# Vector indices (for each i)
ref[indices[i], ...] -= updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their (negated) contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]` or
`updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%"
src="https://www.tensorflow.org/images/ScatterSub.png" alt>
</div>
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
`qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
`uint32`, `uint64`. Should be from a `Variable` node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to subtract from `ref`.
use_locking: An optional `bool`. Defaults to `False`.
If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_sub(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_nd_sub"])
def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
r"""Applies sparse subtraction to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
```
For example, say we want to subtract 4 scattered elements from a rank-1 tensor
with 8 elements. In Python, that update would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
op = tf.compat.v1.scatter_nd_sub(ref, indices, updates)
with tf.compat.v1.Session() as sess:
print sess.run(op)
```
The resulting update to ref would look like this:
[1, -9, 3, -6, -6, 6, 7, -4]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
`qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
`uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A tensor of indices into ref.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to add to ref.
use_locking: An optional `bool`. Defaults to `False`.
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_nd_sub(
ref, indices, updates, use_locking, name)
return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_mul"])
def scatter_mul(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Multiplies sparse updates into a variable reference.
This operation computes
```python
# Scalar indices
ref[indices, ...] *= updates[...]
# Vector indices (for each i)
ref[indices[i], ...] *= updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions multiply.
Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
[]`.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
`qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
`uint32`, `uint64`. Should be from a `Variable` node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
values to multiply to `ref`.
use_locking: An optional `bool`. Defaults to `False`. If True, the operation
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_mul(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_div"])
def scatter_div(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Divides a variable reference by sparse updates.
This operation computes
```python
# Scalar indices
ref[indices, ...] /= updates[...]
# Vector indices (for each i)
ref[indices[i], ...] /= updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
```
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions divide.
Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
[]`.
Args:
ref: A mutable `Tensor`. Must be one of the following types: `float32`,
`float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
`qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
`uint32`, `uint64`. Should be from a `Variable` node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
that `ref` is divided by.
use_locking: An optional `bool`. Defaults to `False`. If True, the operation
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_div(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_max"])
def scatter_max(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Reduces sparse updates into a variable reference using the `max` operation.
This operation computes
# Scalar indices
ref[indices, ...] = max(ref[indices, ...], updates[...])
# Vector indices (for each i)
ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
updates[i, ..., j, ...])
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions combine.
Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
[]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
alt>
</div>
Args:
ref: A mutable `Tensor`. Must be one of the following types: `half`,
`bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
`Variable` node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
values to reduce into `ref`.
use_locking: An optional `bool`. Defaults to `False`. If True, the update
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_max(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_min"])
def scatter_min(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Reduces sparse updates into a variable reference using the `min` operation.
This operation computes
# Scalar indices
ref[indices, ...] = min(ref[indices, ...], updates[...])
# Vector indices (for each i)
ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
updates[i, ..., j, ...])
This operation outputs `ref` after the update is done.
This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions combine.
Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
[]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
alt>
</div>
Args:
ref: A mutable `Tensor`. Must be one of the following types: `half`,
`bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
`Variable` node.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
values to reduce into `ref`.
use_locking: An optional `bool`. Defaults to `False`. If True, the update
will be protected by a lock; otherwise the behavior is undefined, but may
exhibit less contention.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_min(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["batch_scatter_update"])
@deprecation.deprecated(
"2018-11-29", "Use the batch_scatter_update method of Variable instead.")
def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
"""Generalization of `tf.compat.v1.scatter_update` to axis different than 0.
Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
have a series of leading dimensions that are the same for all of them, and the
updates are performed on the last dimension of indices. In other words, the
dimensions should be the following:
`num_prefix_dims = indices.ndims - 1`
`batch_dim = num_prefix_dims + 1`
`updates.shape = indices.shape + var.shape[batch_dim:]`
where
`updates.shape[:num_prefix_dims]`
`== indices.shape[:num_prefix_dims]`
`== var.shape[:num_prefix_dims]`
And the operation performed can be expressed as:
`var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
When indices is a 1D tensor, this operation is equivalent to
`tf.compat.v1.scatter_update`.
To avoid this operation there would be 2 alternatives:
1) Reshaping the variable by merging the first `ndims` dimensions. However,
this is not possible because `tf.reshape` returns a Tensor, which we
cannot use `tf.compat.v1.scatter_update` on.
2) Looping over the first `ndims` of the variable and using
`tf.compat.v1.scatter_update` on the subtensors that result of slicing the
first
dimension. This is a valid option for `ndims = 1`, but less efficient than
this implementation.
See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`.
Args:
ref: `Variable` to scatter onto.
indices: Tensor containing indices as described above.
updates: Tensor of updates to apply to `ref`.
use_locking: Boolean indicating whether to lock the writing operation.
name: Optional scope name string.
Returns:
Ref to `variable` after it has been modified.
Raises:
ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
not the same.
"""
with ops.name_scope(name):