259 lines
9.4 KiB
Python
259 lines
9.4 KiB
Python
import importlib
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import requests
|
|
from temporalio import activity
|
|
from temporalio.exceptions import ApplicationError
|
|
from jsonschema import validate, ValidationError
|
|
from temporalio.client import Client
|
|
from temporalio.worker import Worker
|
|
import time
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Automatically determine if in a test environment
|
|
IS_TEST_ENVIRONMENT = "unittest" in sys.modules
|
|
|
|
# Environment variables
|
|
REPO_NAME = os.getenv('REPO_NAME')
|
|
BRANCH_NAME = os.getenv('BRANCH_NAME')
|
|
COMMIT_ID = os.getenv('VERSION')
|
|
NAMESPACE = os.getenv('NAMESPACE')
|
|
FLOWX_ENGINE_ADDRESS = os.getenv('FLOWX_ENGINE_ADDRESS')
|
|
SQLPAD_API_URL = os.getenv('SQLPAD_API_URL')
|
|
|
|
|
|
if not BRANCH_NAME or not COMMIT_ID or not NAMESPACE or not FLOWX_ENGINE_ADDRESS:
|
|
raise ValueError("Missing required environment variables.")
|
|
|
|
COMMIT_ID_SHORT = COMMIT_ID[:10]
|
|
|
|
# Sanitize name function
|
|
def sanitize_name(name):
|
|
sanitized = re.sub(r'\W|^(?=\d)', '_', name)
|
|
sanitized = re.sub(r'_+', '_', sanitized)
|
|
return sanitized.strip('_')
|
|
|
|
BLOCK_NAME = REPO_NAME + "_" + BRANCH_NAME
|
|
block_name_safe = sanitize_name(BLOCK_NAME)
|
|
commit_id_safe = sanitize_name(COMMIT_ID_SHORT)
|
|
|
|
# Construct the task queue name
|
|
TASK_QUEUE = f"{block_name_safe}_{commit_id_safe}"
|
|
|
|
# Load JSON schema
|
|
def load_schema(schema_path):
|
|
try:
|
|
with open(schema_path, 'r') as schema_file:
|
|
return json.load(schema_file)
|
|
except Exception as e:
|
|
logger.error("Failed to load schema from %s: %s", schema_path, e)
|
|
if not IS_TEST_ENVIRONMENT:
|
|
raise ApplicationError(f"Schema loading failed: {e}")
|
|
else:
|
|
raise ValueError(f"Schema loading failed: {e}")
|
|
|
|
# Validate input against request schema
|
|
def validate_input(input_data):
|
|
request_schema = load_schema("/app/request_schema.json")
|
|
try:
|
|
validate(instance=input_data, schema=request_schema)
|
|
logger.info("Input data validated successfully")
|
|
except ValidationError as e:
|
|
logger.error("Input validation failed: %s", e)
|
|
if not IS_TEST_ENVIRONMENT:
|
|
raise ApplicationError(f"Input validation error: {e}")
|
|
else:
|
|
raise ValueError(f"Input validation error: {e}")
|
|
|
|
# Validate output against response schema
|
|
def validate_output(output_data):
|
|
response_schema = load_schema("/app/response_schema.json")
|
|
try:
|
|
validate(instance=output_data, schema=response_schema)
|
|
logger.info("Output data validated successfully")
|
|
except ValidationError as e:
|
|
logger.error("Output validation failed: %s", e)
|
|
if not IS_TEST_ENVIRONMENT:
|
|
raise ApplicationError(f"Output validation error: {e}")
|
|
else:
|
|
raise ValueError(f"Output validation error: {e}")
|
|
|
|
# Get the connection ID from config.json
|
|
def get_connection_id(namespace):
|
|
response_schema = load_schema("/app/config.json")
|
|
for item in response_schema:
|
|
if item.get("namespace") == namespace:
|
|
logger.info("Got the connectionID")
|
|
return item.get("connectionId")
|
|
logger.error("Provided Namespace not found.")
|
|
raise ValueError(f"Namespace '{namespace}' not found")
|
|
|
|
# Read SQL file and replace placeholders
|
|
def construct_sql(input_data):
|
|
try:
|
|
with open("/app/main.sql", "r") as sql_file:
|
|
sql_template = sql_file.read()
|
|
|
|
for key, value in input_data.items():
|
|
placeholder = f"${key}"
|
|
|
|
if value is None:
|
|
replacement = "NULL"
|
|
elif isinstance(value, bool):
|
|
replacement = "TRUE" if value else "FALSE"
|
|
elif isinstance(value, str):
|
|
replacement = f"'{value}'"
|
|
else:
|
|
replacement = str(value)
|
|
|
|
sql_template = sql_template.replace(placeholder, replacement)
|
|
|
|
logger.info("SQL query constructed.")
|
|
return sql_template.strip()
|
|
except Exception as e:
|
|
logger.error("Error processing SQL template: %s", e)
|
|
raise ApplicationError(f"SQL template error: {e}")
|
|
|
|
def get_batch_results(batch_id, retry_interval=0.05, max_retries=5):
|
|
retries = 0
|
|
while retries < max_retries:
|
|
try:
|
|
response = requests.get(f"{SQLPAD_API_URL}/api/batches/{batch_id}")
|
|
response.raise_for_status()
|
|
batch_status = response.json()
|
|
|
|
status = batch_status.get("status")
|
|
if status in ["finished", "error"]:
|
|
statements = batch_status.get("statements", [])
|
|
|
|
if not statements:
|
|
raise ApplicationError("No statements found in batch response.")
|
|
|
|
statement = statements[0]
|
|
statement_id = statement.get("id")
|
|
error = statement.get("error")
|
|
columns = statement.get("columns", None)
|
|
sql_text = batch_status.get("batchText", "").strip().lower()
|
|
logger.info(f"statements: {statements}")
|
|
logger.info(f"error from batches result {error}, statement: {statement_id}, columns: {columns}")
|
|
|
|
if error:
|
|
raise ApplicationError(f"SQL execution failed: {error}")
|
|
|
|
is_select_query = sql_text.startswith("select") or (
|
|
sql_text.startswith("with") and "select" in sql_text
|
|
)
|
|
if is_select_query and not columns:
|
|
raise ApplicationError("SELECT query did not return columns, cannot process data.")
|
|
|
|
return status, statement_id, error, columns, is_select_query
|
|
|
|
time.sleep(retry_interval)
|
|
retries += 1
|
|
except requests.RequestException as e:
|
|
logger.error("Failed to fetch batch results: %s", e)
|
|
raise ApplicationError(f"Failed to fetch batch results: {e}")
|
|
|
|
raise ApplicationError("SQLPad batch execution timed out.")
|
|
|
|
def execute_sqlpad_query(connection_id, sql_query):
|
|
payload = {
|
|
"connectionId": connection_id,
|
|
"name": "",
|
|
"batchText": sql_query,
|
|
"selectedText": ""
|
|
}
|
|
|
|
try:
|
|
response = requests.post(f"{SQLPAD_API_URL}/api/batches", json=payload)
|
|
response.raise_for_status()
|
|
batch_response = response.json()
|
|
|
|
batch_id = batch_response.get("statements", [{}])[0].get("batchId")
|
|
logger.info(f"Batch ID from the batches API response {batch_id}")
|
|
if not batch_id:
|
|
raise ApplicationError("Batch ID not found in SQLPad response.")
|
|
|
|
status, statement_id, error, columns, is_select_query = get_batch_results(batch_id)
|
|
if not is_select_query:
|
|
return {"status": status, "error": error}
|
|
|
|
result_response = requests.get(f"{SQLPAD_API_URL}/api/statements/{statement_id}/results")
|
|
result_response.raise_for_status()
|
|
result_data = result_response.json()
|
|
|
|
type_mapping = {
|
|
"number": float,
|
|
"string": str,
|
|
"date": str,
|
|
"boolean": bool,
|
|
"timestamp": str,
|
|
}
|
|
column_names_list = [col["name"] for col in columns]
|
|
column_types_list = [col["datatype"] for col in columns]
|
|
converted_data = [
|
|
[
|
|
type_mapping.get(dtype, str)(value) if value is not None else None
|
|
for dtype, value in zip(column_types_list, row)
|
|
]
|
|
for row in result_data
|
|
]
|
|
results_dict_list = [dict(zip(column_names_list, row)) for row in converted_data]
|
|
logger.info(f"results_dict_list: {results_dict_list}")
|
|
|
|
return {"results": results_dict_list}
|
|
|
|
except requests.RequestException as e:
|
|
logger.error("SQLPad API request failed: %s", e)
|
|
raise ApplicationError(f"SQLPad API request failed: {e}")
|
|
|
|
@activity.defn
|
|
async def block_main_activity(input_data):
|
|
validate_input(input_data)
|
|
try:
|
|
sql_query = construct_sql(input_data)
|
|
logger.info(f"constructed sql query: {sql_query}")
|
|
connection_id = get_connection_id(NAMESPACE)
|
|
if connection_id:
|
|
logger.info(f"connection id exists {connection_id}")
|
|
result = execute_sqlpad_query(connection_id, sql_query)
|
|
validate_output(result)
|
|
logger.info(f"final result for the query: {result}")
|
|
return result
|
|
else:
|
|
logger.error("connection id not exists, please add the connection id according to the namespace.")
|
|
raise ApplicationError("connection id not exists, please add the connection id according to the namespace.")
|
|
except Exception as e:
|
|
logger.error("Error executing query execution: %s", e)
|
|
if not IS_TEST_ENVIRONMENT:
|
|
raise ApplicationError(f"Error during block execution: {e}") from e
|
|
else:
|
|
raise RuntimeError("Error during query execution") from e
|
|
|
|
async def main():
|
|
try:
|
|
client = await Client.connect(FLOWX_ENGINE_ADDRESS, namespace=NAMESPACE)
|
|
worker = Worker(
|
|
client,
|
|
task_queue=TASK_QUEUE,
|
|
activities=[block_main_activity],
|
|
)
|
|
logger.info("Worker starting, listening to task queue: %s", TASK_QUEUE)
|
|
await worker.run()
|
|
except Exception as e:
|
|
logger.critical("Worker failed to start: %s", e)
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|