system/block_wrapper.py
gitea_admin_user 9a8753f490
All checks were successful
CI Workflow / Testing the Block (push) Successful in 1m43s
CI Workflow / Containerize the Block (push) Successful in 1m42s
Add initial files
2025-04-09 16:45:38 +00:00

259 lines
9.4 KiB
Python

import importlib
import json
import asyncio
import logging
import os
import re
import sys
import requests
from temporalio import activity
from temporalio.exceptions import ApplicationError
from jsonschema import validate, ValidationError
from temporalio.client import Client
from temporalio.worker import Worker
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Automatically determine if in a test environment
IS_TEST_ENVIRONMENT = "unittest" in sys.modules
# Environment variables
REPO_NAME = os.getenv('REPO_NAME')
BRANCH_NAME = os.getenv('BRANCH_NAME')
COMMIT_ID = os.getenv('VERSION')
NAMESPACE = os.getenv('NAMESPACE')
FLOWX_ENGINE_ADDRESS = os.getenv('FLOWX_ENGINE_ADDRESS')
SQLPAD_API_URL = os.getenv('SQLPAD_API_URL')
if not BRANCH_NAME or not COMMIT_ID or not NAMESPACE or not FLOWX_ENGINE_ADDRESS:
raise ValueError("Missing required environment variables.")
COMMIT_ID_SHORT = COMMIT_ID[:10]
# Sanitize name function
def sanitize_name(name):
sanitized = re.sub(r'\W|^(?=\d)', '_', name)
sanitized = re.sub(r'_+', '_', sanitized)
return sanitized.strip('_')
BLOCK_NAME = REPO_NAME + "_" + BRANCH_NAME
block_name_safe = sanitize_name(BLOCK_NAME)
commit_id_safe = sanitize_name(COMMIT_ID_SHORT)
# Construct the task queue name
TASK_QUEUE = f"{block_name_safe}_{commit_id_safe}"
# Load JSON schema
def load_schema(schema_path):
try:
with open(schema_path, 'r') as schema_file:
return json.load(schema_file)
except Exception as e:
logger.error("Failed to load schema from %s: %s", schema_path, e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Schema loading failed: {e}")
else:
raise ValueError(f"Schema loading failed: {e}")
# Validate input against request schema
def validate_input(input_data):
request_schema = load_schema("/app/request_schema.json")
try:
validate(instance=input_data, schema=request_schema)
logger.info("Input data validated successfully")
except ValidationError as e:
logger.error("Input validation failed: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Input validation error: {e}")
else:
raise ValueError(f"Input validation error: {e}")
# Validate output against response schema
def validate_output(output_data):
response_schema = load_schema("/app/response_schema.json")
try:
validate(instance=output_data, schema=response_schema)
logger.info("Output data validated successfully")
except ValidationError as e:
logger.error("Output validation failed: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Output validation error: {e}")
else:
raise ValueError(f"Output validation error: {e}")
# Get the connection ID from config.json
def get_connection_id(namespace):
response_schema = load_schema("/app/config.json")
for item in response_schema:
if item.get("namespace") == namespace:
logger.info("Got the connectionID")
return item.get("connectionId")
logger.error("Provided Namespace not found.")
raise ValueError(f"Namespace '{namespace}' not found")
# Read SQL file and replace placeholders
def construct_sql(input_data):
try:
with open("/app/main.sql", "r") as sql_file:
sql_template = sql_file.read()
for key, value in input_data.items():
placeholder = f"${key}"
if value is None:
replacement = "NULL"
elif isinstance(value, bool):
replacement = "TRUE" if value else "FALSE"
elif isinstance(value, str):
replacement = f"'{value}'"
else:
replacement = str(value)
sql_template = sql_template.replace(placeholder, replacement)
logger.info("SQL query constructed.")
return sql_template.strip()
except Exception as e:
logger.error("Error processing SQL template: %s", e)
raise ApplicationError(f"SQL template error: {e}")
def get_batch_results(batch_id, retry_interval=0.05, max_retries=5):
retries = 0
while retries < max_retries:
try:
response = requests.get(f"{SQLPAD_API_URL}/api/batches/{batch_id}")
response.raise_for_status()
batch_status = response.json()
status = batch_status.get("status")
if status in ["finished", "error"]:
statements = batch_status.get("statements", [])
if not statements:
raise ApplicationError("No statements found in batch response.")
statement = statements[0]
statement_id = statement.get("id")
error = statement.get("error")
columns = statement.get("columns", None)
sql_text = batch_status.get("batchText", "").strip().lower()
logger.info(f"statements: {statements}")
logger.info(f"error from batches result {error}, statement: {statement_id}, columns: {columns}")
if error:
raise ApplicationError(f"SQL execution failed: {error}")
is_select_query = sql_text.startswith("select") or (
sql_text.startswith("with") and "select" in sql_text
)
if is_select_query and not columns:
raise ApplicationError("SELECT query did not return columns, cannot process data.")
return status, statement_id, error, columns, is_select_query
time.sleep(retry_interval)
retries += 1
except requests.RequestException as e:
logger.error("Failed to fetch batch results: %s", e)
raise ApplicationError(f"Failed to fetch batch results: {e}")
raise ApplicationError("SQLPad batch execution timed out.")
def execute_sqlpad_query(connection_id, sql_query):
payload = {
"connectionId": connection_id,
"name": "",
"batchText": sql_query,
"selectedText": ""
}
try:
response = requests.post(f"{SQLPAD_API_URL}/api/batches", json=payload)
response.raise_for_status()
batch_response = response.json()
batch_id = batch_response.get("statements", [{}])[0].get("batchId")
logger.info(f"Batch ID from the batches API response {batch_id}")
if not batch_id:
raise ApplicationError("Batch ID not found in SQLPad response.")
status, statement_id, error, columns, is_select_query = get_batch_results(batch_id)
if not is_select_query:
return {"status": status, "error": error}
result_response = requests.get(f"{SQLPAD_API_URL}/api/statements/{statement_id}/results")
result_response.raise_for_status()
result_data = result_response.json()
type_mapping = {
"number": float,
"string": str,
"date": str,
"boolean": bool,
"timestamp": str,
}
column_names_list = [col["name"] for col in columns]
column_types_list = [col["datatype"] for col in columns]
converted_data = [
[
type_mapping.get(dtype, str)(value) if value is not None else None
for dtype, value in zip(column_types_list, row)
]
for row in result_data
]
results_dict_list = [dict(zip(column_names_list, row)) for row in converted_data]
logger.info(f"results_dict_list: {results_dict_list}")
return {"results": results_dict_list}
except requests.RequestException as e:
logger.error("SQLPad API request failed: %s", e)
raise ApplicationError(f"SQLPad API request failed: {e}")
@activity.defn
async def block_main_activity(input_data):
validate_input(input_data)
try:
sql_query = construct_sql(input_data)
logger.info(f"constructed sql query: {sql_query}")
connection_id = get_connection_id(NAMESPACE)
if connection_id:
logger.info(f"connection id exists {connection_id}")
result = execute_sqlpad_query(connection_id, sql_query)
validate_output(result)
logger.info(f"final result for the query: {result}")
return result
else:
logger.error("connection id not exists, please add the connection id according to the namespace.")
raise ApplicationError("connection id not exists, please add the connection id according to the namespace.")
except Exception as e:
logger.error("Error executing query execution: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Error during block execution: {e}") from e
else:
raise RuntimeError("Error during query execution") from e
async def main():
try:
client = await Client.connect(FLOWX_ENGINE_ADDRESS, namespace=NAMESPACE)
worker = Worker(
client,
task_queue=TASK_QUEUE,
activities=[block_main_activity],
)
logger.info("Worker starting, listening to task queue: %s", TASK_QUEUE)
await worker.run()
except Exception as e:
logger.critical("Worker failed to start: %s", e)
raise
if __name__ == "__main__":
asyncio.run(main())