Skip to content

Commit ab38b44

Browse files
committed
Update on "[dynamo] Remove transformers ModelOutput hack"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
2 parents 08c4317 + 683f0a6 commit ab38b44

File tree

119 files changed

+4847
-1028
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+4847
-1028
lines changed

.lintrunner.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,3 +1733,17 @@ include_patterns = [
17331733
'torch/**/not-exist.py'
17341734
]
17351735
is_formatter = false
1736+
1737+
# `import_linter` reports on importing disallowed third party libraries.
1738+
[[linter]]
1739+
code = 'IMPORT_LINTER'
1740+
command = [
1741+
'python3',
1742+
'tools/linter/adapters/import_linter.py',
1743+
'--',
1744+
'@{{PATHSFILE}}'
1745+
]
1746+
include_patterns = [
1747+
'torch/_dynamo/**',
1748+
]
1749+
is_formatter = false

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,9 @@ if(NOT MSVC)
10951095
append_cxx_flag_if_supported("-Wno-error=redundant-move" CMAKE_CXX_FLAGS)
10961096
endif()
10971097
else()
1098+
# Define export functions for AOTI.
1099+
add_compile_definitions(EXPORT_AOTI_FUNCTIONS)
1100+
10981101
# skip unwanted includes from windows.h
10991102
add_compile_definitions(WIN32_LEAN_AND_MEAN)
11001103
# Windows SDK broke compatibility since version 25131, but introduced this

aten/src/ATen/Context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
543543
display_vmap_fallback_warnings_ = enabled;
544544
}
545545

