Conversation
8675187 to
fe5a82c
Compare
|
Update - the original difference in PSNR was due to a dataloader bug in the original implementation (sarafridov/K-Planes#11). Both implementations now give comparable PSNR values and I've updated the PR description accordingly. |
tancik
left a comment
There was a problem hiding this comment.
Thanks for reimplementing k-planes.
I left some initial comments. One broader one is that I would like to avoid adding the alpha compositing in this PR. It's mostly tangential to k-planes and touches many base files.
nerfstudio/engine/schedulers.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class CosineWithWarmupSchedulerConfig(SchedulerConfig): |
There was a problem hiding this comment.
There is now a CosineDecaySchedulerConfig, it can be used instead.
| # in x,y,z order | ||
| camera_to_world[..., 3] *= self.scale_factor | ||
| scene_box = SceneBox(aabb=torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], dtype=torch.float32)) | ||
| radius = 1.3 if "ship" not in str(self.data) else 1.5 |
There was a problem hiding this comment.
Remove special logic per scene.
nerfstudio/fields/kplanes_field.py
Outdated
| CONSOLE = Console(width=120) | ||
|
|
||
|
|
||
| def init_grid_param(grid_nd: int, in_dim: int, out_dim: int, reso: List[int], a: float = 0.1, b: float = 0.5): |
There was a problem hiding this comment.
Here grid_nd is always 2, it would be cleaner to remove the configuration option and other references.
There was a problem hiding this comment.
Also in_dim can be easily inferred from reso (in_dim = len(reso)), best to remove it from the arguments.
nerfstudio/fields/kplanes_field.py
Outdated
| def interpolate_ms_features( | ||
| pts: torch.Tensor, | ||
| ms_grids: Union[nn.ModuleList, List[nn.ParameterList]], | ||
| grid_dimensions: int, |
There was a problem hiding this comment.
Similarly grid_dimensions is always 2
nerfstudio/fields/kplanes_field.py
Outdated
| # Input should be in [-1, 1] | ||
| grid_input = (grid_input - self.aabb[0]) * (2.0 / (self.aabb[1] - self.aabb[0])) - 1.0 | ||
|
|
||
| if self.grid_config[0]["input_coordinate_dim"] == 4: |
There was a problem hiding this comment.
Time needs to be normalized: in nerfstudio datasets it's in [0, 1] but for interpolation it must be in [-1, 1] (do timestamps = (timestamps * 2) - 1)
nerfstudio/fields/kplanes_field.py
Outdated
| grid_input = (grid_input - self.aabb[0]) * (2.0 / (self.aabb[1] - self.aabb[0])) - 1.0 | ||
|
|
||
| if ray_samples.times is not None and self.hexplane: | ||
| grid_input = torch.cat([grid_input, ray_samples.times.view(-1, 1)], -1) |
There was a problem hiding this comment.
As above, time needs to be normalized
nerfstudio/models/kplanes.py
Outdated
|
|
||
| grid_config: List[Dict] = field( | ||
| default_factory=lambda: [ | ||
| {"grid_dimensions": 2, "input_coordinate_dim": 3, "output_coordinate_dim": 32, "resolution": [64, 64, 64]} |
There was a problem hiding this comment.
here grid_dimensions is always 2, I would delete it.
nerfstudio/models/kplanes.py
Outdated
| ] | ||
| ) | ||
|
|
||
| is_ndc: bool = False |
There was a problem hiding this comment.
is_ndc is not really used
nerfstudio/models/kplanes.py
Outdated
| multiscale_res: List[int] = field(default_factory=lambda: [1, 2, 4, 8]) | ||
| concat_features_across_scales: bool = True | ||
| linear_decoder: bool = False | ||
| linear_decoder_layers: Optional[int] = 4 |
There was a problem hiding this comment.
default to 1 is better for the linear decoder layers
nerfstudio/models/kplanes.py
Outdated
| self.config.proposal_update_every, | ||
| ) | ||
|
|
||
| if self.config.is_contracted or self.config.is_ndc: |
There was a problem hiding this comment.
Here there's no point in using the is_ndc config as nerfstudio doesn't have ndc datasets
nerfstudio/models/kplanes.py
Outdated
| metrics_dict = {} | ||
| metrics_dict["psnr"] = self.psnr(outputs["rgb"], image) | ||
| if self.training: | ||
| metrics_dict["plane_tv"] = compute_grid_tv(self.field.grids) |
There was a problem hiding this comment.
Maybe these should be computed in the get_loss_dict function directly (they're not really metrics like PSNR is).
There was a problem hiding this comment.
I personally like being able to look at these values before they get scaled with loss coefficients, but I don't have very strong opinions if others disagree
|
Is there a reason you left behind the temporal losses? (sparse transients, temporal smoothness) Also, I'm not sure I understand the code for the plane TV loss. In the paper, they apply it only on space-only planes and in 1D on space-time planes. However, the implementation here computes in 2D two times over space-only planes, and one time over space-time planes? total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
spatial_grids = [0, 1, 2]
else:
spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
for grid_id in spatial_grids:
total += compute_plane_tv(grids[grid_id])
for grid in grids:
total += compute_plane_tv(grid) |
Actually, I don't even understand exactly what is going on in the "compute_plane_tv" function. batch_size, c, h, w = t.shape
count_h = batch_size * c * (h - 1) * w
count_w = batch_size * c * h * (w - 1)
h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).sum()
w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).sum()
return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg
If anybody understands this better, I'd like to know about it. From the equation in the paper: I would have done this: def compute_plane_tv(t, only_w=False):
batch_size, c, h, w = t.shape
h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).mean()
w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).mean()
return h_tv + w_tv if not only_w else w_tv
def space_tv_loss(multi_res_grids):
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
spatial_grids = [0, 1, 2]
else:
spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
for grid_id, grid in enumerate(grids):
if grid_id in spatial_grids:
total += compute_plane_tv(grid)
else:
# Space is the last dimension for space-time planes.
total += compute_plane_tv(grid, only_w=True)
return total |
|
Higher-level comments:
|
Totally agree. Our implementation was dragged over from previous code, and especially the normalization was very arbitrary. Not sure about normalizing the space dim on space-time planes (it should work, but needs testing), but otherwise looks good to me! Another thing is that we might want l1 TV instead of l2, but better leave it to a later date. Of course relative loss weights will differ from the paper but it doesn't matter! |
Regarding time: the PR already seems to support time-datasets, you can test that with dycheck or dnerf which are in nerfstudio. importance sampling is definitely out of scope though! |
|
I indeed think importance sampling is out of scope, but it can be a very important convergence factor. I have a large dataset (10500 images) with a large scale scene with small dynamic objects, and it made qualitative results incredibly better by implementing (some sort of) importance sampling. However, my implementation changes more and bigger parts of Nerfstudio (and is probably not the cleanest), so I'll probably do it separately. |
f2ed1eb to
a19cb6d
Compare
|
@tancik I think I've addressed most of the feedback in this PR, but I'm now getting a test failure (https://github.com/nerfstudio-project/nerfstudio/actions/runs/4423210082/jobs/7755716203) - do you know what might be going on? |
tancik
left a comment
There was a problem hiding this comment.
Thoughts on moving the k-planes encoding out of the field and into encodings.py similar to what we did with tensorf. This will make it easier for others to use these encodings in their projects.
nerfstudio/fields/kplanes_field.py
Outdated
|
|
||
|
|
||
| def init_grid_param(out_dim: int, reso: List[int], a: float = 0.1, b: float = 0.5): | ||
| """Initializes the grid parameters.""" |
There was a problem hiding this comment.
For all docstrings can you follow the form,
""" Desc
Args:
a: stuff
Returns:
desc
"""
nerfstudio/models/kplanes.py
Outdated
| far_plane: float = 6.0 | ||
| """How far along the ray to stop sampling.""" | ||
|
|
||
| grid_config: List[Dict] = field( |
There was a problem hiding this comment.
Add argstrings to all arguments
|
@tancik I think we could move the k-planes encoding to |
Creating a |
|
Is there something that's blocking this PR? Happy to get the branch updated / get through the final cleanup if no one else is working on it. |
|
If we can merge the encoding first (#1658) and use it from this PR it should be good to go? |
|
hey all, I've been dealing with a paper-related deadline but getting this over the line should be straightforward and I'll take care of it today |
df2bf79 to
2e03ecb
Compare
|
hey @tancik - are there any todos that you'd still like for me to address in this PR? |
|
Sorry for the delay, Im traveling at the moment, won't have a chance to do a thorough look for a few days. |
|
no rush on my side! |
|
Sorry for taking so long to get back to this. Since my last comment we have started rethinking how we organize new models. Would you be willing to host the k-planes repo on your own github? We have a guide for setting it up here, given this PR, it should slot in nicely. We can then add a page to methods so that others can easily find it. For example, we just did this for instruct-nerf2nerf |
|
sgtm, this will realistically happen near the end of may at the earliest but will let you know once I give it a go |
|
Hi! K-planes integration is live at https://github.com/Giodiro/kplanes_nerfstudio I'll write a PR for the docs soon, and I think we can close this one! |
This is great, thanks for your effort! We should keep the embedding in the main repo. Looking forward to the docs PR. Will close this PR. |
This should be a pretty faithful implementation of the official K-Planes implementation at https://github.com/sarafridov/K-Planes.
It can load models trained with the official implementation and gives comparable PSNR metrics when trained from scratch (31.69 vs 31.86 db in PSNR)