@@ -48,31 +48,73 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh
4848 return { newShape, newPerm } ;
4949} ;
5050
51+ const isTransposeReshape = ( perm : number [ ] , shape : readonly number [ ] ) => {
52+ // As long as the dims with values > 1 stay in the same order, it's a reshape.
53+ // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
54+ let lastPermutedAxis = 0 ;
55+ for ( let i = 0 ; i < perm . length ; ++ i ) {
56+ if ( shape [ perm [ i ] ] === 1 ) {
57+ continue ;
58+ }
59+ if ( perm [ i ] < lastPermutedAxis ) {
60+ return false ;
61+ }
62+ lastPermutedAxis = perm [ i ] ;
63+ }
64+ return true ;
65+ } ;
66+
5167export const createTransposeProgramInfo = ( inputTensor : TensorView , permAttr : number [ ] ) : ProgramInfo => {
5268 const inputDataType = inputTensor . dataType ;
5369 const inputRank = inputTensor . dims . length ;
5470 const perm = getAdjustedPerm ( inputRank , permAttr ) ;
5571 const outputShape = getOutputShape ( inputTensor . dims , perm ) ;
72+ let newInputShape = inputTensor . dims ;
73+ let newOutputShape = outputShape ;
74+ const transposeAsReshape = isTransposeReshape ( perm , inputTensor . dims ) ;
75+ let getShaderSource ;
76+ if ( transposeAsReshape ) {
77+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
78+ const input = inputVariable ( 'input' , inputDataType , newInputShape , 4 ) ;
79+ const output = outputVariable ( 'output' , inputDataType , newOutputShape , 4 ) ;
80+ return `
81+ ${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
82+ ${ shaderHelper . mainStart ( ) }
83+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
84+ output[global_idx] = input[global_idx];
85+ }` ;
86+ } ;
87+
88+ return {
89+ name : 'TransposeCopy' ,
90+ shaderCache : { inputDependencies : [ 'type' ] } ,
91+ getRunData : ( ) => {
92+ const outputSize = ShapeUtil . size ( outputShape ) ;
93+ return {
94+ outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
95+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* components */ ) } ,
96+ programUniforms : [ { type : DataType . uint32 , data : Math . ceil ( outputSize / 4 ) } ] ,
97+ } ;
98+ } ,
99+ getShaderSource,
100+ } ;
101+ }
56102 const { newShape, newPerm } = squeezeShape ( inputTensor . dims , perm ) ;
57103 const channelsLast = ShapeUtil . areEqual ( newPerm , [ 2 , 3 , 1 ] ) ;
58104 const channelsFirst = ShapeUtil . areEqual ( newPerm , [ 3 , 1 , 2 ] ) ;
59- const useShared = ( newShape . length === 2 && newPerm [ 0 ] > newPerm [ 1 ] ) || channelsLast || channelsFirst ;
60- let newInputShape = useShared ? newShape : inputTensor . dims ;
61- let newOutputShape = outputShape ;
105+ const useShared = newShape . length === 2 || channelsLast || channelsFirst ;
62106 if ( useShared ) {
63107 newInputShape = channelsLast
64108 ? [ newShape [ 0 ] , newShape [ 1 ] * newShape [ 2 ] ]
65109 : channelsFirst
66110 ? [ newShape [ 0 ] * newShape [ 1 ] , newShape [ 2 ] ]
67111 : newShape ;
68112 newOutputShape = [ newInputShape [ 1 ] , newInputShape [ 0 ] ] ;
69- }
70- const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
71- const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
72- const tileSize = 16 ;
73- let getShaderSource ;
74- if ( useShared ) {
75- getShaderSource = ( shaderHelper : ShaderHelper ) => `
113+ const tileSize = 16 ;
114+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
115+ const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
116+ const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
117+ return `
76118 ${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
77119 var<workgroup> tile : array<array<${ output . type . value } , ${ tileSize + 1 } >, ${ tileSize } >;
78120 ${ shaderHelper . mainStart ( [ tileSize , tileSize , 1 ] ) }
@@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
92134 ${ output . setByIndices ( `${ output . type . indices } (output_row, output_col)` , 'tile[local_id.x][local_id.y]' ) }
93135 }
94136 }` ;
95- } else {
96- getShaderSource = ( shaderHelper : ShaderHelper ) => `
137+ } ;
138+ return {
139+ name : 'TransposeShared' ,
140+ shaderCache : { inputDependencies : [ 'type' ] } ,
141+ getRunData : ( ) => {
142+ const outputSize = ShapeUtil . size ( outputShape ) ;
143+ return {
144+ outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
145+ dispatchGroup : { x : Math . ceil ( newOutputShape [ 1 ] / tileSize ) , y : Math . ceil ( newOutputShape [ 0 ] / tileSize ) } ,
146+ programUniforms : [
147+ { type : DataType . uint32 , data : outputSize } ,
148+ ...createTensorShapeVariables ( newInputShape , newOutputShape ) ,
149+ ] ,
150+ } ;
151+ } ,
152+ getShaderSource,
153+ } ;
154+ }
155+
156+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
157+ const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
158+ const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
159+ return `
97160 ${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
98161
99162 ${ permFunctionBody ( perm , inputRank , input , output ) }
@@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
106169
107170 ${ output . setByOffset ( 'global_idx' , input . getByIndices ( 'aIndices' ) ) }
108171 }` ;
109- }
172+ } ;
110173 return {
111- name : useShared ? 'TransposeShared' : 'Transpose' ,
174+ name : 'Transpose' ,
112175 shaderCache : { hint : `${ permAttr } ` , inputDependencies : [ 'rank' ] } ,
113176 getRunData : ( ) => {
114177 const outputSize = ShapeUtil . size ( outputShape ) ;
115178 return {
116179 outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
117- dispatchGroup : useShared
118- ? { x : Math . ceil ( newOutputShape [ 1 ] / tileSize ) , y : Math . ceil ( newOutputShape [ 0 ] / tileSize ) }
119- : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
180+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
120181 programUniforms : [
121182 { type : DataType . uint32 , data : outputSize } ,
122183 ...createTensorShapeVariables ( newInputShape , newOutputShape ) ,
0 commit comments