-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy path_thin_plate_splines.py
More file actions
251 lines (201 loc) · 8.01 KB
/
_thin_plate_splines.py
File metadata and controls
251 lines (201 loc) · 8.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
from typing import Self
import numpy as np
from scipy.spatial import distance_matrix
from _skimage2._shared.utils import check_nD, _deprecate_estimate, FailedEstimation
class ThinPlateSplineTransform:
"""Thin-plate spline transformation.
Given two matching sets of points, source and destination, this class
estimates the thin-plate spline (TPS) transformation which transforms
each point in source into its destination counterpart.
Attributes
----------
src : array_like of shape (N, 2)
Coordinates of control points in source image.
References
----------
.. [1] Bookstein, Fred L. "Principal warps: Thin-plate splines and the
decomposition of deformations," IEEE Transactions on pattern analysis
and machine intelligence 11.6 (1989): 567–585.
DOI:`10.1109/34.24792`
https://user.engineering.uiowa.edu/~aip/papers/bookstein-89.pdf
Examples
--------
>>> import skimage as ski
Define source and destination control points such that they simulate
rotating by 90 degrees and generate a meshgrid from them:
>>> src = np.array([[0, 0], [0, 5], [5, 5], [5, 0]])
>>> dst = np.array([[5, 0], [0, 0], [0, 5], [5, 5]])
Estimate the transformation:
>>> tps = ski.transform.ThinPlateSplineTransform.from_estimate(src, dst)
Appyling the transformation to `src` approximates `dst`:
>>> np.round(tps(src), 4) # doctest: +FLOAT_CMP
array([[5., 0.],
[0., 0.],
[0., 5.],
[5., 5.]])
Create a meshgrid to apply the transformation to:
>>> grid = np.meshgrid(np.arange(5), np.arange(5))
>>> grid[1]
array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]])
>>> coords = np.vstack([grid[0].ravel(), grid[1].ravel()]).T
>>> transformed = tps(coords)
>>> np.round(transformed[:, 1]).reshape(5, 5).astype(int)
array([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
The estimation can fail - for example, if all the input or output points
are the same. If this happens, you will get a transform that is not
"truthy" - meaning that ``bool(tform)`` is ``False``:
>>> if tps:
... print("Estimation succeeded.")
Estimation succeeded.
Not so for a degenerate transform with identical points.
>>> bad_src = np.ones((4, 2))
>>> bad_tps = ski.transform.ThinPlateSplineTransform.from_estimate(
... bad_src, dst)
>>> if not bad_tps:
... print("Estimation failed.")
Estimation failed.
Trying to use this failed estimation transform result will give a suitable
error:
>>> bad_tps.params # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
FailedEstimationAccessError: No attribute "params" for failed estimation ...
"""
def __init__(self):
self._estimated = False
self._spline_mappings = None
self.src = None
def __call__(self, coords):
"""Estimate the transformation from a set of corresponding points.
Parameters
----------
coords : array_like of shape (N, 2)
x, y coordinates to transform
Returns
-------
transformed_coords: ndarray of shape (N, D)
Destination coordinates.
"""
if self._spline_mappings is None:
msg = (
"Transformation is undefined, define it by calling `estimate` "
"before applying it"
)
raise ValueError(msg)
coords = np.array(coords)
if coords.ndim != 2 or coords.shape[1] != 2:
msg = "Input `coords` must have shape (N, 2)"
raise ValueError(msg)
radial_dist = self._radial_distance(coords)
transformed_coords = self._spline_function(coords, radial_dist)
return transformed_coords
@property
def inverse(self):
raise NotImplementedError("Not supported")
@classmethod
def from_estimate(cls, src, dst) -> Self | FailedEstimation:
"""Estimate optimal spline mappings between source and destination points.
Parameters
----------
src : array_like of shape (N, 2)
Control points at source coordinates.
dst : array_like of shape (N, 2)
Control points at destination coordinates.
Returns
-------
tform : Self or ``FailedEstimation``
An instance of the transformation if the estimation succeeded.
Otherwise, we return a special ``FailedEstimation`` object to
signal a failed estimation. Testing the truth value of the failed
estimation object will return ``False``. E.g.
.. code-block:: python
tform = ThinPlateSplineTransform.from_estimate(...)
if not tform:
raise RuntimeError(f"Failed estimation: {tf}")
Notes
-----
The number N of source and destination points must match.
"""
tf = cls()
msg = tf._estimate(src, dst)
return tf if msg is None else FailedEstimation(f'{cls.__name__}: {msg}')
def _estimate(self, src, dst):
"""Try to estimate and return reason if estimation fails."""
check_nD(src, 2, arg_name="src")
check_nD(dst, 2, arg_name="dst")
if src.shape[0] < 3 or dst.shape[0] < 3:
msg = "Need at least 3 points in in `src` and `dst`"
raise ValueError(msg)
if src.shape != dst.shape:
msg = f"Shape of `src` and `dst` didn't match, {src.shape} != {dst.shape}"
raise ValueError(msg)
self.src = src
n, d = src.shape
dist = distance_matrix(src, src)
K = self._radial_basis_kernel(dist)
P = np.hstack([np.ones((n, 1)), src])
n_plus_3 = n + 3
L = np.zeros((n_plus_3, n_plus_3), dtype=np.float32)
L[:n, :n] = K
L[:n, -3:] = P
L[-3:, :n] = P.T
V = np.vstack([dst, np.zeros((d + 1, d))])
try:
self._spline_mappings = np.linalg.solve(L, V)
except np.linalg.LinAlgError:
return 'Unable to solve for spline mappings'
return None
def _radial_distance(self, coords):
"""Compute the radial distance between input points and source points."""
dists = distance_matrix(coords, self.src)
return self._radial_basis_kernel(dists)
def _spline_function(self, coords, radial_dist):
"""Estimate the spline function in X and Y directions."""
n = self.src.shape[0]
w = self._spline_mappings[:n]
a = self._spline_mappings[n:]
transformed_coords = a[0] + np.dot(coords, a[1:]) + np.dot(radial_dist, w)
return transformed_coords
@staticmethod
def _radial_basis_kernel(r):
"""Compute the radial basis function for thin-plate splines.
Parameters
----------
r : (4, N) ndarray
Input array representing the Euclidean distance between each pair of
two collections of control points.
Returns
-------
U : (4, N) ndarray
Calculated kernel function U.
"""
_small = 1e-8 # Small value to avoid divide-by-zero
r_sq = r**2
U = np.where(r == 0.0, 0.0, r_sq * np.log(r_sq + _small))
return U
@_deprecate_estimate
def estimate(self, src, dst):
"""Estimate optimal spline mappings between source and destination points.
Parameters
----------
src : array_like of shape (N, 2)
Control points at source coordinates.
dst : array_like of shape (N, 2)
Control points at destination coordinates.
Returns
-------
success: bool
True indicates that the estimation was successful.
Notes
-----
The number N of source and destination points must match.
"""
return self._estimate(src, dst) is None