546+
bool Context::isDefaultMobileCPUAllocatorSet() {
547+
return prev_allocator_ptr_ != nullptr;
548+
}
549+
546550
void Context::setDefaultMobileCPUAllocator() {
547551
TORCH_CHECK(prev_allocator_ptr_ == nullptr,
548552
"Already within the scope of another non-default cpu allocator."

aten/src/ATen/Context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ class TORCH_API Context {
347347
void setDisplayVmapFallbackWarnings(bool enabled);
348348
bool areVmapFallbackWarningsEnabled() const;
349349

350+
bool isDefaultMobileCPUAllocatorSet();
350351
void setDefaultMobileCPUAllocator();
351352
void unsetDefaultMobileCPUAllocator();
352353
bool allowFP16ReductionCPU() const;

aten/src/ATen/autocast_mode.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
124124
// deprecated other backend specific autocast APIs
125125
AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
126126

127+
const std::array<at::DeviceType, 8> _AUTOCAST_SUPPORTED_DTYPES{
128+
at::kCPU,
129+
at::kCUDA,
130+
at::kXPU,
131+
at::kIPU,
132+
at::kHPU,
133+
at::kXLA,
134+
at::kPrivateUse1,
135+
at::kMPS};
136+
127137
namespace {
128138
inline bool is_autocast_eligible(
129139
const Tensor& tensor,
@@ -179,10 +189,10 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
179189
}
180190

181191
inline bool is_autocast_available(c10::DeviceType device_type) {
182-
if (device_type == at::kCPU || device_type == at::kCUDA ||
183-
device_type == at::kXPU || device_type == at::kIPU ||
184-
device_type == at::kHPU || device_type == at::kXLA ||
185-
device_type == at::kPrivateUse1 || device_type == at::kMPS) {
192+
if (std::find(
193+
_AUTOCAST_SUPPORTED_DTYPES.begin(),
194+
_AUTOCAST_SUPPORTED_DTYPES.end(),
195+
device_type) != _AUTOCAST_SUPPORTED_DTYPES.end()) {
186196
return true;
187197
} else {
188198
return false;

aten/src/ATen/mps/MPSDevice.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ enum class MacOSVersion : uint32_t {
2424
MACOS_VER_14_4_PLUS,
2525
MACOS_VER_15_0_PLUS,
2626
MACOS_VER_15_1_PLUS,
27+
MACOS_VER_15_2_PLUS,
2728
};
2829

2930
//-----------------------------------------------------------------

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
7373
static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
7474
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
7575
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
76+
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
7677

7778
switch (version) {
7879
case MacOSVersion::MACOS_VER_13_1_PLUS:
@@ -89,6 +90,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
8990
return _macos_15_0_plus;
9091
case MacOSVersion::MACOS_VER_15_1_PLUS:
9192
return _macos_15_1_plus;
93+
case MacOSVersion::MACOS_VER_15_2_PLUS:
94+
return _macos_15_2_plus;
9295
default:
9396
return false;
9497
}

aten/src/ATen/native/FractionalMaxPool2d.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
109109
/* get contiguous gradOutput */
110110
auto gradOutput = gradOutput_.contiguous();
111111

112-
TORCH_CHECK(outputW == gradOutput.size(widthDim),
113-
"fractional_max_pool2d_backward(): gradOutput width unexpected");
114-
TORCH_CHECK(outputH == gradOutput.size(heightDim),
115-
"fractional_max_pool2d_backward(): gradOutput height unexpected");
112+
auto expectedOutputShape = IntArrayRef(input.sizes().data(), ndims - 2).vec();
113+
expectedOutputShape.push_back(outputH);
114+
expectedOutputShape.push_back(outputW);
115+
TORCH_CHECK(gradOutput.sizes().equals(expectedOutputShape),
116+
"fractional_max_pool2d_backward(): gradOutput sizes unexpected");
117+
TORCH_CHECK(indices.sizes().equals(expectedOutputShape),
118+
"fractional_max_pool2d_backward(): indices sizes unexpected");
116119

117120
/* resize */
118121
if (ndims == 3) {

aten/src/ATen/native/cuda/EmbeddingBag.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
136136
accscalar_t weightFeatSum = 0;
137137
int64_t bag_size_ = 0;
138138
for (int64_t emb = begin; emb < end; emb++) {
139-
bool pad = (input[emb] == padding_idx);
140-
CUDA_KERNEL_ASSERT(input[emb] < numRows);
141-
const int64_t weightRow = input[emb] * weight_stride0;
139+
index_t input_idx = input[emb];
140+
bool pad = (input_idx == padding_idx);
141+
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
142+
const int64_t weightRow = input_idx * weight_stride0;
142143
scalar_t weightValue = weightFeat[weightRow];
143144
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
144145
if (per_sample_weights) {

aten/src/ATen/native/cuda/RowwiseScaledMM.cu

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
4343
}
4444

4545

46+
#include <cutlass/version.h>
4647
#include <cutlass/core_io.h>
4748
#include <cutlass/cutlass.h>
4849
#include <cutlass/gemm/device/gemm.h>
@@ -174,7 +175,11 @@ void f8f8bf16_rowwise_impl(
174175

175176
// Implement rowwise scaling epilogue.
176177
constexpr int ColBroadcastStages = 0;
178+
#if CUTLASS_VERSION == 351
179+
constexpr int RowBroadcastStages = 0;
180+
#else
177181
constexpr int RowBroadcastStages = PingPong::value ? 2 : 1;
182+
#endif
178183

179184
using XScale = cutlass::epilogue::fusion::
180185
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
@@ -191,15 +196,24 @@ void f8f8bf16_rowwise_impl(
191196

192197
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
193198

199+
#if CUTLASS_VERSION == 351
200+
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
201+
Multiply,
202+
WScale,
203+
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
204+
#else
205+
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
206+
Multiply,
207+
XScale,
208+
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>;
209+
#endif
210+
194211
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
195212
Cast,
196213
cutlass::epilogue::fusion::Sm90EVT<
197214
Add,
198215
Bias,
199-
cutlass::epilogue::fusion::Sm90EVT<
200-
Multiply,
201-
XScale,
202-
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
216+
AccumScale>>;
203217

204218
using CollectiveEpilogue =
205219
typename cutlass::epilogue::collective::CollectiveBuilder<

0 commit comments

Comments
 (0)