Skip to content

Conversation

@wheeze01
Copy link

What does this PR do?

Fixes a crash in Exaone4Config.__init__ when sliding_window_pattern is None (EXAONE-4.0-1.2B) or a string like "LLLG" (EXAONE-4.0-32B). The original code unconditionally performed a modulo operation on sliding_window_pattern, causing either a ZeroDivisionError or a TypeError. It also removed an incorrect "sliding_window" key check that left _attn_implementation unset. Now:

  • We branch safely on three cases for sliding_window_pattern:

    1. None or 0 → all layers use "full_attention".
    2. str (e.g. "LLLG") → map each character (L"sliding_attention", others → "full_attention"), repeat to cover all layers, and force the final layer to "full_attention".
    3. positive int (e.g. 4) → every n‑th layer is "full_attention", others "sliding_attention", final layer forced "full_attention".
  • We remove the incorrect check for "sliding_window" in layer_types and no longer force _attn_implementation="hybrid"; we let Hugging Face’s internal _check_and_adjust_attn_implementation decide the proper backend (e.g., "eager", "sdpa", "flash_attention_*").

This resolves both the division/modulo crash and the risk of _attn_implementation remaining None downstream.

Fixes #39696

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Models:

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: exaone4

@wheeze01
Copy link
Author

wheeze01 commented Jul 26, 2025

What does this PR do?

This PR fixes two issues in Exaone4Config.__init__ when loading EXAONE‑4.0 models:

  1. Crash on sliding_window_pattern=None or a string
    The original code unconditionally did

    (i + 1) % sliding_window_pattern

    even when sliding_window_pattern was None (EXAONE-4.0-1.2B) → ZeroDivisionError, or a string like "LLLG" (EXAONE-4.0-32B) → TypeError.

  2. Incorrect _attn_implementation override
    It looked for the literal "sliding_window" token in layer_types and forced _attn_implementation="hybrid", but the correct layer marker is "sliding_attention". This left _attn_implementation = None downstream, causing subtle failures.

Now, we safely branch on three cases for sliding_window_pattern:

  1. None or 0 → every layer is "full_attention".
  2. str (e.g. "LLLG") → map each character: L"sliding_attention", others → "full_attention"; repeat to cover all layers; force the final layer to "full_attention".
  3. positive int (e.g. 4) → every 4th layer is "full_attention", others "sliding_attention"; force the final layer to "full_attention".

We also remove the incorrect "sliding_window" check and no longer force _attn_implementation="hybrid". Instead, we defer to HF’s internal _check_and_adjust_attn_implementation logic (choosing "eager", "sdpa", "flash_attention_*", etc.) for proper backend selection.

Fixes #39696


Why were these changes needed?

  1. Prevent ZeroDivisionError / TypeError

    • Unconditional % sliding_window_pattern fails when it’s None or non-numeric.
  2. Support both string and integer patterns

    • None/0: No sliding window → all layers use full_attention.
    • String (e.g. "LLLG"): Blueprint for local/global attention per layer.
    • Integer (e.g. 4): Periodic full-attention every N layers.
  3. Let HF pick the best attention backend

    • Removing the hardcoded "hybrid" default ensures Transformers chooses between "eager", "sdpa", "flash_attention_2", "flash_attention_3", etc., according to availability and best performance.

Manual test script (scripts/manual_test_exaone4.py)

import argparse
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, help="Model ID or local path")
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda")
    args = parser.parse_args()

    # 1) Load config and check layer_types & attn_impl
    cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
    impl = getattr(cfg, "_attn_implementation", None)
    expect_hybrid = "32B" in args.model
    if expect_hybrid:
        assert impl is not None and impl != "eager", "32B model should use a hybrid backend"
    else:
        assert impl == "eager", "1.2B model should use eager backend"
    print("attn_impl OK:", impl)

    # 2) Simple forward pass
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    ).to(args.device)
    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    sample_ids = tokenizer("Hello", return_tensors="pt").input_ids.to(args.device)
    out = model(sample_ids)
    print("Forward OK: logits", out.logits.shape)

    # 3) Causal mask and sliding-window mask
    seq_len = 8
    dummy = torch.randn(1, seq_len, cfg.hidden_size)
    mask_kwargs = {
        "config": cfg,
        "input_embeds": dummy,
        "attention_mask": None,
        "cache_position": torch.arange(seq_len),
        "past_key_values": None,
        "position_ids": torch.arange(seq_len).unsqueeze(0),
    }
    try:
        full = create_causal_mask(**mask_kwargs)
        print("full_mask:", full.shape)
    except Exception:
        print("full_mask: None (skipped)")

    if "sliding_attention" in cfg.layer_types:
        try:
            slide = create_sliding_window_causal_mask(**mask_kwargs)
            print("sliding_mask:", slide.shape)
        except Exception:
            print("sliding_mask: None (skipped)")

    print("All manual checks passed!")

if __name__ == "__main__":
    main()

How to run

# 1.2B on GPU
python scripts/manual_test_exaone4.py \
  --model LGAI-EXAONE/EXAONE-4.0-1.2B \
  --device cuda

# 32B on CPU to avoid OOM
python scripts/manual_test_exaone4.py \
  --model LGAI-EXAONE/EXAONE-4.0-32B \
  --device cpu

Both models now load without exception, pass a forward pass, and generate masks correctly.

Result

=== Checking LGAI-EXAONE/EXAONE-4.0-1.2B ===
 - first 8 layer_types: ['full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention']
 - total layers: 30
 - has sliding_attention? False
 - generate OK, logits shape: torch.Size([1, 37])
 - sample output: [|user|]\n너가 얼마나 대단한지 설명해 봐[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\n저는 EXAONE으로, LG AI Research에서 개발된 대규모 언어 모델 ...
 - full_mask: None
=== Done: LGAI-EXAONE/EXAONE-4.0-1.2B ===

=== Checking LGAI-EXAONE/EXAONE-4.0-32B ===
 - first 8 layer_types: ['sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention']
 - total layers: 64
 - has sliding_attention? True
 - generate OK, logits shape: torch.Size([1, 37])
 - sample output: [|user|]\n너가 얼마나 대단한지 설명해 봐[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\n저는 LG AI 연구원에서 개발한 EXAONE입니다. 저는 다양한 ...
 - full_mask: None
 - sliding_mask: None
=== Done: LGAI-EXAONE/EXAONE-4.0-32B ===

@wheeze01
Copy link
Author

wheeze01 commented Jul 26, 2025


This PR infers and implements the intended behavior from LG AI Research’s existing code and PR discussion for EXAONE-4.0. It may differ slightly from the original developer’s intent, so any feedback is greatly appreciated.

We also verified that inference works correctly with the following script:

from transformers import AutoModelForCausalLM, AutoTokenizer

# model_name = "LGAI-EXAONE/EXAONE-4.0-1.2B"     # same result
model_name = "LGAI-EXAONE/EXAONE-4.0-32B"

model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="bfloat16", device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# choose your prompt
prompt = "너가 얼마나 대단한지 설명해 봐"
messages = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)

output = model.generate(
    input_ids.to(model.device),
    max_new_tokens=128,
    do_sample=False,
)
print(tokenizer.decode(output[0]))

output:

[|user|]
너가 얼마나 대단한지 설명해 봐[|endofturn|]
[|assistant|]
<think>

</think>

저는 EXAONE으로, LG AI Research에서 개발된 대규모 언어 모델입니다. 제 능력은 다음과 같은 점에서 뛰어납니다:

1. **복잡한 계산 처리**: 다양한 언어 작업을 빠르고 정확하게 수행할 수 있습니다.
2. **다양한 언어 이해 및 생성**: 한국어, 영어 등 여러 언어를 유창하게 이해하고 생성할 수 있습니다.
3. **빠른 응답 속도**: 긴 텍스트도 짧은 시간 내에 분석하고 요약하거나 새로운 답변을 제공할 수 있습니다.
4. **학습 데이터 활용**: 방대한

@lgai-exaone
Copy link
Contributor

Hello, @wheeze01. Thank you for your attention and contribution!

Your PR appears to align with our intentions, except for one point:

  • We did not intend for models to be forced to use full attention in the last layer. We intend for all attention types to be controlled by sliding window options or layer types.

By the way, we will update the models' configuration with proper layer_types to address this issue first. After this update, this PR may become unnecessary for the released EXAONE 4.0 models, but it would be better for the maintainers to make the decision.

@lgai-exaone lgai-exaone mentioned this pull request Jul 28, 2025
5 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes @lgai-exaone mentioned, the best is to align with layer_types! which should be explicit!

@lgai-exaone
Copy link
Contributor

@ArthurZucker, thank you for your response!

Apart from handling sliding_window_pattern, it appears that forcing _attn_implementation="hybrid" needs to be fixed, as the current implementation does not support it.
Are there any plans for those lines to work properly?

@ArthurZucker
Copy link
Collaborator

Hey! I am wondering what:

attn_implementation="hybrid"

would refer to? As currently all attention implementation in transformers support both sliding and non sliding

@lgai-exaone
Copy link
Contributor

@ArthurZucker, the current EXAONE 4.0 configuration includes this implementation at

if "sliding_window" in self.layer_types:
self._attn_implementation = "hybrid"
layer_type_validation(self.layer_types)

@ArthurZucker
Copy link
Collaborator

Ah ! Sorry what I wanted to write is self.cache_implementation! I'll fix this and do a patch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[exaone4] ZeroDivisionError/TypeError when sliding_window_pattern is None/"LLLG" and _attn_implementation stays None (4.54.0 & main)

3 participants