Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 4, 2025

In SFT and GRPO, the entropy computation induces huge memory spike when training on long context. This PR fixes it, basically from 80 GB spike to around 0.23 GB spike. No significant slowdown observed.

peak_memory_comparison entropy_from_logits_benchmark
import torch
import torch.nn.functional as F

def entropy_from_logits_before(logits, chunk_size: int = 1) -> torch.Tensor:
    per_token_entropies = []
    for logits_chunk in logits.split(chunk_size, dim=0):
        logps = F.log_softmax(logits_chunk, dim=-1)
        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
        per_token_entropies.extend(chunk_entropy)

    per_token_entropies = torch.stack(per_token_entropies)
    return per_token_entropies


def entropy_from_logits_after(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
    original_shape = logits.shape[:-1]
    num_classes = logits.shape[-1]

    flat_logits = logits.reshape(-1, num_classes)
    entropies = []
    for chunk in flat_logits.split(chunk_size, dim=0):
        logps = F.log_softmax(chunk, dim=-1)
        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
        entropies.append(chunk_entropy)

    return torch.cat(entropies, dim=0).reshape(original_shape)


if __name__ == "__main__":
    device = "cuda"

    peak_before = []
    peak_after_64 = []
    peak_after_128 = []
    peak_after_256 = []

    for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
        # Example: batch=1, seq_len=16k, vocab=50k
        batch_size, seq_len, vocab_size = 1, seq_len, 150_000
        logits = torch.randn(batch_size, seq_len, vocab_size, device=device).contiguous()

        torch.cuda.empty_cache()
        mem_before = torch.cuda.memory_allocated(device)
        torch.cuda.reset_peak_memory_stats(device)
        entropy_from_logits_before(logits)
        mem_peak_total = torch.cuda.max_memory_allocated(device)
        mem_peak = mem_peak_total - mem_before
        peak_before.append(mem_peak / 1e9)

        torch.cuda.empty_cache()
        mem_before = torch.cuda.memory_allocated(device)
        torch.cuda.reset_peak_memory_stats(device)
        entropy_from_logits_after(logits, chunk_size=64)
        mem_peak_total = torch.cuda.max_memory_allocated(device)
        mem_peak = mem_peak_total - mem_before
        peak_after_64.append(mem_peak / 1e9)

        torch.cuda.empty_cache()
        mem_before = torch.cuda.memory_allocated(device)
        torch.cuda.reset_peak_memory_stats(device)
        entropy_from_logits_after(logits, chunk_size=128)
        mem_peak_total = torch.cuda.max_memory_allocated(device)
        mem_peak = mem_peak_total - mem_before
        peak_after_128.append(mem_peak / 1e9)

        torch.cuda.empty_cache()
        mem_before = torch.cuda.memory_allocated(device)
        torch.cuda.reset_peak_memory_stats(device)
        entropy_from_logits_after(logits, chunk_size=256)
        mem_peak_total = torch.cuda.max_memory_allocated(device)
        mem_peak = mem_peak_total - mem_before
        peak_after_256.append(mem_peak / 1e9)
    
    import matplotlib.pyplot as plt

    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], peak_before, label="Before")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], peak_after_64, label="After (chunk_size=64)")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], peak_after_128, label="After (chunk_size=128 (default))")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], peak_after_256, label="After (chunk_size=256)")
    plt.xlabel("Sequence Length")
    plt.ylabel("Peak Memory (GB)")
    plt.title("Peak Memory Usage Before and After Refactor")
    plt.legend()
    plt.ylim(0, 10)
    plt.grid()
    plt.savefig("peak_memory_comparison.png")
import torch
import torch.nn.functional as F


import torch
import torch.nn.functional as F
import timeit
def entropy_from_logits_before(logits, chunk_size: int = 1) -> torch.Tensor:
    per_token_entropies = []
    for logits_chunk in logits.split(chunk_size, dim=0):
        logps = F.log_softmax(logits_chunk, dim=-1)
        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
        per_token_entropies.extend(chunk_entropy)

    per_token_entropies = torch.stack(per_token_entropies)
    return per_token_entropies


def entropy_from_logits_after(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
    original_shape = logits.shape[:-1]
    num_classes = logits.shape[-1]

    flat_logits = logits.reshape(-1, num_classes)
    entropies = []
    for chunk in flat_logits.split(chunk_size, dim=0):
        logps = F.log_softmax(chunk, dim=-1)
        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
        entropies.append(chunk_entropy)

    return torch.cat(entropies, dim=0).reshape(original_shape)


if __name__ == "__main__":
    device = "cuda"
    N=25

    time_before = []
    time_after_64 = []
    time_after_128 = []
    time_after_256 = []
    time_after_512 = []
    time_after_1024 = []
    time_after_2048 = []

    for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
        # Example: batch=1, seq_len=16k, vocab=50k
        batch_size, seq_len, vocab_size = 1, seq_len, 150_000
        logits = torch.randn(batch_size, seq_len, vocab_size, device=device).contiguous()

        ###################################################################
        def timed_fn():
            torch.cuda.synchronize()  # wait for all previous CUDA ops
            entropy = entropy_from_logits_before(logits)
            torch.cuda.synchronize()  # wait until entropy computation finishes
            return entropy

        elapsed_time = timeit.timeit(timed_fn, number=N)/N
        time_before.append(elapsed_time)

        ###################################################################
        def timed_fn():
            torch.cuda.synchronize()  # wait for all previous CUDA ops
            entropy = entropy_from_logits_after(logits, chunk_size=64)
            torch.cuda.synchronize()  # wait until entropy computation finishes
            return entropy

        elapsed_time = timeit.timeit(timed_fn, number=N)/N
        time_after_64.append(elapsed_time)

        ###################################################################
        def timed_fn():
            torch.cuda.synchronize()  # wait for all previous CUDA ops
            entropy = entropy_from_logits_after(logits, chunk_size=128)
            torch.cuda.synchronize()  # wait until entropy computation finishes
            return entropy

        elapsed_time = timeit.timeit(timed_fn, number=N)/N
        time_after_128.append(elapsed_time)

        ###################################################################
        def timed_fn():
            torch.cuda.synchronize()  # wait for all previous CUDA ops
            entropy = entropy_from_logits_after(logits, chunk_size=256)
            torch.cuda.synchronize()  # wait until entropy computation finishes
            return entropy

        elapsed_time = timeit.timeit(timed_fn, number=N)/N
        time_after_256.append(elapsed_time)

    import matplotlib.pyplot as plt

    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], time_before, label="Before")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], time_after_64, label="After (chunk_size=64)")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], time_after_128, label="After (chunk_size=128 (default))")
    plt.plot(["1024", "2048", "4096", "8192", "16384", "32768"], time_after_256, label="After (chunk_size=256)")
    plt.xlabel("Sequence Length")
    plt.ylabel("Time per call (s)")
    plt.title("Entropy from logits computation time")
    plt.legend()
    plt.grid()
    plt.savefig("entropy_from_logits_benchmark.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Refactor entropy_from_logits for memory efficiency 🌵 Refactor entropy_from_logits for memory efficiency Sep 4, 2025
@qgallouedec qgallouedec merged commit deae7e0 into main Sep 4, 2025
10 of 11 checks passed
@qgallouedec qgallouedec deleted the reduce-mem-peak branch September 4, 2025 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants