Skip to content

Conversation

@CaoE
Copy link
Collaborator

@CaoE CaoE commented Sep 18, 2024

kernel_micro_gemm generated using BRGEMM:

template <bool accum>
inline void kernel_micro_gemm(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    at::native::cpublas::brgemm(
      M, N, K,
      lda, ldb, ldc,
      1.f, accum ? 1.f : 0.f,
      A,
      B,
      C);
}

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136255

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5070fdd with merge base 3672c68 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor labels Sep 18, 2024
@CaoE CaoE added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Sep 18, 2024
@CaoE CaoE force-pushed the gemm_template branch 3 times, most recently from c27521a to febedb3 Compare September 19, 2024 07:58
@CaoE CaoE force-pushed the gemm_template branch 2 times, most recently from a2a7a0f to 399b92d Compare October 21, 2024 01:20
@CaoE CaoE requested a review from jgong5 October 21, 2024 13:02

def _check_brgemm_counter(self, vec_amx):
if vec_amx and (
torch.cpu._is_amx_fp16_supported() or torch.cpu._is_avx512_fp16_supported()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we count in AVX512-FP16? Are we relying on the FP16 FMA which is less accurate than AMX?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking avx512_fp16 is enough. Modified. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, my question is why we care about avx512_fp16 here. We only use AMX FP16, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed non-amx fp16 support. Only check AMX FP16 here.

@CaoE CaoE changed the title Add oneDNN BRGEMM config for Half cpp gemm template [Inductor][CPP] Add oneDNN BRGEMM config for Half cpp gemm template Oct 23, 2024
@CaoE CaoE requested a review from jgong5 October 29, 2024 00:43
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why we need avx512 fp16 here.

@CaoE
Copy link
Collaborator Author

CaoE commented Oct 29, 2024

Please explain why we need avx512 fp16 here.

Removed non-amx fp16 support.

@CaoE CaoE requested a review from jgong5 October 29, 2024 08:54
@CaoE CaoE marked this pull request as ready for review October 30, 2024 03:25
@CaoE CaoE requested a review from jansel October 30, 2024 03:26
@CaoE
Copy link
Collaborator Author

CaoE commented Nov 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@CaoE
Copy link
Collaborator Author

CaoE commented Nov 3, 2024

@pytorchbot rebase

@CaoE
Copy link
Collaborator Author

CaoE commented Nov 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased gemm_template onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout gemm_template && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@CaoE
Copy link
Collaborator Author

CaoE commented Nov 5, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ytorch#136255)

`kernel_micro_gemm` generated using BRGEMM:
```
template <bool accum>
inline void kernel_micro_gemm(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    at::native::cpublas::brgemm(
      M, N, K,
      lda, ldb, ldc,
      1.f, accum ? 1.f : 0.f,
      A,
      B,
      C);
}
```

Pull Request resolved: pytorch#136255
Approved by: https://github.com/jgong5, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants