Skip to content

k-planes implementation#1584

Closed
hturki wants to merge 6 commits intonerfstudio-project:mainfrom
hturki:ht/k-planes
Closed

k-planes implementation#1584
hturki wants to merge 6 commits intonerfstudio-project:mainfrom
hturki:ht/k-planes

Conversation

@hturki
Copy link
Copy Markdown
Contributor

@hturki hturki commented Mar 10, 2023

This should be a pretty faithful implementation of the official K-Planes implementation at https://github.com/sarafridov/K-Planes.

It can load models trained with the official implementation and gives comparable PSNR metrics when trained from scratch (31.69 vs 31.86 db in PSNR)

@hturki hturki force-pushed the ht/k-planes branch 15 times, most recently from 8675187 to fe5a82c Compare March 12, 2023 01:31
@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 12, 2023

Update - the original difference in PSNR was due to a dataloader bug in the original implementation (sarafridov/K-Planes#11). Both implementations now give comparable PSNR values and I've updated the PR description accordingly.

Copy link
Copy Markdown
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reimplementing k-planes.

I left some initial comments. One broader one is that I would like to avoid adding the alpha compositing in this PR. It's mostly tangential to k-planes and touches many base files.



@dataclass
class CosineWithWarmupSchedulerConfig(SchedulerConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now a CosineDecaySchedulerConfig, it can be used instead.

# in x,y,z order
camera_to_world[..., 3] *= self.scale_factor
scene_box = SceneBox(aabb=torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], dtype=torch.float32))
radius = 1.3 if "ship" not in str(self.data) else 1.5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove special logic per scene.

Copy link
Copy Markdown
Contributor

@Giodiro Giodiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hturki , Thanks for the PR! I left a few minor comments on places which can be cleaned up; as well as the normalization of time which is important.
I'll try to add a commit on top for the missing regularizers!

CONSOLE = Console(width=120)


def init_grid_param(grid_nd: int, in_dim: int, out_dim: int, reso: List[int], a: float = 0.1, b: float = 0.5):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here grid_nd is always 2, it would be cleaner to remove the configuration option and other references.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in_dim can be easily inferred from reso (in_dim = len(reso)), best to remove it from the arguments.

def interpolate_ms_features(
pts: torch.Tensor,
ms_grids: Union[nn.ModuleList, List[nn.ParameterList]],
grid_dimensions: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly grid_dimensions is always 2

# Input should be in [-1, 1]
grid_input = (grid_input - self.aabb[0]) * (2.0 / (self.aabb[1] - self.aabb[0])) - 1.0

if self.grid_config[0]["input_coordinate_dim"] == 4:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time needs to be normalized: in nerfstudio datasets it's in [0, 1] but for interpolation it must be in [-1, 1] (do timestamps = (timestamps * 2) - 1)

grid_input = (grid_input - self.aabb[0]) * (2.0 / (self.aabb[1] - self.aabb[0])) - 1.0

if ray_samples.times is not None and self.hexplane:
grid_input = torch.cat([grid_input, ray_samples.times.view(-1, 1)], -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, time needs to be normalized


grid_config: List[Dict] = field(
default_factory=lambda: [
{"grid_dimensions": 2, "input_coordinate_dim": 3, "output_coordinate_dim": 32, "resolution": [64, 64, 64]}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here grid_dimensions is always 2, I would delete it.

]
)

is_ndc: bool = False
Copy link
Copy Markdown
Contributor

@Giodiro Giodiro Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_ndc is not really used

multiscale_res: List[int] = field(default_factory=lambda: [1, 2, 4, 8])
concat_features_across_scales: bool = True
linear_decoder: bool = False
linear_decoder_layers: Optional[int] = 4
Copy link
Copy Markdown
Contributor

@Giodiro Giodiro Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to 1 is better for the linear decoder layers

self.config.proposal_update_every,
)

if self.config.is_contracted or self.config.is_ndc:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here there's no point in using the is_ndc config as nerfstudio doesn't have ndc datasets

metrics_dict = {}
metrics_dict["psnr"] = self.psnr(outputs["rgb"], image)
if self.training:
metrics_dict["plane_tv"] = compute_grid_tv(self.field.grids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these should be computed in the get_loss_dict function directly (they're not really metrics like PSNR is).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally like being able to look at these values before they get scaled with loss coefficients, but I don't have very strong opinions if others disagree

@iSach
Copy link
Copy Markdown
Contributor

iSach commented Mar 13, 2023

Is there a reason you left behind the temporal losses? (sparse transients, temporal smoothness)

Also, I'm not sure I understand the code for the plane TV loss. In the paper, they apply it only on space-only planes and in 1D on space-time planes. However, the implementation here computes in 2D two times over space-only planes, and one time over space-time planes?

    total = 0.0
    for grids in multi_res_grids:
        if len(grids) == 3:
            spatial_grids = [0, 1, 2]
        else:
            spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal
        for grid_id in spatial_grids:
            total += compute_plane_tv(grids[grid_id])
        for grid in grids:
            total += compute_plane_tv(grid)

@Giodiro
Copy link
Copy Markdown
Contributor

Giodiro commented Mar 13, 2023

To add to @iSach 's comment, you can have a look here for his implementation of the regularizers!

@iSach
Copy link
Copy Markdown
Contributor

iSach commented Mar 13, 2023

Is there a reason you left behind the temporal losses? (sparse transients, temporal smoothness)

Also, I'm not sure I understand the code for the plane TV loss. In the paper, they apply it only on space-only planes and in 1D on space-time planes. However, the implementation here computes in 2D two times over space-only planes, and one time over space-time planes?

    total = 0.0
    for grids in multi_res_grids:
        if len(grids) == 3:
            spatial_grids = [0, 1, 2]
        else:
            spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal
        for grid_id in spatial_grids:
            total += compute_plane_tv(grids[grid_id])
        for grid in grids:
            total += compute_plane_tv(grid)

Actually, I don't even understand exactly what is going on in the "compute_plane_tv" function.

    batch_size, c, h, w = t.shape
    count_h = batch_size * c * (h - 1) * w
    count_w = batch_size * c * h * (w - 1)
    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).sum()
    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).sum()
    return 2 * (h_tv / count_h + w_tv / count_w)  # This is summing over batch and c instead of avg
  • batch_size is always 1. Why is there this dimension? It's not the batch size, it's number of models.
  • the sum() then dividing by count_h (or w) is exactly the same here as doing .mean() on h_tv (or w_tv). What is meant here by summing over batch and c instead of avg?
  • Why is the result multiplied by 2?

If anybody understands this better, I'd like to know about it.

From the equation in the paper:

I would have done this:

def compute_plane_tv(t, only_w=False):
    batch_size, c, h, w = t.shape
    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).mean()
    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).mean()
    return h_tv + w_tv if not only_w else w_tv

def space_tv_loss(multi_res_grids):
    total = 0.0
    for grids in multi_res_grids:
        if len(grids) == 3:
            spatial_grids = [0, 1, 2]
        else:
            spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal

        for grid_id, grid in enumerate(grids):
            if grid_id in spatial_grids:
                total += compute_plane_tv(grid)
            else:
                # Space is the last dimension for space-time planes.
                total += compute_plane_tv(grid, only_w=True)
    return total

@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 13, 2023

Higher-level comments:

  • Will split out the alpha compositing into its own PR
  • Re: some of the TV-related code - I mainly copied code from the official repo at https://github.com/sarafridov/K-Planes to make sure that the port was faithful / that the numbers line up with the paper. I'd refer any questions to the official implementation, but happy to simplify things / cut out parts that aren't relevant to nerfstudio
  • Re: not adding time-related losses - I was hoping to start with that's needed for static scenes in this PR (which I should have mentioned in the PR description), as the paper mentions doing things like importance sampling for the DyNeRF scenes that would likely involve making more changes than what I wanted to do with this initial PR. I'm happy to add the time losses in this PR, but I (or someone else) would ideally do any other required things (like the importance sampling) in a follow-up to this

@Giodiro
Copy link
Copy Markdown
Contributor

Giodiro commented Mar 13, 2023

Is there a reason you left behind the temporal losses? (sparse transients, temporal smoothness)
Also, I'm not sure I understand the code for the plane TV loss. In the paper, they apply it only on space-only planes and in 1D on space-time planes. However, the implementation here computes in 2D two times over space-only planes, and one time over space-time planes?

    total = 0.0
    for grids in multi_res_grids:
        if len(grids) == 3:
            spatial_grids = [0, 1, 2]
        else:
            spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal
        for grid_id in spatial_grids:
            total += compute_plane_tv(grids[grid_id])
        for grid in grids:
            total += compute_plane_tv(grid)

Actually, I don't even understand exactly what is going on in the "compute_plane_tv" function.

    batch_size, c, h, w = t.shape
    count_h = batch_size * c * (h - 1) * w
    count_w = batch_size * c * h * (w - 1)
    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).sum()
    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).sum()
    return 2 * (h_tv / count_h + w_tv / count_w)  # This is summing over batch and c instead of avg
* batch_size is always 1. Why is there this dimension? It's not the batch size, it's number of models.

* the sum() then dividing by count_h (or w) is exactly the same here as doing .mean() on h_tv (or w_tv). What is meant here by summing over batch and c instead of avg?

* Why is the result multiplied by 2?

If anybody understands this better, I'd like to know about it.

From the equation in the paper:

I would have done this:

def compute_plane_tv(t, only_w=False):
    batch_size, c, h, w = t.shape
    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).mean()
    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).mean()
    return h_tv + w_tv if not only_w else w_tv

def space_tv_loss(multi_res_grids):
    total = 0.0
    for grids in multi_res_grids:
        if len(grids) == 3:
            spatial_grids = [0, 1, 2]
        else:
            spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal

        for grid_id, grid in enumerate(grids):
            if grid_id in spatial_grids:
                total += compute_plane_tv(grid)
            else:
                # Space is the last dimension for space-time planes.
                total += compute_plane_tv(grid, only_w=True)
    return total

Totally agree. Our implementation was dragged over from previous code, and especially the normalization was very arbitrary. Not sure about normalizing the space dim on space-time planes (it should work, but needs testing), but otherwise looks good to me! Another thing is that we might want l1 TV instead of l2, but better leave it to a later date. Of course relative loss weights will differ from the paper but it doesn't matter!

@Giodiro
Copy link
Copy Markdown
Contributor

Giodiro commented Mar 13, 2023

Higher-level comments:

* Will split out the alpha compositing into its own PR

* Re: some of the TV-related code - I mainly copied code from the official repo at https://github.com/sarafridov/K-Planes to make sure that the port was faithful / that the numbers line up with the paper. I'd refer any questions to the official implementation, but happy to simplify things / cut out parts that aren't relevant to nerfstudio

* Re: not adding time-related losses - I was hoping to start with that's needed for static scenes in this PR (which I should have mentioned in the PR description), as the paper mentions doing things like importance sampling for the DyNeRF scenes that would likely involve making more changes than what I wanted to do with this initial PR. I'm happy to add the time losses in this PR, but I (or someone else) would ideally do any other required things (like the importance sampling) in a follow-up to this

Regarding time: the PR already seems to support time-datasets, you can test that with dycheck or dnerf which are in nerfstudio. importance sampling is definitely out of scope though!

@iSach
Copy link
Copy Markdown
Contributor

iSach commented Mar 13, 2023

I indeed think importance sampling is out of scope, but it can be a very important convergence factor.

I have a large dataset (10500 images) with a large scale scene with small dynamic objects, and it made qualitative results incredibly better by implementing (some sort of) importance sampling.

However, my implementation changes more and bigger parts of Nerfstudio (and is probably not the cleanest), so I'll probably do it separately.

@hturki hturki force-pushed the ht/k-planes branch 4 times, most recently from f2ed1eb to a19cb6d Compare March 15, 2023 05:18
@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 15, 2023

