-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmulti_GPU_demo.py
More file actions
127 lines (100 loc) · 3.58 KB
/
multi_GPU_demo.py
File metadata and controls
127 lines (100 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# multi_GPU_demo.py
# ================
#
# This Jython script demonstrates how to process images
# on multiple GPUs (or: OpenCL devices) in parallel.
#
# Author: Robert Haase, [email protected]
# August 2019
#
########################################################
from ij import IJ;
from java.lang import Thread;
# The Processor class extends Thread so that we can run it
# in parallel
class Processor(Thread):
# the CLIJ instance doing the heavy work
clij = None;
# the image which should be processed
image = None;
# a flag that says some work is ongoing
working = False;
# a flag that says is work was done in the past
finished = False;
# Constructor
def __init__(self, clij):
self.clij = clij;
# sets the image which should be processed
def setImage(self, image):
self.image = image;
# the actual procedure. Run processor.start() to get started in parallel.
def run(self):
# set status flags and initialize
self.finished = False;
self.working = True;
# print("" + str(self.clij) + " starts working...\n");
clij = self.clij;
# push the image to GPU memory
input_image = clij.push(self.image);
# allocate more memory on the GPU for temp and resul images
temp_image = clij.create(input_image);
backgroundSubtracted_image = clij.create(input_image);
max_projection_image = clij.create([input_image.getWidth(), input_image.getHeight()], input_image.getNativeType());
# perform a background-subtracted maximum projection
clij.op().blur(input_image, temp_image, 5, 5, 1);
clij.op().subtract(input_image, temp_image, backgroundSubtracted_image);
clij.op().maximumZProjection(backgroundSubtracted_image, max_projection_image);
# pull result back from GPU memory and show it
result = clij.pull(max_projection_image);
# result.show();
# IJ.run("Enhance Contrast", "saturated=0.35");
# clean up by the end
input_image.close();
temp_image.close();
backgroundSubtracted_image.close();
max_projection_image.close();
# set status flags
self.working = False;
self.finished = True;
def isWorking(self):
return self.working;
def isFinished(self):
return self.finished;
def getCLIJ(self):
return self.clij;
#imp = IJ.openImage("C:/structure/data/2018-05-23-16-18-13-89-Florence_multisample/processed/tif/000116.raw.tif");
imp = IJ.openImage("https://bds.mpi-cbg.de/CLIJ_benchmarking_data/000461.raw.tif");
from net.haesleinhuepf.clij import CLIJ;
# print out available OpenCL devices
print("Available devices:");
for name in CLIJ.getAvailableDeviceNames():
print(name);
# initialize a hand full of processors
processors = []
for i in range(0, len(CLIJ.getAvailableDeviceNames())):
processors.append(Processor(CLIJ(i)));
from java.lang import System;
startTime = System.currentTimeMillis();
# loop until a given number of images was processed
processed_images = 0;
while(processed_images < 10):
# go trough all processors and see if one is doing nothing
for j in range(0, len(processors)):
processor = processors[j];
if(not processor.isWorking()):
# found a sleeping processor!
# was he done with something?
if (processor.isFinished()):
processed_images += 1;
# update log
IJ.log("\\Clear");
IJ.log("Processed images: " + str(processed_images));
# replace it with a new processor
processor = Processor(processor.getCLIJ());
processors[j] = processor;
# Starting a processor
processor.setImage(imp);
processor.start();
# wait a moment
Thread.sleep(100);
print("Processing on " + str(len(CLIJ.getAvailableDeviceNames())) + " devices took " + str(System.currentTimeMillis() - startTime) + " ms");