162 lines
5.8 KiB
Python
162 lines
5.8 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from jinja2 import Environment, FileSystemLoader
|
|
import networkx as nx
|
|
import re # Import the re module for regular expressions
|
|
import argparse # Import the argparse module for parsing arguments
|
|
|
|
# Define paths for templates and output
|
|
TEMPLATE_DIR = 'templates'
|
|
|
|
# Load Jinja environment
|
|
env = Environment(loader=FileSystemLoader(TEMPLATE_DIR),
|
|
extensions=["jinja2.ext.do"])
|
|
|
|
# Add regex_replace filter
|
|
def regex_replace(s, pattern, repl):
|
|
return re.sub(pattern, repl, s)
|
|
|
|
env.filters['regex_replace'] = regex_replace
|
|
|
|
# Retrieve environment variables
|
|
REPO_NAME = os.getenv('REPO_NAME')
|
|
BRANCH_NAME = os.getenv('BRANCH_NAME')
|
|
COMMIT_ID = os.getenv('VERSION')
|
|
NAMESPACE = os.getenv('NAMESPACE')
|
|
|
|
if not BRANCH_NAME or not COMMIT_ID or not BRANCH_NAME or not NAMESPACE:
|
|
raise ValueError("Environment variables BRANCH_NAME, VERSION, BRANCH_NAME, NAMESPACE must be set.")
|
|
|
|
# Shorten the commit ID to the first 10 characters
|
|
COMMIT_ID_SHORT = COMMIT_ID[:10]
|
|
|
|
# Sanitize flow name and commit ID to create a valid task queue name
|
|
def sanitize_name(name):
|
|
# Replace non-alphanumeric characters or invalid start with underscores
|
|
sanitized = re.sub(r'\W|^(?=\d)', '_', name)
|
|
# Replace multiple consecutive underscores with a single underscore
|
|
sanitized = re.sub(r'_+', '_', sanitized)
|
|
# Remove trailing underscores
|
|
return sanitized.strip('_')
|
|
|
|
FLOW_NAME = REPO_NAME + "_" + BRANCH_NAME
|
|
flow_name_safe = sanitize_name(FLOW_NAME)
|
|
commit_id_safe = sanitize_name(COMMIT_ID_SHORT)
|
|
|
|
# workflow_class_name = f"{flow_name_safe}_{commit_id_safe}"
|
|
workflow_class_name = flow_name_safe
|
|
|
|
def load_flow_definition(file_path):
|
|
"""Load the flow definition from a JSON file."""
|
|
try:
|
|
with open(file_path, "r") as f:
|
|
return json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error loading JSON file '{file_path}': {e}")
|
|
sys.exit(1)
|
|
|
|
def determine_flow_type(flow_definition):
|
|
"""Determine the flow type based on the edges."""
|
|
nodes = {node["id"]: node for node in flow_definition["nodes"]}
|
|
edges = flow_definition["edges"]
|
|
has_parallel = False
|
|
for node_id in nodes:
|
|
outgoing_edges = [e for e in edges if e["source"] == node_id]
|
|
if len(outgoing_edges) > 1:
|
|
has_parallel = True
|
|
break
|
|
if has_parallel:
|
|
return "hybrid"
|
|
elif len(edges) > 0:
|
|
return "sequential"
|
|
else:
|
|
return "parallel"
|
|
|
|
def collect_root_input_keys(flow_definition):
|
|
"""Collect all unique root input keys from the flow definition."""
|
|
root_input_keys = set()
|
|
for node in flow_definition.get("nodes", []): # Safeguard for missing nodes
|
|
properties = node.get("data", {}).get("nodeConfig", {}).get("schema", {}).get("properties", {})
|
|
for key, value in properties.items():
|
|
source = value.get("source")
|
|
if isinstance(source, str) and source.startswith("$root."):
|
|
root_input_keys.add(source[6:]) # Adjusted to capture full path after $root.
|
|
return list(root_input_keys)
|
|
|
|
def build_execution_graph(flow_definition):
|
|
"""Builds an execution graph from the flow definition using networkx."""
|
|
G = nx.DiGraph()
|
|
nodes = {node["id"]: node for node in flow_definition["nodes"]}
|
|
edges = flow_definition["edges"]
|
|
|
|
# Add nodes
|
|
for node_id, node in nodes.items():
|
|
G.add_node(node_id, node=node)
|
|
|
|
# Add edges
|
|
for edge in edges:
|
|
G.add_edge(edge["source"], edge["target"])
|
|
|
|
return G
|
|
|
|
def get_execution_steps(G):
|
|
"""Returns a list of execution steps, each containing nodes that can be run in parallel."""
|
|
try:
|
|
levels = list(nx.topological_generations(G))
|
|
execution_steps = [list(level) for level in levels]
|
|
return execution_steps
|
|
except nx.NetworkXUnfeasible:
|
|
print("Error: Workflow graph has cycles.")
|
|
sys.exit(1)
|
|
|
|
def generate_workflow(flow_definition, output_file):
|
|
"""Generate the workflow code using the Jinja template."""
|
|
template = env.get_template('workflow_template.py.j2')
|
|
flow_type = determine_flow_type(flow_definition)
|
|
root_input_keys = collect_root_input_keys(flow_definition)
|
|
|
|
# Filter out requestNode from nodes
|
|
filtered_nodes = {node["id"]: node for node in flow_definition["nodes"] if node["type"] != "requestNode"}
|
|
|
|
# Filter edges to exclude connections to or from filtered nodes
|
|
filtered_edges = [
|
|
edge for edge in flow_definition["edges"]
|
|
if edge["source"] in filtered_nodes and edge["target"] in filtered_nodes
|
|
]
|
|
|
|
# Build execution graph and steps
|
|
filtered_flow_definition = {
|
|
"nodes": list(filtered_nodes.values()),
|
|
"edges": filtered_edges,
|
|
}
|
|
G = build_execution_graph(filtered_flow_definition)
|
|
execution_steps = get_execution_steps(G)
|
|
|
|
# Render the workflow template
|
|
workflow_code = template.render(
|
|
workflow_class_name=workflow_class_name,
|
|
flow_type=flow_type,
|
|
root_input_keys=root_input_keys,
|
|
execution_steps=execution_steps,
|
|
nodes=filtered_nodes
|
|
)
|
|
with open(output_file, "w") as f:
|
|
f.write(workflow_code)
|
|
print(f"Generated workflow: {output_file}")
|
|
|
|
if __name__ == "__main__":
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser(description="Generate Temporal workflow from JSON flow definition.")
|
|
parser.add_argument("--input-file", type=str, required=True,
|
|
help="Path to the flow definition JSON file.")
|
|
parser.add_argument("--output-file", type=str, required=True,
|
|
help="Path to the generated workflow output file.")
|
|
args = parser.parse_args()
|
|
|
|
# Load the flow definition and generate the workflow
|
|
flow_file = args.input_file
|
|
output_file = args.output_file
|
|
flow_def = load_flow_definition(flow_file)
|
|
generate_workflow(flow_def, output_file)
|