@@ -210,12 +210,12 @@ __device__ __forceinline__ void welford_merge_block_vertical(C& count,
210210
211211template <typename input_scalar_t , typename stat_scalar_t , typename stat_accscalar_t , bool train, typename index_t >
212212__global__ void batch_norm_transform_input_kernel (
213- const GenericPackedTensorAccessor<input_scalar_t , 3 , RestrictPtrTraits, index_t > input,
213+ const GenericPackedTensorAccessor<const input_scalar_t , 3 , RestrictPtrTraits, index_t > input,
214214 GenericPackedTensorAccessor<input_scalar_t , 3 , RestrictPtrTraits, index_t > output,
215215 const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t , stat_scalar_t >::type, 1 , RestrictPtrTraits, index_t > mean_,
216216 const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t , stat_scalar_t >::type, 1 , RestrictPtrTraits, index_t > var_or_invstd,
217- const GenericPackedTensorAccessor<stat_scalar_t , 1 , RestrictPtrTraits, index_t > weight,
218- const GenericPackedTensorAccessor<stat_scalar_t , 1 , RestrictPtrTraits, index_t > bias,
217+ const GenericPackedTensorAccessor<const stat_scalar_t , 1 , RestrictPtrTraits, index_t > weight,
218+ const GenericPackedTensorAccessor<const stat_scalar_t , 1 , RestrictPtrTraits, index_t > bias,
219219 stat_accscalar_t epsilon) {
220220
221221 index_t plane = blockIdx .x ;
@@ -267,7 +267,7 @@ struct Var {
267267
268268template <typename VarTransform, typename input_scalar_t , typename stat_scalar_t , typename stat_accscalar_t , typename index_t >
269269__global__ void batch_norm_collect_statistics_kernel (
270- const GenericPackedTensorAccessor<input_scalar_t , 3 , RestrictPtrTraits, index_t > input,
270+ const GenericPackedTensorAccessor<const input_scalar_t , 3 , RestrictPtrTraits, index_t > input,
271271 const stat_accscalar_t epsilon,
272272 const stat_accscalar_t momentum,
273273 GenericPackedTensorAccessor<stat_accscalar_t , 1 , RestrictPtrTraits, index_t > save_mean,
@@ -582,7 +582,7 @@ __global__ void batch_norm_backward_elemt_kernel(
582582template <typename scalar_t , int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
583583static GenericPackedTensorAccessor<scalar_t , dim, PtrTraits, index_t > get_packed_accessor (
584584 const Tensor& t, c10::string_view var_name) {
585- constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t >::value;
585+ constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const< scalar_t >::type >::value;
586586 const auto actual_type = t.scalar_type ();
587587 TORCH_CHECK (actual_type == expect_type, " Expected " , var_name,
588588 " to have type " , expect_type, " but got " , actual_type);
@@ -670,7 +670,7 @@ void batch_norm_stats_cuda_template(
670670 resize_output (out_mean, {n_input});
671671 resize_output (out_invstd, {n_input});
672672 auto input = get_packed_accessor<
673- scalar_t , 3 , RestrictPtrTraits, index_t >(input_reshaped, " input" );
673+ const scalar_t , 3 , RestrictPtrTraits, index_t >(input_reshaped, " input" );
674674 TORCH_INTERNAL_ASSERT (out_invstd.dim () == 1 && out_invstd.is_contiguous () &&
675675 out_invstd.sizes ()[0 ]);
676676 TORCH_INTERNAL_ASSERT (out_mean.dim () == 1 && out_mean.is_contiguous () &&
@@ -700,13 +700,13 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
700700 auto output_reshaped = output_.view ({input_.size (0 ), input_.size (1 ), -1 });
701701
702702 auto input = get_packed_accessor<
703- input_scalar_t , 3 , RestrictPtrTraits, index_t >(input_reshaped, " input" );
703+ const input_scalar_t , 3 , RestrictPtrTraits, index_t >(input_reshaped, " input" );
704704 auto output = get_packed_accessor<
705705 input_scalar_t , 3 , RestrictPtrTraits, index_t >(output_reshaped, " output" );
706706 auto weight = packed_accessor_or_dummy<
707- stat_scalar_t , 1 , RestrictPtrTraits, index_t >(weight_, " weight" );
707+ const stat_scalar_t , 1 , RestrictPtrTraits, index_t >(weight_, " weight" );
708708 auto bias = packed_accessor_or_dummy<
709- stat_scalar_t , 1 , RestrictPtrTraits, index_t >(bias_, " bias" );
709+ const stat_scalar_t , 1 , RestrictPtrTraits, index_t >(bias_, " bias" );
710710 auto mean = packed_accessor_or_dummy<
711711 stat_accscalar_t , 1 , RestrictPtrTraits, index_t >(mean_, " mean" );
712712 auto invstd = packed_accessor_or_dummy<
0 commit comments