Skip to content

Conversation

@hjh0119
Copy link
Contributor

@hjh0119 hjh0119 commented Sep 17, 2025

Fix asymmetric patching/unpatching in InsertPostInitMethodToModuleSubClasses

Problem Description

The InsertPostInitMethodToModuleSubClasses context manager patches __init__ methods of model classes during entry and unpatches them during exit.

However, asymmetric condition checks between patching and unpatching can introduce subtle inheritance bugs.

Root Cause Analysis

The issue occurs with classes that have multiple inheritance where:

  1. Child class A does not override __init__
  2. Parent class B does not inherit from nn.Module
  3. Parent class C inherits from nn.Module

Current asymmetric logic:

# Patching (entry): Only patch classes with explicit __init__
def _enable_class(cls):
    if '__init__' in cls.__dict__:  # ✅ Strict check
        cls._old_init = cls.__init__
        cls.__init__ = partition_after(cls.__init__)

# Unpatching (exit): Restore any class with _old_init
def _disable_class(cls):
    if hasattr(cls, '_old_init'):  # ❌ Permissive check
        cls.__init__ = cls._old_init

Execution flow:

  1. During entry: Child A is skipped (no explicit __init__), Parent C is patched
  2. During exit: Child A inherits _old_init from Parent C and gets incorrectly "restored"

Result: Child A's __init__ points to Parent C's original __init__, bypassing Parent B and breaking the inheritance chain.

Reproduction Case

This pattern is common in Hugging Face models:

class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
    pass  # No explicit __init__

# GenericForSequenceClassification - not a nn.Module subclass
# Qwen3PreTrainedModel - inherits from nn.Module

Solution

Apply symmetric condition checking in both patch and unpatch operations:

def _disable_class(cls):
    # Match the patching condition: only restore classes we explicitly patched
    if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'):
        cls.__init__ = cls._old_init
        delattr(cls, '_old_init')  # Optional cleanup

This ensures that only classes that were explicitly patched during entry get restored during exit.

Testing

The fix has been validated against the Qwen3ForSequenceClassification reproduction case and resolves the inheritance chain corruption.

Related Issues

@tohtana tohtana enabled auto-merge (squash) September 19, 2025 06:35
Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hjh0119 Thank you for the fix!

@tohtana tohtana merged commit e4f6da9 into deepspeedai:master Sep 19, 2025
12 checks passed
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
## Fix asymmetric patching/unpatching in
InsertPostInitMethodToModuleSubClasses

### Problem Description

The `InsertPostInitMethodToModuleSubClasses` context manager patches
`__init__` methods of model classes during entry and unpatches them
during exit.

However, asymmetric condition checks between patching and unpatching can
introduce subtle inheritance bugs.

### Root Cause Analysis

The issue occurs with classes that have multiple inheritance where:
1. **Child class A** does not override `__init__`
2. **Parent class B** does not inherit from `nn.Module`
3. **Parent class C** inherits from `nn.Module`

**Current asymmetric logic:**
```python
# Patching (entry): Only patch classes with explicit __init__
def _enable_class(cls):
    if '__init__' in cls.__dict__:  # ✅ Strict check
        cls._old_init = cls.__init__
        cls.__init__ = partition_after(cls.__init__)

# Unpatching (exit): Restore any class with _old_init
def _disable_class(cls):
    if hasattr(cls, '_old_init'):  # ❌ Permissive check
        cls.__init__ = cls._old_init
```

**Execution flow:**
1. **During entry**: Child A is skipped (no explicit `__init__`), Parent
C is patched
2. **During exit**: Child A inherits `_old_init` from Parent C and gets
incorrectly "restored"

**Result**: Child A's `__init__` points to Parent C's original
`__init__`, bypassing Parent B and breaking the inheritance chain.

### Reproduction Case

This pattern is common in Hugging Face models:
```python
class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
    pass  # No explicit __init__

# GenericForSequenceClassification - not a nn.Module subclass
# Qwen3PreTrainedModel - inherits from nn.Module
```

### Solution

Apply symmetric condition checking in both patch and unpatch operations:

```python
def _disable_class(cls):
    # Match the patching condition: only restore classes we explicitly patched
    if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'):
        cls.__init__ = cls._old_init
        delattr(cls, '_old_init')  # Optional cleanup
```

This ensures that only classes that were explicitly patched during entry
get restored during exit.

### Testing

The fix has been validated against the Qwen3ForSequenceClassification
reproduction case and resolves the inheritance chain corruption.

### Related Issues
- External issue: modelscope/ms-swift#5820

Co-authored-by: Masahiro Tanaka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants