Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Conversation

@ain-soph
Copy link
Contributor

@ain-soph ain-soph commented May 9, 2022

Fix a small bug

    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)

Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
@ain-soph
Copy link
Contributor Author

ain-soph commented May 9, 2022

I follow the tutorial and implement the version without using functorch. I wonder what's the advantage of using functorch?

import torch
import torch.nn as nn
from torch.nn.utils import _stateless

import functools

def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,
        parameters: dict[str, nn.Parameter] = None,
        compute='full') -> torch.Tensor:
    einsum_expr: str = ''
    match compute:
        case 'full':
            einsum_expr = 'Naf,Mbf->NMab'
        case 'trace':
            einsum_expr = 'Naf,Maf->NM'
        case 'diagonal':
            einsum_expr = 'Naf,Maf->NMa'
        case _:
            raise ValueError(compute)

    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    def func(*params: torch.Tensor, _input: torch.Tensor = None):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output  # (N, C)

    jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input1), values, vectorize=True)
    jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input2), values, vectorize=True)
    jac1 = [j.flatten(2) for j in jac1]
    jac2 = [j.flatten(2) for j in jac2]
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
    return result

@Chillee
Copy link
Contributor

Chillee commented May 10, 2022

Thanks!

@Chillee Chillee merged commit a7a8e66 into pytorch:main May 10, 2022
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants