Skip to content

Conversation

@Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Dec 3, 2024

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

The benchmark code used
import time
import torch
from torch.profiler import profile, ProfilerActivity


def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)


def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
Results
op dtype size time after time before % improvement
relu torch.float16 33554432 4.84E-05 5.06E-05 4.66296539127052
relu torch.float16 67108864 9.22E-05 9.64E-05 4.56491432752297
relu torch.float16 134217728 0.000180343495837102 0.000187981834945579 4.23543919508829
relu torch.float16 268435456 0.000355071155354381 0.000370856161074092 4.44558942107169
relu torch.float16 536870912 0.000704489842367669 0.000736006341564159 4.47366268483987
relu torch.bfloat16 16777216 3.03E-05 3.04E-05 0.166504085842689
relu torch.bfloat16 33554432 4.89E-05 5.06E-05 3.45848238875716
relu torch.bfloat16 67108864 9.32E-05 9.65E-05 3.56122651631445
relu torch.bfloat16 134217728 0.000180805509444326 0.000187998676362137 3.97840029317567
relu torch.bfloat16 268435456 0.000356242332297067 0.000371279485989362 4.22104627356745
relu torch.bfloat16 536870912 0.000708114336399982 0.000736773828975856 4.04729732229083
relu torch.float32 16777216 5.61E-05 5.61E-05 0.0442587268354941
relu torch.float32 33554432 9.33E-05 9.30E-05 -0.259070913799022
relu torch.float32 67108864 0.000181321326332788 0.000181289506144822 -0.0175490597877115
relu torch.float32 134217728 0.000356896334172537 0.000356570177245885 -0.0913870206618981
relu torch.float32 268435456 0.000709421835684528 0.000707465515006334 -0.275762681635911
relu torch.float32 536870912 0.00141372415237129 0.00141036518228551 -0.237597276678471
sigmoid torch.float16 16777216 3.10E-05 3.16E-05 2.10012593866895
sigmoid torch.float16 33554432 4.91E-05 5.23E-05 6.37710600666122
sigmoid torch.float16 67108864 9.30E-05 0.000100057009452333 7.61866144555331
sigmoid torch.float16 134217728 0.000180928347011407 0.000194982004662355 7.76752669390248
sigmoid torch.float16 268435456 0.000355658994521946 0.00038468533117945 8.16128288742412
sigmoid torch.float16 536870912 0.000705982849467546 0.000764021339515845 8.22094900634937
sigmoid torch.bfloat16 16777216 3.08E-05 3.17E-05 2.90965915673149
sigmoid torch.bfloat16 33554432 4.87E-05 5.24E-05 7.63503884668234
sigmoid torch.bfloat16 67108864 9.33E-05 0.000100019678939134 7.21238137428013
sigmoid torch.bfloat16 134217728 0.000180786165098349 0.000194868014659733 7.78922964250206
sigmoid torch.bfloat16 268435456 0.000355564659306159 0.000384909333661199 8.25297835063321
sigmoid torch.bfloat16 536870912 0.000705831005082776 0.000764102345177283 8.2557070566308
sigmoid torch.float32 16777216 4.93E-05 5.65E-05 14.5314136197766
sigmoid torch.float32 33554432 9.32E-05 9.31E-05 -0.120169865610833
sigmoid torch.float32 67108864 0.000181328505277634 0.000180455681402236 -0.481349512069855
sigmoid torch.float32 134217728 0.000357362829769651 0.000356093340087682 -0.35523831137877
sigmoid torch.float32 268435456 0.000708921831877281 0.000707052337626616 -0.263709504574663
sigmoid torch.float32 536870912 0.00141358317341656 0.0014090768333214 -0.318788464654745
tanh torch.float16 16777216 3.03E-05 3.03E-05 -0.0912564658661808
tanh torch.float16 33554432 4.90E-05 5.07E-05 3.46644442974484
tanh torch.float16 67108864 9.30E-05 9.68E-05 3.99871369815531
tanh torch.float16 134217728 0.00018052199933057 0.000188717152923346 4.53969799978138
tanh torch.float16 268435456 0.000355684508879979 0.000373026006855071 4.8755280430115
tanh torch.float16 536870912 0.000706660988119741 0.000740105014604827 4.73268328765002
tanh torch.bfloat16 16777216 2.99E-05 3.03E-05 1.21049563135981
tanh torch.bfloat16 33554432 4.89E-05 5.06E-05 3.48836101041744
tanh torch.bfloat16 67108864 9.28E-05 9.69E-05 4.39944918036626
tanh torch.bfloat16 134217728 0.000180710999605556 0.000189167990659674 4.67984299382829
tanh torch.bfloat16 268435456 0.000356062994493792 0.000372666652159144 4.66312363882606
tanh torch.bfloat16 536870912 0.000707100164921333 0.000740134331863374 4.67178040408393
tanh torch.float32 16777216 5.61E-05 5.64E-05 0.439595755746353
tanh torch.float32 33554432 9.31E-05 9.31E-05 0.00287633090228212
tanh torch.float32 67108864 0.000181465332085888 0.000180895323865116 -0.31411411437098
tanh torch.float32 134217728 0.000356963835656643 0.000356073161431899 -0.249513854283251
tanh torch.float32 268435456 0.000709201170442005 0.00070707315656667 -0.300057862849997
tanh torch.float32 536870912 0.00141367283261692 0.00141030051357423 -0.238550176877922
gelu torch.float16 16777216 2.73E-05 3.17E-05 15.921079070745
gelu torch.float16 33554432 5.06E-05 5.55E-05 9.76345374333098
gelu torch.float16 67108864 9.65E-05 0.000106600326641152 10.4308039074712
gelu torch.float16 134217728 0.000187776672343413 0.000208565829476962 11.0712139447915
gelu torch.float16 268435456 0.000370216167842348 0.000412251994324227 11.3544005187205
gelu torch.float16 536870912 0.000737301345604161 0.000819394170927505 11.1342296895002
gelu torch.bfloat16 16777216 3.02E-05 3.08E-05 1.78405479367653
gelu torch.bfloat16 33554432 5.13E-05 5.69E-05 10.9929393318302
gelu torch.bfloat16 67108864 9.76E-05 0.00010968199543034 12.3420807512356
gelu torch.bfloat16 134217728 0.000189661824454864 0.000214487663470209 13.0895287371091
gelu torch.bfloat16 268435456 0.000374197009174774 0.000423670164309442 13.2211519391275
gelu torch.bfloat16 536870912 0.000743675006863972 0.000842577001700799 13.299088166737
gelu torch.float32 16777216 5.06E-05 5.04E-05 -0.413385894716413
gelu torch.float32 33554432 9.31E-05 9.32E-05 0.134157041722546
gelu torch.float32 67108864 0.000181480175039421 0.000180836669945469 -0.354586992112075
gelu torch.float32 134217728 0.000356874331676712 0.000356305002545317 -0.159532104402047
gelu torch.float32 268435456 0.000708909006789327 0.000706991491218408 -0.270488250615287
gelu torch.float32 536870912 0.00141321367118508 0.00140937082081412 -0.271922813181618
sin torch.float16 16777216 3.04E-05 3.11E-05 2.21834939018859
sin torch.float16 33554432 4.85E-05 5.23E-05 7.72165512511596
sin torch.float16 67108864 9.31E-05 9.98E-05 7.24947099480072
sin torch.float16 134217728 0.000180371008658161 0.000194791161144773 7.99471744039613
sin torch.float16 268435456 0.000355454161763191 0.000384903668115536 8.28503630574026
sin torch.float16 536870912 0.000705183832906187 0.000764360166310022 8.39161799270973
sin torch.bfloat16 16777216 3.11E-05 3.10E-05 -0.257677954940036
sin torch.bfloat16 33554432 4.89E-05 5.24E-05 7.34808420323539
sin torch.bfloat16 67108864 9.26E-05 0.000100248667877167 8.22347488801205
sin torch.bfloat16 134217728 0.000180674154156198 0.00019567032965521 8.30012215584937
sin torch.bfloat16 268435456 0.000355360486234228 0.000386023331278314 8.62865913118873
sin torch.bfloat16 536870912 0.00070483615854755 0.000766805159704139 8.79197248964745
sin torch.float32 16777216 5.67E-05 5.64E-05 -0.441348534920039
sin torch.float32 33554432 9.34E-05 9.30E-05 -0.496458540364117
sin torch.float32 67108864 0.000181706990891447 0.000180556671693921 -0.633062708199702
sin torch.float32 134217728 0.000356894995396336 0.000356046327700218 -0.237791985616354
sin torch.float32 268435456 0.000708777321657787 0.000707602652255446 -0.165731798471427
sin torch.float32 536870912 0.00141263716310884 0.00140912582476934 -0.248566187496451
exp torch.float16 16777216 3.00E-05 3.04E-05 1.40099098901014
exp torch.float16 33554432 4.86E-05 5.03E-05 3.44611943643906
exp torch.float16 67108864 9.37E-05 9.55E-05 1.96412400380129
exp torch.float16 134217728 0.000180913504057874 0.000187193179347863 3.47109262113439
exp torch.float16 268435456 0.00035607748820136 0.000369079003576189 3.65131630210701
exp torch.float16 536870912 0.000707551507124056 0.000732363162872692 3.50669251620789
exp torch.bfloat16 16777216 2.98E-05 3.04E-05 1.74345594341654
exp torch.bfloat16 33554432 4.88E-05 5.04E-05 3.40217856534821
exp torch.bfloat16 67108864 9.32E-05 9.62E-05 3.29219958210226
exp torch.bfloat16 134217728 0.000180999826019009 0.000187239318620414 3.44723679499521
exp torch.bfloat16 268435456 0.000355944503098726 0.000369370992605885 3.77207384585864
exp torch.bfloat16 536870912 0.000707135167128096 0.000733066000975668 3.66702648277075
exp torch.float32 16777216 4.89E-05 5.63E-05 15.1245314346532
exp torch.float32 33554432 9.34E-05 9.31E-05 -0.259945454477446
exp torch.float32 67108864 0.000181152504713585 0.000180474346658836 -0.374357536939058
exp torch.float32 134217728 0.000356771342922002 0.000355627329554409 -0.3206573034212
exp torch.float32 268435456 0.000708404501589636 0.00070713268360123 -0.179532736671163
exp torch.float32 536870912 0.00141283582585553 0.00140944866385932 -0.23974208002295

