from dataiku.llm.python import BaseLLM
import dataiku
import json
import logging
import urllib.error
import urllib.request


logger = logging.getLogger(__name__)

BLOCKED_MESSAGE = (
    "I cannot complete this request because it was blocked by the configured "
    "agentic AI protection policy."
)

CONFIGURATION_FAILURE_MESSAGE = (
    "This assistant is not fully configured. Ask a project administrator to set "
    "the required Radware example project variables."
)

DEFAULT_USER_IDENTIFIER = "dataiku-user"
DEFAULT_RADWARE_TIMEOUT_SECONDS = 5

TOOLS_INPUT = [
    {
        "type": "function",
        "name": "llm_completion",
        "description": "Generate a response or a tool plan with the configured LLM.",
        "parameters": {
            "type": "object",
            "properties": {
                "prompt": {
                    "type": "string",
                    "description": "The user request sent to the LLM.",
                }
            },
            "required": ["prompt"],
        },
    },
    {
        "type": "function",
        "name": "order_lookup",
        "description": "Look up example orders by status.",
        "parameters": {
            "type": "object",
            "properties": {
                "filter": {
                    "type": "object",
                    "properties": {
                        "column": {"type": "string", "enum": ["Status"]},
                        "operator": {"type": "string", "enum": ["EQUALS"]},
                        "value": {"type": "string"},
                    },
                    "required": ["column", "operator", "value"],
                }
            },
            "required": ["filter"],
        },
    },
]


class RadwareProtectedOrderAgent(BaseLLM):
    dependencies = [{"type": "DATASET", "ref": "orders"}]

    def __init__(self):
        self.project = dataiku.api_client().get_default_project()

    def process(self, query, settings, trace):
        user_prompt = self.get_latest_user_message(query)
        variables = dataiku.get_custom_variables() or {}

        radware_url = variables.get("RADWARE_URL")
        radware_api_key = variables.get("RADWARE_API_KEY")
        user_identifier = variables.get(
            "RADWARE_USER_IDENTIFIER", DEFAULT_USER_IDENTIFIER
        )
        llm_id = variables.get("RADWARE_DSS_LLM_ID")
        model_to_use = variables.get("RADWARE_MODEL_NAME")
        order_lookup_tool_id = variables.get("RADWARE_ORDER_LOOKUP_TOOL_ID")

        required_variables = {
            "RADWARE_URL": radware_url,
            "RADWARE_API_KEY": radware_api_key,
            "RADWARE_DSS_LLM_ID": llm_id,
            "RADWARE_MODEL_NAME": model_to_use,
            "RADWARE_ORDER_LOOKUP_TOOL_ID": order_lookup_tool_id,
        }
        missing_variables = [
            name for name, value in required_variables.items() if not value
        ]
        if missing_variables:
            logger.error(
                "Missing required Radware example variables: %s",
                ", ".join(missing_variables),
            )
            return {"text": CONFIGURATION_FAILURE_MESSAGE}

        llm = self.project.get_llm(llm_id)
        order_lookup_tool = self.project.get_agent_tool(order_lookup_tool_id)

        user_context = json.dumps(
            {
                "agent": "radware_protected_order_agent",
                "projectKey": self.project.project_key,
            }
        )

        llm_arguments = {"prompt": user_prompt}
        # Protection point 1: check the request before asking the LLM to plan.
        if not self.radware_allows(
            radware_url=radware_url,
            radware_api_key=radware_api_key,
            user_prompt=user_prompt,
            user_context=user_context,
            tool_name="llm_completion",
            args_input=llm_arguments,
            user_identifier=user_identifier,
            model_to_use=model_to_use,
        ):
            return {"text": BLOCKED_MESSAGE}

        plan = self.call_llm_for_plan(llm, user_prompt, trace)

        if plan.get("tool") != "order_lookup":
            return {"text": plan.get("answer", "I do not know how to answer that.")}

        tool_arguments = plan.get("arguments")
        if not isinstance(tool_arguments, dict):
            return {
                "text": "I could not safely determine how to call the order lookup tool."
            }

        # Protection point 2: check the selected tool call before execution.
        if not self.radware_allows(
            radware_url=radware_url,
            radware_api_key=radware_api_key,
            user_prompt=user_prompt,
            user_context=user_context,
            tool_name="order_lookup",
            args_input=tool_arguments,
            user_identifier=user_identifier,
            model_to_use=model_to_use,
        ):
            return {"text": BLOCKED_MESSAGE}

        tool_result = order_lookup_tool.run(tool_arguments)
        return {"text": self.format_order_lookup_result(tool_result)}

    def radware_allows(
        self,
        radware_url,
        radware_api_key,
        user_prompt,
        user_context,
        tool_name,
        args_input,
        user_identifier,
        model_to_use,
    ):
        payload = {
            "ApiKey": radware_api_key,
            "UserPrompt": user_prompt,
            "UserContext": user_context,
            "ToolName": tool_name,
            "ArgsInput": args_input,
            "ToolsInput": TOOLS_INPUT,
            "UserIdentifier": user_identifier,
            "ModelToUse": model_to_use,
        }

        request = urllib.request.Request(
            radware_url,
            data=json.dumps(payload).encode("utf-8"),
            headers={"Content-Type": "application/json", "Accept": "application/json"},
            method="POST",
        )

        try:
            with urllib.request.urlopen(
                request, timeout=DEFAULT_RADWARE_TIMEOUT_SECONDS
            ) as response:
                result = json.loads(response.read().decode("utf-8"))
        except urllib.error.HTTPError as exc:
            body = exc.read().decode("utf-8", errors="replace")
            logger.exception(
                "Radware protection check failed with HTTP %s: %s",
                exc.code,
                self.redact_secret(body, radware_api_key)[:500],
            )
            return False
        except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
            logger.exception("Radware protection check failed: %s", exc)
            # This example uses fail-close behavior. Change this only after
            # explicitly accepting the risk of running tool calls without a decision.
            return False

        is_blocked = result.get("IsBlocked", True)
        event_id = result.get("EventId")
        logger.info(
            "Radware decision for %s: blocked=%s event_id=%s",
            tool_name,
            is_blocked,
            event_id,
        )
        return not is_blocked

    def redact_secret(self, value, secret):
        if not value or not secret:
            return value
        return value.replace(secret, "<redacted>")

    def call_llm_for_plan(self, llm, user_prompt, trace):
        completion = llm.new_completion()
        completion.with_message(
            "You are a planning assistant for an example agent. "
            "Return only JSON. "
            "If the user asks for orders by status, extract the requested status "
            "from the user request, uppercase it, and return JSON in this exact shape: "
            '{"tool": "order_lookup", "arguments": {"filter": {"column": "Status", "operator": "EQUALS", "value": "<REQUESTED_STATUS>"}}}. '
            "For example, if the user asks for shipped orders, use "
            '{"tool": "order_lookup", "arguments": {"filter": {"column": "Status", "operator": "EQUALS", "value": "SHIPPED"}}}. '
            "If the user asks for orders but does not specify a status, ask them to provide a status. "
            "For other requests, return JSON in this exact shape: "
            '{"answer": "brief answer text"}.',
            role="system",
        )
        completion.with_message(user_prompt, role="user")
        response = completion.execute()

        try:
            return json.loads(response.text)
        except json.JSONDecodeError:
            logger.warning("LLM planner did not return valid JSON: %s", response.text)
            return {"answer": response.text}

    def format_order_lookup_result(self, tool_result):
        if not isinstance(tool_result, dict):
            return "The order lookup tool returned an unexpected result."

        output = tool_result.get("output", tool_result)
        if isinstance(output, str):
            return output
        if not isinstance(output, dict):
            return "The order lookup tool returned an unexpected result."

        rows = output.get("rows", [])
        if not isinstance(rows, list):
            return "The order lookup tool returned rows in an unexpected format."

        message = output.get("message") or f"Found {len(rows)} order(s)."

        if not rows:
            return message

        lines = [message]
        for row in rows:
            lines.append(
                "- Order {order_id}: {customer_name}; status={status}; total={total}; region={region}".format(
                    order_id=row.get("OrderID", "unknown"),
                    customer_name=row.get("CustomerName", "unknown customer"),
                    status=row.get("Status", "unknown"),
                    total=row.get("Total", "unknown"),
                    region=row.get("Region", "unknown"),
                )
            )
        return "\n".join(lines)

    def get_latest_user_message(self, query):
        messages = query.get("messages", [])
        for message in reversed(messages):
            if message.get("role") == "user":
                return message.get("content", "")
        return query.get("prompt", "")
