import importlib import json import asyncio import logging import os import re import sys import requests from temporalio import activity from temporalio.exceptions import ApplicationError from jsonschema import validate, ValidationError from temporalio.client import Client from temporalio.worker import Worker import time # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", ) logger = logging.getLogger(__name__) # Automatically determine if in a test environment IS_TEST_ENVIRONMENT = "unittest" in sys.modules # Environment variables REPO_NAME = os.getenv('REPO_NAME') BRANCH_NAME = os.getenv('BRANCH_NAME') COMMIT_ID = os.getenv('VERSION') NAMESPACE = os.getenv('NAMESPACE') FLOWX_ENGINE_ADDRESS = os.getenv('FLOWX_ENGINE_ADDRESS') SQLPAD_API_URL = os.getenv('SQLPAD_API_URL') if not BRANCH_NAME or not COMMIT_ID or not NAMESPACE or not FLOWX_ENGINE_ADDRESS: raise ValueError("Missing required environment variables.") COMMIT_ID_SHORT = COMMIT_ID[:10] # Sanitize name function def sanitize_name(name): sanitized = re.sub(r'\W|^(?=\d)', '_', name) sanitized = re.sub(r'_+', '_', sanitized) return sanitized.strip('_') BLOCK_NAME = REPO_NAME + "_" + BRANCH_NAME block_name_safe = sanitize_name(BLOCK_NAME) commit_id_safe = sanitize_name(COMMIT_ID_SHORT) # Construct the task queue name TASK_QUEUE = f"{block_name_safe}_{commit_id_safe}" # Load JSON schema def load_schema(schema_path): try: with open(schema_path, 'r') as schema_file: return json.load(schema_file) except Exception as e: logger.error("Failed to load schema from %s: %s", schema_path, e) if not IS_TEST_ENVIRONMENT: raise ApplicationError(f"Schema loading failed: {e}") else: raise ValueError(f"Schema loading failed: {e}") # Validate input against request schema def validate_input(input_data): request_schema = load_schema("/app/request_schema.json") try: validate(instance=input_data, schema=request_schema) logger.info("Input data validated successfully") except ValidationError as e: logger.error("Input validation failed: %s", e) if not IS_TEST_ENVIRONMENT: raise ApplicationError(f"Input validation error: {e}") else: raise ValueError(f"Input validation error: {e}") # Validate output against response schema def validate_output(output_data): response_schema = load_schema("/app/response_schema.json") try: validate(instance=output_data, schema=response_schema) logger.info("Output data validated successfully") except ValidationError as e: logger.error("Output validation failed: %s", e) if not IS_TEST_ENVIRONMENT: raise ApplicationError(f"Output validation error: {e}") else: raise ValueError(f"Output validation error: {e}") # Get the connection ID from config.json def get_connection_id(namespace): response_schema = load_schema("/app/config.json") for item in response_schema: if item.get("namespace") == namespace: logger.info("Got the connectionID") return item.get("connectionId") logger.error("Provided Namespace not found.") raise ValueError(f"Namespace '{namespace}' not found") # Read SQL file and replace placeholders def construct_sql(input_data): try: with open("/app/main.sql", "r") as sql_file: sql_template = sql_file.read() for key, value in input_data.items(): placeholder = f"${key}" if value is None: replacement = "NULL" elif isinstance(value, bool): replacement = "TRUE" if value else "FALSE" elif isinstance(value, str): replacement = f"'{value}'" else: replacement = str(value) sql_template = sql_template.replace(placeholder, replacement) logger.info("SQL query constructed.") return sql_template.strip() except Exception as e: logger.error("Error processing SQL template: %s", e) raise ApplicationError(f"SQL template error: {e}") def get_batch_results(batch_id, retry_interval=0.05, max_retries=5): retries = 0 while retries < max_retries: try: response = requests.get(f"{SQLPAD_API_URL}/api/batches/{batch_id}") response.raise_for_status() batch_status = response.json() status = batch_status.get("status") if status in ["finished", "error"]: statements = batch_status.get("statements", []) if not statements: raise ApplicationError("No statements found in batch response.") statement = statements[0] statement_id = statement.get("id") error = statement.get("error") columns = statement.get("columns", None) sql_text = batch_status.get("batchText", "").strip().lower() logger.info(f"statements: {statements}") logger.info(f"error from batches result {error}, statement: {statement_id}, columns: {columns}") if error: raise ApplicationError(f"SQL execution failed: {error}") is_select_query = sql_text.startswith("select") or ( sql_text.startswith("with") and "select" in sql_text ) if is_select_query and not columns: raise ApplicationError("SELECT query did not return columns, cannot process data.") return status, statement_id, error, columns, is_select_query time.sleep(retry_interval) retries += 1 except requests.RequestException as e: logger.error("Failed to fetch batch results: %s", e) raise ApplicationError(f"Failed to fetch batch results: {e}") raise ApplicationError("SQLPad batch execution timed out.") def execute_sqlpad_query(connection_id, sql_query): payload = { "connectionId": connection_id, "name": "", "batchText": sql_query, "selectedText": "" } try: response = requests.post(f"{SQLPAD_API_URL}/api/batches", json=payload) response.raise_for_status() batch_response = response.json() batch_id = batch_response.get("statements", [{}])[0].get("batchId") logger.info(f"Batch ID from the batches API response {batch_id}") if not batch_id: raise ApplicationError("Batch ID not found in SQLPad response.") status, statement_id, error, columns, is_select_query = get_batch_results(batch_id) if not is_select_query: return {"status": status, "error": error} result_response = requests.get(f"{SQLPAD_API_URL}/api/statements/{statement_id}/results") result_response.raise_for_status() result_data = result_response.json() type_mapping = { "number": float, "string": str, "date": str, "boolean": bool, "timestamp": str, } column_names_list = [col["name"] for col in columns] column_types_list = [col["datatype"] for col in columns] converted_data = [ [ type_mapping.get(dtype, str)(value) if value is not None else None for dtype, value in zip(column_types_list, row) ] for row in result_data ] results_dict_list = [dict(zip(column_names_list, row)) for row in converted_data] logger.info(f"results_dict_list: {results_dict_list}") return {"results": results_dict_list} except requests.RequestException as e: logger.error("SQLPad API request failed: %s", e) raise ApplicationError(f"SQLPad API request failed: {e}") @activity.defn async def block_main_activity(input_data): validate_input(input_data) try: sql_query = construct_sql(input_data) logger.info(f"constructed sql query: {sql_query}") connection_id = get_connection_id(NAMESPACE) if connection_id: logger.info(f"connection id exists {connection_id}") result = execute_sqlpad_query(connection_id, sql_query) validate_output(result) logger.info(f"final result for the query: {result}") return result else: logger.error("connection id not exists, please add the connection id according to the namespace.") raise ApplicationError("connection id not exists, please add the connection id according to the namespace.") except Exception as e: logger.error("Error executing query execution: %s", e) if not IS_TEST_ENVIRONMENT: raise ApplicationError(f"Error during block execution: {e}") from e else: raise RuntimeError("Error during query execution") from e async def main(): try: client = await Client.connect(FLOWX_ENGINE_ADDRESS, namespace=NAMESPACE) worker = Worker( client, task_queue=TASK_QUEUE, activities=[block_main_activity], ) logger.info("Worker starting, listening to task queue: %s", TASK_QUEUE) await worker.run() except Exception as e: logger.critical("Worker failed to start: %s", e) raise if __name__ == "__main__": asyncio.run(main())