@@ -44,6 +44,7 @@ def __new__(
4444 ) -> _Feature :
4545 tensor = cls ._to_tensor (data , dtype = dtype , device = device , requires_grad = requires_grad )
4646 output = tensor .as_subclass (_Feature )
47+ output ._tensor = tensor
4748 return output
4849
4950 @classmethod
@@ -108,7 +109,9 @@ def __torch_function__(
108109 # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
109110 # will retain the input type. Thus, we need to unwrap here.
110111 if isinstance (output , cls ):
111- return output .as_subclass (torch .Tensor )
112+ tensor = output .as_subclass (torch .Tensor )
113+ output ._tensor = tensor
114+ return tensor
112115
113116 return output
114117
@@ -134,23 +137,19 @@ def _F(self) -> ModuleType:
134137 # this way we return the result without passing into __torch_function__
135138 @property
136139 def shape (self ) -> _size : # type: ignore[override]
137- with DisableTorchFunction ():
138- return super ().shape
140+ return self ._tensor .shape
139141
140142 @property
141143 def ndim (self ) -> int : # type: ignore[override]
142- with DisableTorchFunction ():
143- return super ().ndim
144+ return self ._tensor .ndim
144145
145146 @property
146147 def device (self , * args : Any , ** kwargs : Any ) -> _device : # type: ignore[override]
147- with DisableTorchFunction ():
148- return super ().device
148+ return self ._tensor .device
149149
150150 @property
151151 def dtype (self ) -> _dtype : # type: ignore[override]
152- with DisableTorchFunction ():
153- return super ().dtype
152+ return self ._tensor .dtype
154153
155154 def horizontal_flip (self ) -> _Feature :
156155 return self
0 commit comments