5555 WorkflowError ,
5656 print_exception_warning ,
5757)
58+ from snakemake .settings .types import PrintDag
5859from snakemake .io import (
5960 _IOFile ,
6061 PeriodicityDetector ,
@@ -2382,7 +2383,10 @@ def rule_dot(self):
23822383 graph = defaultdict (set )
23832384 for job in self .jobs :
23842385 graph [job .rule ].update (dep .rule for dep in self ._dependencies [job ])
2385- return self ._dot (graph )
2386+ if self .workflow .dag_settings .print_dag_as == str (PrintDag .DOT ):
2387+ return self ._dot (graph )
2388+ elif self .workflow .dag_settings .print_dag_as == str (PrintDag .MERMAID_JS ):
2389+ return self ._mermaid_js (graph )
23862390
23872391 def dot (self ):
23882392 def node2style (job ):
@@ -2407,6 +2411,70 @@ def format_wildcard(wildcard):
24072411 dag , node2rule = node2rule , node2style = node2style , node2label = node2label
24082412 )
24092413
2414+ def mermaid_js (self ):
2415+ def node2style (job ):
2416+ if not self .needrun (job ):
2417+ return ",stroke-dasharray: 5 5"
2418+ return ""
2419+
2420+ def format_wildcard (wildcard ):
2421+ name , value = wildcard
2422+ return f"{ name } : { value } "
2423+
2424+ node2label = lambda job : " - " .join (
2425+ chain (
2426+ [job .rule .name ], sorted (map (format_wildcard , self .new_wildcards (job )))
2427+ )
2428+ )
2429+
2430+ dag = {job : self ._dependencies [job ] for job in self .jobs }
2431+
2432+ return self ._mermaid_js (dag , node2style = node2style , node2label = node2label )
2433+
2434+ def _mermaid_js (
2435+ self , graph , node2style = lambda node : "" , node2label = lambda node : node
2436+ ):
2437+ def hsv_to_htmlhexrgb (h , s , v ):
2438+ import colorsys
2439+
2440+ hex_r , hex_g , hex_b = (round (255 * x ) for x in colorsys .hsv_to_rgb (h , s , v ))
2441+ return "#{hex_r:0>2X}{hex_g:0>2X}{hex_b:0>2X}" .format (
2442+ hex_r = hex_r , hex_g = hex_g , hex_b = hex_b
2443+ )
2444+
2445+ # color the rules - sorting by name first gives deterministic output
2446+ rules = sorted (self .rules , key = lambda r : r .name )
2447+ huefactor = 2 / (3 * len (rules ))
2448+ rulecolor = {
2449+ rule .name : hsv_to_htmlhexrgb (i * huefactor , 0.6 , 0.85 )
2450+ for i , rule in enumerate (rules )
2451+ }
2452+ nodes_headers = [
2453+ f"\t id{ index } [{ node2label (node )} ]" for index , node in enumerate (graph )
2454+ ]
2455+ nodes_styles = [
2456+ f"\t style id{ index } fill:{ rulecolor [str (node )]} ,stroke-width:2px,color:#333333{ node2style (node )} "
2457+ for index , node in enumerate (graph )
2458+ ]
2459+ edges = [
2460+ f"\t id{ index } --> id{ index_dep } "
2461+ for index , (_ , deps ) in enumerate (graph .items ())
2462+ for index_dep , _ in enumerate (deps )
2463+ ]
2464+ return (
2465+ textwrap .dedent (
2466+ """\
2467+ ---
2468+ title: DAG
2469+ ---
2470+ flowchart TB
2471+ """
2472+ )
2473+ + "{}\n {}\n {}" .format (
2474+ "\n " .join (nodes_headers ), "\n " .join (nodes_styles ), "\n " .join (edges )
2475+ )
2476+ )
2477+
24102478 def _dot (
24112479 self ,
24122480 graph ,
@@ -3021,7 +3089,10 @@ def norm_rule_relpath(f, rule):
30213089 return files
30223090
30233091 def __str__ (self ):
3024- return self .dot ()
3092+ if self .workflow .dag_settings .print_dag_as == str (PrintDag .DOT ):
3093+ return self .dot ()
3094+ if self .workflow .dag_settings .print_dag_as == str (PrintDag .MERMAID_JS ):
3095+ return self .mermaid_js ()
30253096
30263097 def __len__ (self ):
30273098 return self ._len
0 commit comments