@tancik I think I've addressed most of the feedback in this PR, but I'm now getting a test failure (https://github.com/nerfstudio-project/nerfstudio/actions/runs/4423210082/jobs/7755716203) - do you know what might be going on?

Copy link
Copy Markdown
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on moving the k-planes encoding out of the field and into encodings.py similar to what we did with tensorf. This will make it easier for others to use these encodings in their projects.



def init_grid_param(out_dim: int, reso: List[int], a: float = 0.1, b: float = 0.5):
"""Initializes the grid parameters."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all docstrings can you follow the form,

""" Desc

Args:
      a: stuff
      
Returns:
      desc
"""

far_plane: float = 6.0
"""How far along the ray to stop sampling."""

grid_config: List[Dict] = field(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add argstrings to all arguments

@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 17, 2023

@tancik I think we could move the k-planes encoding to encodings.py, but there seems to be a TriplaneEncoding class there already for TensorRF that behaves differently from the K-Planes one and I'm not sure it would be straightforward to cleanly merge in the functionality needed to support both methods into that single encoding. Should I create a separate KPlanesEncoding or do you have a better suggestion?

@tancik
Copy link
Copy Markdown
Contributor

tancik commented Mar 17, 2023

@tancik I think we could move the k-planes encoding to encodings.py, but there seems to be a TriplaneEncoding class there already for TensorRF that behaves differently from the K-Planes one and I'm not sure it would be straightforward to cleanly merge in the functionality needed to support both methods into that single encoding. Should I create a separate KPlanesEncoding or do you have a better suggestion?

Creating a KPlanesEncoding sounds good to me.

@akristoffersen
Copy link
Copy Markdown
Contributor

Is there something that's blocking this PR? Happy to get the branch updated / get through the final cleanup if no one else is working on it.

@Giodiro
Copy link
Copy Markdown
Contributor

Giodiro commented Mar 24, 2023

If we can merge the encoding first (#1658) and use it from this PR it should be good to go?

@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 24, 2023

hey all, I've been dealing with a paper-related deadline but getting this over the line should be straightforward and I'll take care of it today

@hturki hturki force-pushed the ht/k-planes branch 3 times, most recently from df2bf79 to 2e03ecb Compare March 24, 2023 22:12
@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 28, 2023

hey @tancik - are there any todos that you'd still like for me to address in this PR?

@tancik
Copy link
Copy Markdown
Contributor

tancik commented Mar 28, 2023

Sorry for the delay, Im traveling at the moment, won't have a chance to do a thorough look for a few days.

@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Mar 28, 2023

no rush on my side!

@tancik
Copy link
Copy Markdown
Contributor

tancik commented Apr 17, 2023

Sorry for taking so long to get back to this. Since my last comment we have started rethinking how we organize new models. Would you be willing to host the k-planes repo on your own github? We have a guide for setting it up here, given this PR, it should slot in nicely. We can then add a page to methods so that others can easily find it. For example, we just did this for instruct-nerf2nerf

@hturki
Copy link
Copy Markdown
Contributor Author

hturki commented Apr 17, 2023

sgtm, this will realistically happen near the end of may at the earliest but will let you know once I give it a go

@Giodiro
Copy link
Copy Markdown
Contributor

Giodiro commented Apr 26, 2023

Hi! K-planes integration is live at https://github.com/Giodiro/kplanes_nerfstudio
I assumed the k-planes embedding will still live in nerfstudio, but if you want to change that let me know!

I'll write a PR for the docs soon, and I think we can close this one!

@tancik
Copy link
Copy Markdown
Contributor

tancik commented Apr 26, 2023

Hi! K-planes integration is live at https://github.com/Giodiro/kplanes_nerfstudio I assumed the k-planes embedding will still live in nerfstudio, but if you want to change that let me know!

I'll write a PR for the docs soon, and I think we can close this one!

This is great, thanks for your effort! We should keep the embedding in the main repo. Looking forward to the docs PR. Will close this PR.

@tancik tancik closed this Apr 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants