@@ -22,6 +22,9 @@ namespace paddle {
2222namespace primitive {
2323namespace details {
2424
25+ // empty_shape means x.shape=[]
26+ static std::vector<int64_t > empty_shape;
27+
2528template <typename T>
2629Tensor mean_decomp (const Tensor& x, const IntArray& axis, bool keepdim) {
2730 auto org_dtype = x.dtype ();
@@ -345,62 +348,66 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
345348
346349 // cast dtype to float32 if dtype =float16 or bfloat16
347350 if (need_cast) {
348- x_cast = cast<T>(x_cast, phi:: DataType::FLOAT32);
351+ x_cast = cast<T>(x_cast, DataType::FLOAT32);
349352 }
350353
351354 auto x_dim = common::vectorize<int64_t >(x.dims ());
352355 for (size_t i = begin_norm_axis; i < x_dim.size (); i++) {
353356 axis.push_back (static_cast <int64_t >(i));
354357 }
355- auto mean_ = mean_decomp<T>(x_cast, IntArray ( axis) , true );
358+ auto mean_ = mean_decomp<T>(x_cast, axis, true );
356359 auto difference = x_cast - mean_;
357360 auto var_tmp1 = difference * difference;
358- auto variance = mean_decomp<T>(var_tmp1, IntArray ( axis) , true );
361+ auto variance = mean_decomp<T>(var_tmp1, axis, true );
359362 auto var_tmp3 = variance + epsilon;
360363 auto rsqrt_var = elementwise_pow<T>(
361- var_tmp3,
362- full<T>(common::vectorize (var_tmp3.dims ()), -0.5 , var_tmp3.dtype ()));
364+ var_tmp3, full<T>(empty_shape, -0.5 , var_tmp3.dtype ()));
363365 auto out = difference * rsqrt_var;
364366
365367 auto scale_ptr = scale.get_ptr ();
366368 auto bias_ptr = bias.get_ptr ();
367369
368- std::vector<int64_t > slice_shape;
369- for (int64_t i = begin_norm_axis; i < static_cast <int64_t >(x_dim.size ());
370- i++) {
371- slice_shape.push_back (x_dim[i]);
370+ std::vector<int64_t > slice_shape_l;
371+ std::vector<int64_t > slice_shape_r;
372+ for (int64_t i = 0 ; i < static_cast <int64_t >(x_dim.size ()); i++) {
373+ if (i < begin_norm_axis) {
374+ slice_shape_l.push_back (x_dim[i]);
375+ } else {
376+ slice_shape_r.push_back (x_dim[i]);
377+ }
372378 }
373379 Tensor scale_cast;
374380 if (scale_ptr) {
375- if (slice_shape != scale_ptr->shape ()) {
376- scale_cast = reshape<T>(*scale_ptr, slice_shape );
381+ if (slice_shape_r != scale_ptr->shape ()) {
382+ scale_cast = reshape<T>(*scale_ptr, slice_shape_r );
377383 } else {
378384 scale_cast = *scale_ptr;
379385 }
380386 if (need_cast) {
381- scale_cast = cast<T>(scale_cast, phi:: DataType::FLOAT32);
387+ scale_cast = cast<T>(scale_cast, DataType::FLOAT32);
382388 }
383389 out = out * scale_cast;
384390 }
385391 Tensor bias_cast;
386392 if (bias_ptr) {
387- if (slice_shape != bias_ptr->shape ()) {
388- bias_cast = reshape<T>(*bias_ptr, slice_shape );
393+ if (slice_shape_r != bias_ptr->shape ()) {
394+ bias_cast = reshape<T>(*bias_ptr, slice_shape_r );
389395 } else {
390396 bias_cast = *bias_ptr;
391397 }
392398 if (need_cast) {
393- bias_cast = cast<T>(bias_cast, phi:: DataType::FLOAT32);
399+ bias_cast = cast<T>(bias_cast, DataType::FLOAT32);
394400 }
395401 out = out + bias_cast;
396402 }
397- mean_ = reshape<T>(mean_, std::vector< int64_t >({- 1 }) );
398- variance = reshape<T>(variance, std::vector< int64_t >({- 1 }) );
403+ mean_ = reshape<T>(mean_, slice_shape_l );
404+ variance = reshape<T>(variance, slice_shape_l );
399405
406+ // same as LayerNormInferMeta
407+ // x: float32 --> out: float32, mean: float32, variance: float32
408+ // x: float16 --> out: float16, mean: float32, variance: float32
400409 if (need_cast) {
401410 out = cast<T>(out, org_dtype);
402- mean_ = cast<T>(mean_, org_dtype);
403- variance = cast<T>(variance, org_dtype);
404411 }
405412
406413 return std::make_tuple (out, mean_, variance);
0 commit comments