Skip to content

Commit 96c782c

Browse files
committed
Preprocess custom_symbols in tqdm.__init__
Signed-off-by: Stephen L. <[email protected]>
1 parent 23ebf76 commit 96c782c

1 file changed

Lines changed: 66 additions & 53 deletions

File tree

tqdm/_tqdm.py

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def print_status(s):
121121
@staticmethod
122122
def format_meter(n, total, elapsed, ncols=None, prefix='',
123123
ascii=False, unit='it', unit_scale=False, rate=None,
124-
bar_format=None):
124+
bar_format=None, custom_symbols=None):
125125
"""
126126
Return a string-based progress bar given some parameters
127127
@@ -167,27 +167,6 @@ def format_meter(n, total, elapsed, ncols=None, prefix='',
167167
-------
168168
out : Formatted meter and stats, ready to display.
169169
"""
170-
171-
def extract_symbols(s, start_tag, end_tag):
172-
"""
173-
Extract custom symbols enclosed by tags, with the first character being the separator.
174-
Eg, extract_symbols('before{start},1,2,3,4,#{end}after', '{start}', '{end}')
175-
176-
Returns
177-
-------
178-
out, out2 : list of symbols, input string without tagged part
179-
"""
180-
start = s.find(start_tag)
181-
start_content = start+len(start_tag)
182-
end_content = s.find(end_tag)
183-
end = end_content+len(end_tag)
184-
185-
sep = s[start_content:start_content+1]
186-
return s[start_content+1:end_content].split(sep), s[:start] + s[end:]
187-
188-
# Custom symbols variables
189-
c_symbols = None
190-
looping = False
191170

192171
# sanity check: total
193172
if total and n > total:
@@ -256,22 +235,6 @@ def extract_symbols(s, start_tag, end_tag):
256235
# 'bar': full_bar # replaced by procedure below
257236
}
258237

259-
# Custom symbols extraction
260-
for tag in ['bar_symbols', 'bar_symbols_ascii', 'bar_symbols_loop', 'bar_symbols_loop_ascii']:
261-
start_tag = '{' + tag + '}'
262-
end_tag = '{/' + tag + '}'
263-
# Check if tag is found in the template
264-
if start_tag in bar_format and end_tag in bar_format:
265-
# Get ascii symbols if ascii env, else unicode
266-
if (ascii and 'ascii' in tag) or (not ascii and not 'ascii' in tag):
267-
c_symbols, bar_format = extract_symbols(bar_format, start_tag, end_tag)
268-
# Looping symbol?
269-
if 'loop' in tag:
270-
looping = True
271-
# Need to clean all tags from template
272-
else:
273-
_, bar_format = extract_symbols(bar_format, start_tag, end_tag)
274-
275238
# Interpolate supplied bar format with the dict
276239
if '{bar}' in bar_format:
277240
# Format left/right sides of the bar, and format the bar
@@ -289,28 +252,33 @@ def extract_symbols(s, start_tag, end_tag):
289252
else 10
290253

291254
# custom symbols format
292-
# need to provide both ascii and unicode versions of custom symbols,
293-
# eg, if ascii env but user provided only unicode symbols, then
294-
# will revert to default ascii bar.
295-
if c_symbols:
255+
# need to provide both ascii and unicode versions of custom symbols
256+
if custom_symbols:
257+
# get ascii or unicode template
258+
if ascii:
259+
c_symb = custom_symbols[1]
260+
else:
261+
c_symb = custom_symbols[2]
296262
# looping symbols: just update the symbol animation at each iteration
297-
if looping:
263+
if custom_symbols[0] == 'loop':
298264
# increment one step in the animation at each step
299-
bar = c_symbols[divmod(n, len(c_symbols))[1]]
265+
bar = c_symb[divmod(n, len(c_symb))[1]]
300266
frac_bar = ''
301267

302268
bar_length = N_BARS # avoid the filling
303269
frac_bar_length = len(frac_bar)
304270
# normal progress symbols
305271
else:
272+
nb_symb = len(c_symb)
273+
len_filler = len(c_symb[-1])
306274
bar_length, frac_bar_length = divmod(
307-
int((frac/len(c_symbols[-1])) * N_BARS * len(c_symbols)), len(c_symbols))
275+
int((frac/len_filler) * N_BARS * nb_symb), nb_symb)
308276

309-
bar = c_symbols[-1] * bar_length # last symbol is always the filler
310-
frac_bar = c_symbols[frac_bar_length] if frac_bar_length \
277+
bar = c_symb[-1] * bar_length # last symbol is always the filler
278+
frac_bar = c_symb[frac_bar_length] if frac_bar_length \
311279
else ' '
312280
# update real bar length (if symbols > 1 char) for correct filler
313-
bar_length = bar_length * len(c_symbols[-1])
281+
bar_length = bar_length * len_filler
314282

315283
# ascii format
316284
elif ascii:
@@ -642,6 +610,46 @@ def __init__(self, iterable=None, desc=None, total=None, leave=True,
642610
if ascii is None:
643611
ascii = not _supports_unicode(file)
644612

613+
614+
# Custom symbols extraction
615+
custom_symbols = None
616+
if bar_format:
617+
looping = None
618+
c_symbols_ascii = None
619+
c_symbols_unicode = None
620+
found_tag = False
621+
for tag in ['bar_symbols', 'bar_symbols_ascii', 'bar_symbols_loop', 'bar_symbols_loop_ascii']:
622+
start_tag = '{' + tag + '}'
623+
end_tag = '{/' + tag + '}'
624+
# Check if tag is found in the template
625+
if start_tag in bar_format and end_tag in bar_format:
626+
found_tag = True
627+
# Extract custom symbols enclosed by tags, with the first character being the separator.
628+
# Eg, extract_symbols('before{start},1,2,3,4,#{end}after', '{start}', '{end}')
629+
start = bar_format.find(start_tag)
630+
start_content = start+len(start_tag)
631+
end_content = bar_format.find(end_tag)
632+
end = end_content+len(end_tag)
633+
sep = bar_format[start_content:start_content+1]
634+
c_symbols = bar_format[start_content+1:end_content].split(sep)
635+
# Cleanup all weird tags from bar_format else .format() crash
636+
bar_format = bar_format[:start] + bar_format[end:]
637+
638+
if 'ascii' in tag:
639+
c_symbols_ascii = c_symbols
640+
else:
641+
c_symbols_unicode = c_symbols
642+
643+
# Looping symbol?
644+
if 'loop' in tag:
645+
looping = True
646+
else:
647+
looping = False
648+
649+
# Compile the ascii/unicode bars in a nice argument for format_meter
650+
if found_tag:
651+
custom_symbols = ['loop' if looping else 'bar', c_symbols_ascii, c_symbols_unicode]
652+
645653
if bar_format and not ascii:
646654
# Convert bar format into unicode since terminal uses unicode
647655
bar_format = _unicode(bar_format)
@@ -670,6 +678,7 @@ def __init__(self, iterable=None, desc=None, total=None, leave=True,
670678
self.avg_time = None
671679
self._time = time
672680
self.bar_format = bar_format
681+
self.custom_symbols = custom_symbols
673682

674683
# Init the iterations counters
675684
self.last_print_n = initial
@@ -686,7 +695,8 @@ def __init__(self, iterable=None, desc=None, total=None, leave=True,
686695
self.moveto(self.pos)
687696
self.sp(self.format_meter(self.n, total, 0,
688697
(dynamic_ncols(file) if dynamic_ncols else ncols),
689-
self.desc, ascii, unit, unit_scale, None, bar_format))
698+
self.desc, ascii, unit, unit_scale, None,
699+
bar_format, custom_symbols))
690700
if self.pos:
691701
self.moveto(-self.pos)
692702

@@ -713,7 +723,8 @@ def __repr__(self):
713723
time() - self.last_print_t,
714724
self.ncols, self.desc, self.ascii, self.unit,
715725
self.unit_scale, 1 / self.avg_time
716-
if self.avg_time else None, self.bar_format)
726+
if self.avg_time else None,
727+
self.bar_format, self.custom_symbols)
717728

718729
def __lt__(self, other):
719730
# try:
@@ -770,6 +781,7 @@ def __iter__(self):
770781
smoothing = self.smoothing
771782
avg_time = self.avg_time
772783
bar_format = self.bar_format
784+
custom_symbols = self.custom_symbols
773785
_time = self._time
774786
format_meter = self.format_meter
775787

@@ -808,7 +820,8 @@ def __iter__(self):
808820
(dynamic_ncols(self.fp) if dynamic_ncols
809821
else ncols),
810822
self.desc, ascii, unit, unit_scale,
811-
1 / avg_time if avg_time else None, bar_format))
823+
1 / avg_time if avg_time else None,
824+
bar_format, custom_symbols))
812825

813826
if self.pos:
814827
self.moveto(-self.pos)
@@ -894,7 +907,7 @@ def update(self, n=1):
894907
else self.ncols),
895908
self.desc, self.ascii, self.unit, self.unit_scale,
896909
1 / self.avg_time if self.avg_time else None,
897-
self.bar_format))
910+
self.bar_format, self.custom_symbols))
898911

899912
if self.pos:
900913
self.moveto(-self.pos)
@@ -961,7 +974,7 @@ def fp_write(s):
961974
(self.dynamic_ncols(self.fp) if self.dynamic_ncols
962975
else self.ncols),
963976
self.desc, self.ascii, self.unit, self.unit_scale, None,
964-
self.bar_format))
977+
self.bar_format, self.custom_symbols))
965978
if pos:
966979
self.moveto(-pos)
967980
else:

0 commit comments

Comments
 (0)