Skip to content

Commit 9f7dde6

Browse files
committed
Update
[ghstack-poisoned]
1 parent a67bf49 commit 9f7dde6

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

aten/src/ATen/native/mps/MetalShaderLibrary.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ constexpr bool has_size_type_v = has_size_type<T>::value;
4545

4646
class MetalKernelFunction {
4747
public:
48-
MetalKernelFunction(MTLComputePipelineState_t cps_) : cps(cps_) {}
49-
MetalKernelFunction(MetalKernelFunction&) = delete;
48+
MetalKernelFunction(MTLComputePipelineState_t cps_);
5049
~MetalKernelFunction();
50+
MetalKernelFunction(MetalKernelFunction&) = delete;
5151
// Shader properties
5252
uint64_t getMaxThreadsPerThreadgroup() const;
5353
uint64_t getThreadExecutionWidth() const;
@@ -152,6 +152,7 @@ class DynamicMetalShaderLibrary : public MetalShaderLibrary {
152152
// Compile right away
153153
getLibrary();
154154
}
155+
~DynamicMetalShaderLibrary();
155156
};
156157

157158
} // namespace at::native::mps

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,17 @@ static dispatch_data_t getSectionData(const std::string& name) {
923923
return l;
924924
}
925925

926+
// DynamicMetalShaderLibrary implementation
927+
DynamicMetalShaderLibrary::~DynamicMetalShaderLibrary() {
928+
[library release];
929+
}
930+
926931
// MetalKernelFunction implementation
927-
MetalKernelFunction::~MetalKernelFunction() {}
932+
MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_) : cps([cps_ retain]) {}
933+
934+
MetalKernelFunction::~MetalKernelFunction() {
935+
[cps release];
936+
}
928937

929938
void MetalKernelFunction::runCommandBlock(std::function<void(void)> run) {
930939
dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() {

0 commit comments

Comments
 (0)