cc @msaroufim @ptrblck @eqy @manuelcandales @SherlockNoMad @angelayi

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141959

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 9292e7c with merge base aa95618 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Aidyn-A Aidyn-A added module: performance Issues related to performance, either of kernel code or framework glue module: cuda Related to torch.cuda, and CUDA support in general topic: not user facing topic category module: core aten Related to change to the Core ATen opset labels Dec 3, 2024
@Aidyn-A Aidyn-A requested a review from ngimel December 3, 2024 12:04
@ngimel
Copy link
Collaborator

ngimel commented Dec 3, 2024

This unconditionally sets items per thread to 8. Is it neutral/beneficial on A100?

using cpp_type = typename function_traits<func_t>::result_type;
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
if(vec_size == 16) {
vec_size = 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this vec_size set to 4 and not 8? Also, why is this conditional needed at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition I used to avoid vec8 for 1-byte data (bool and uint8) as some tests in test/test_binary_ufuncs.py were failing with mismatches.

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 3, 2024
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Dec 5, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased 128_bit_vectorization onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout 128_bit_vectorization && git pull --rebase)

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Dec 5, 2024

This unconditionally sets items per thread to 8. Is it neutral/beneficial on A100?

It improves performance on A100 as well.

A100 benchmark results are here
op dtype size time after time before % improvement
relu torch.float16 16777216 5.52E-05 5.56E-05 0.588398905051735
relu torch.float16 33554432 0.000101789004272885 0.000103080593463447 1.26888871719382
relu torch.float16 67108864 0.000197273265156481 0.000199393058816592 1.07454685176394
relu torch.float16 134217728 0.000388443004339933 0.000392856593761179 1.13622574533054
relu torch.float16 268435456 0.00077106234514051 0.000779860394282474 1.14102954156841
relu torch.float16 536870912 0.00153737427252862 0.00155346540527211 1.04666332922463
relu torch.bfloat16 16777216 5.50E-05 5.44E-05 -1.00206226691304
relu torch.bfloat16 33554432 0.000104047668476899 0.000104331825342443 0.273102578561746
relu torch.bfloat16 67108864 0.000202173677583536 0.000203764117840264 0.786670290483738
relu torch.bfloat16 134217728 0.000398641969594691 0.000402679149475363 1.01273327662332
relu torch.bfloat16 268435456 0.000800862442702055 0.000801527769201332 0.0830762517757711
relu torch.bfloat16 536870912 0.00159262172463867 0.00157401508962115 -1.16830222328839
relu torch.float32 16777216 0.000105040096160438 0.000104826926771137 -0.202940969299259
relu torch.float32 33554432 0.000205832430058055 0.000205058294037978 -0.376100121763734
relu torch.float32 67108864 0.000406690562764804 0.000405893247160647 -0.196049694081057
relu torch.float32 134217728 0.000808526658349567 0.000807985611673858 -0.0669176050191012
relu torch.float32 268435456 0.00159855828517013 0.00159612013440993 -0.152521855650756
relu torch.float32 536870912 0.00310465165724357 0.00309997693532043 -0.150571543581468
sigmoid torch.float16 16777216 5.53E-05 5.60E-05 1.37442506088692
sigmoid torch.float16 33554432 0.000104938064598375 0.000106826372858551 1.79945024467809
sigmoid torch.float16 67108864 0.000204297041313516 0.000208476403107246 2.04572800803136
sigmoid torch.float16 134217728 0.0004032871996363 0.000411878443426556 2.13030411032227
sigmoid torch.float16 268435456 0.000802099549522003 0.000818967301812437 2.10294997678111
sigmoid torch.float16 536870912 0.00159777545680602 0.00163260997376508 2.18018851213879
sigmoid torch.bfloat16 16777216 5.53E-05 5.60E-05 1.27645767333306
sigmoid torch.bfloat16 33554432 0.00010527613469296 0.000107017759647634 1.65433976062424
sigmoid torch.bfloat16 67108864 0.000203829206940201 0.000208569638844993 2.3256882445623
sigmoid torch.bfloat16 134217728 0.000402993574324581 0.00041222023881144 2.28953141556227
sigmoid torch.bfloat16 268435456 0.0008021114497549 0.000818662500629822 2.06343530939241
sigmoid torch.bfloat16 536870912 0.00159755944170886 0.00163286800185839 2.21015626885024
sigmoid torch.float32 16777216 0.000105733621037669 0.000104870129790571 -0.816666674822408
sigmoid torch.float32 33554432 0.000206765667017963 0.000204619951546192 -1.03775230323157
sigmoid torch.float32 67108864 0.000407101896901925 0.000405989535566833 -0.273239044955886
sigmoid torch.float32 134217728 0.000807929680579238 0.000807615410950449 -0.0388981413041223
sigmoid torch.float32 268435456 0.00161020312872198 0.00161120223088397 0.0620482064758621
sigmoid torch.float32 536870912 0.00312601159223252 0.00311755073360271 -0.270659860981737
tanh torch.float16 16777216 5.50E-05 5.56E-05 1.01392536515987
tanh torch.float16 33554432 0.00010461484392484 0.000106048201107317 1.37012791751241
tanh torch.float16 67108864 0.000203798576775524 0.000206987477011151 1.56473135685284
tanh torch.float16 134217728 0.000403522617287106 0.000408950727432966 1.3451811406145
tanh torch.float16 268435456 0.000802456504768795 0.000812884057975478 1.29945400712861
tanh torch.float16 536870912 0.0015979148964915 0.00162050236637394 1.41355900317532
tanh torch.bfloat16 16777216 5.50E-05 5.58E-05 1.32688536364782
tanh torch.bfloat16 33554432 0.000104613498681121 0.000106201765851842 1.51822392974537
tanh torch.bfloat16 67108864 0.00020353009717332 0.000207337292118205 1.87058081225351
tanh torch.bfloat16 134217728 0.000403028550661272 0.000410060760461622 1.74484159715513
tanh torch.bfloat16 268435456 0.00080103049468663 0.000814526652296384 1.68484941575588
tanh torch.bfloat16 536870912 0.00159850344061852 0.0016241325582895 1.6033195187287
tanh torch.float32 16777216 0.000105516829838355 0.000104838620043463 -0.642750351703303
tanh torch.float32 33554432 0.000205859490152862 0.000204412835753626 -0.702738745812404
tanh torch.float32 67108864 0.000406586409856876 0.00040571458844675 -0.214424631269161
tanh torch.float32 134217728 0.000808000409354766 0.000807130088408788 -0.107712933793336
tanh torch.float32 268435456 0.00160988306419717 0.00161015692477425 0.017011209271689
tanh torch.float32 536870912 0.00313035196935137 0.00310210293779771 -0.90242349199855
gelu torch.float16 16777216 5.58E-05 5.70E-05 2.15048278510104
gelu torch.float16 33554432 0.000105667704095443 0.000109181273728609 3.32511211750415
gelu torch.float16 67108864 0.000205330136749479 0.000213289943834146 3.87658977424148
gelu torch.float16 134217728 0.000404131753991048 0.000421578219781319 4.31702424221241
gelu torch.float16 268435456 0.00080191721725795 0.000838274239665932 4.53376254126332
gelu torch.float16 536870912 0.00159792177793052 0.00167205422702763 4.63930400855495
gelu torch.bfloat16 16777216 5.56E-05 5.71E-05 2.7460248048456
gelu torch.bfloat16 33554432 0.000105594181352192 0.000109837338742283 4.01836288302555
gelu torch.bfloat16 67108864 0.000204587148295508 0.000214392060620917 4.79253580056083
gelu torch.bfloat16 134217728 0.000403126029090749 0.000423436725719108 5.03829948023189
gelu torch.bfloat16 268435456 0.000799921082539691 0.000842164632760816 5.28096472804605
gelu torch.bfloat16 536870912 0.00159319071099162 0.00167886317811078 5.37741442553605
gelu torch.float32 16777216 0.00010601943358779 0.000104866456240416 -1.08751509827605
gelu torch.float32 33554432 0.000206458692749341 0.000204429340859254 -0.982933614013981
gelu torch.float32 67108864 0.000406732886201806 0.000404650604145394 -0.511953207387805
gelu torch.float32 134217728 0.00080782449286845 0.000810090193731917 0.280469443978126
gelu torch.float32 268435456 0.00160836293879482 0.00160547635621495 -0.179473333428126
gelu torch.float32 536870912 0.00312541264833676 0.00309180561453104 -1.07528309337327
sin torch.float16 16777216 5.58E-05 5.68E-05 1.84432136815518
sin torch.float16 33554432 0.000105774857931667 0.000109315746360355 3.34757096149971
sin torch.float16 67108864 0.000204949174076319 0.000213460582825873 4.15293635015306
sin torch.float16 134217728 0.000403335266229179 0.000421439038796557 4.48851714273144
sin torch.float16 268435456 0.000803771014842722 0.00083896823020445 4.37901027926653
sin torch.float16 536870912 0.00159420336907109 0.00167318914706508 4.95456097549341
sin torch.bfloat16 16777216 5.58E-05 5.70E-05 2.05289950280472
sin torch.bfloat16 33554432 0.000105607685529523 0.000109381404601865 3.57333754018021
sin torch.bfloat16 67108864 0.000205246110757192 0.000213417379806439 3.98120530474448
sin torch.bfloat16 134217728 0.000403606125877963 0.000421721229536666 4.48831236624492
sin torch.bfloat16 268435456 0.000800546879569689 0.000839088215596146 4.81437589853244
sin torch.bfloat16 536870912 0.00159501155010528 0.0016719801351428 4.8255816725868
sin torch.float32 16777216 0.0001063231482274 0.000104805506351921 -1.42738613442208
sin torch.float32 33554432 0.000206721118754811 0.00020440181510316 -1.1219480939447
sin torch.float32 67108864 0.000406984757218096 0.000404441884408395 -0.624807874153088
sin torch.float32 134217728 0.000807934802853399 0.000804758371992244 -0.393154354774283
sin torch.float32 268435456 0.00160756231182151 0.00160335816649927 -0.261523008553233
sin torch.float32 536870912 0.00312547411562668 0.00309828793009122 -0.869825969747518
exp torch.float16 16777216 5.46E-05 5.49E-05 0.526774300459265
exp torch.float16 33554432 0.000104664100541009 0.000104676155994336 0.0115182314324169
exp torch.float16 67108864 0.000204241420659754 0.000204416768004497 0.0858529793694496
exp torch.float16 134217728 0.000404592706925339 0.000404040846559736 -0.136398989936448
exp torch.float16 268435456 0.000805885789708959 0.00080300003497137 -0.35808482720997
exp torch.float16 536870912 0.00160873257037666 0.00160096456400222 -0.482864990581999
exp torch.bfloat16 16777216 5.49E-05 5.47E-05 -0.489032335861872
exp torch.bfloat16 33554432 0.000104703112608857 0.0001046407657365 -0.0595463408897046
exp torch.bfloat16 67108864 0.000204433790511555 0.000204202770772907 -0.113004674065675
exp torch.bfloat16 134217728 0.000404823674923844 0.000403735734936264 -0.268744160722478
exp torch.bfloat16 268435456 0.000805160548124048 0.000802269826332728 -0.359024271377362
exp torch.bfloat16 536870912 0.00160597870126367 0.00160156599142485 -0.274767643888552
exp torch.float32 16777216 0.000105293364160591 0.00010481502653824 -0.454290378282118
exp torch.float32 33554432 0.000205790779242913 0.000204942189157009 -0.412355737718517
exp torch.float32 67108864 0.000406338781532314 0.000405922169900603 -0.102528149082948
exp torch.float32 134217728 0.000808410087807311 0.000807947479188442 -0.057224498536812
exp torch.float32 268435456 0.00161091559049156 0.00161218053350846 0.0785232338904818
exp torch.float32 536870912 0.0031157395698958 0.00311509809560246 -0.020588187136561

@Aidyn-A Aidyn-A added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 5, 2024
auto stream = at::cuda::getCurrentCUDAStream();
using cpp_type = typename function_traits<func_t>::result_type;
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data, sizeof(cpp_type) == 2);
// Here 1-byte dtypes are intentionally avoided as the vec8 causes incorrect outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry what? Why does vec8 cause incorrect outputs? It has to be fixed. some 1-byte ops already use 8 elements per thread, so why would 8-vectorization be different?

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more info on 8-vectorization causing incorrect results

@Aidyn-A Aidyn-A marked this pull request as draft December 9, 2024 20:05
@eqy
Copy link
Collaborator

eqy commented Jan 23, 2025

Compiler issue looks real, I think for now after discussing with @Aidyn-A we'll gate the problematic cases via CUDA_VERSION and __CUDA_ARCH__

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 24, 2025

I will need to re-base and resolve conflicts with #143269 that apparently did the same thing.

@ngimel
Copy link
Collaborator

ngimel commented Jan 25, 2025

Thanks for digging in, ping me when PR is ready for review!

pytorchmergebot pushed a commit that referenced this pull request Jan 29, 2025
This is a re-base PR to my previous one #141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

<details>

<summary>The benchmark code used </summary>

```Python

import time
import torch
from torch.profiler import profile, ProfilerActivity

def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)

def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
```

</details>

<details>

<summary> Results </summary>

| op | dtype | size | time after | time before | % improvement |
| ---- | ---- | ---- | ---- | ---- | ---- |
| relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 |
| relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 |
| relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 |
| relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 |
| relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 |
| relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 |
| relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 |
| relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 |
| relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 |
| relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 |
| relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 |
| relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 |
| relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 |
| relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 |
| relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 |
| relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 |
| relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 |
| sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 |
| sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 |
| sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 |
| sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 |
| sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 |
| sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 |
| sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 |
| sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 |
| sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 |
| sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 |
| sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 |
| sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 |
| sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 |
| sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 |
| sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 |
| sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 |
| sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 |
| sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 |
| tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 |
| tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 |
| tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 |
| tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 |
| tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 |
| tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 |
| tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 |
| tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 |
| tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 |
| tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 |
| tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 |
| tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 |
| tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 |
| tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 |
| tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 |
| tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 |
| tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 |
| tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 |
| gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 |
| gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 |
| gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 |
| gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 |
| gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 |
| gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 |
| gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 |
| gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 |
| gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 |
| gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 |
| gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 |
| gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 |
| gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 |
| gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 |
| gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 |
| gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 |
| gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 |
| gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 |
| sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 |
| sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 |
| sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 |
| sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 |
| sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 |
| sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 |
| sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 |
| sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 |
| sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 |
| sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 |
| sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 |
| sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 |
| sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 |
| sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 |
| sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 |
| sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 |
| sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 |
| sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 |
| exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 |
| exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 |
| exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 |
| exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 |
| exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 |
| exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 |
| exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 |
| exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 |
| exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 |
| exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 |
| exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 |
| exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 |
| exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 |
| exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 |
| exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 |
| exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 |
| exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 |
| exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |

</details>

Pull Request resolved: #145746
Approved by: https://github.com/eqy, https://github.com/ngimel
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 29, 2025

Closing in favor of #145746.

@Aidyn-A Aidyn-A closed this Jan 29, 2025
pytorchmergebot pushed a commit that referenced this pull request Jan 31, 2025
This is a re-base PR to my previous one #141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

<details>

<summary>The benchmark code used </summary>

```Python

import time
import torch
from torch.profiler import profile, ProfilerActivity

def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)

def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
```

</details>

<details>

<summary> Results </summary>

| op | dtype | size | time after | time before | % improvement |
| ---- | ---- | ---- | ---- | ---- | ---- |
| relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 |
| relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 |
| relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 |
| relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 |
| relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 |
| relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 |
| relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 |
| relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 |
| relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 |
| relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 |
| relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 |
| relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 |
| relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 |
| relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 |
| relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 |
| relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 |
| relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 |
| sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 |
| sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 |
| sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 |
| sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 |
| sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 |
| sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 |
| sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 |
| sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 |
| sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 |
| sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 |
| sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 |
| sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 |
| sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 |
| sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 |
| sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 |
| sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 |
| sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 |
| sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 |
| tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 |
| tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 |
| tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 |
| tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 |
| tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 |
| tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 |
| tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 |
| tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 |
| tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 |
| tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 |
| tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 |
| tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 |
| tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 |
| tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 |
| tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 |
| tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 |
| tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 |
| tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 |
| gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 |
| gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 |
| gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 |
| gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 |
| gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 |
| gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 |
| gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 |
| gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 |
| gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 |
| gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 |
| gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 |
| gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 |
| gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 |
| gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 |
| gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 |
| gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 |
| gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 |
| gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 |
| sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 |
| sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 |
| sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 |
| sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 |
| sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 |
| sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 |
| sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 |
| sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 |
| sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 |
| sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 |
| sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 |
| sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 |
| sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 |
| sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 |
| sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 |
| sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 |
| sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 |
| sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 |
| exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 |
| exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 |
| exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 |
| exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 |
| exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 |
| exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 |
| exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 |
| exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 |
| exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 |
| exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 |
| exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 |
| exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 |
| exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 |
| exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 |
| exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 |
| exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 |
| exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 |
| exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |

</details>

Pull Request resolved: #145746
Approved by: https://github.com/eqy, https://github.com/ngimel

Co-authored-by: Aaron Gokaslan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general module: performance Issues related to performance, either of kernel code or framework glue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

6 participants