from fastapi import FastAPI, UploadFile, File, HTTPException, Query
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

import tempfile
import os
import logging
import json

from live_mode_example import clean_live_text, client  # your existing module


logger = logging.getLogger("vrs-app")

app = FastAPI()

# Allow browser JS to call the API
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve static files (front-end)
app.mount("/static", StaticFiles(directory="static"), name="static")


@app.get("/")
async def root():
    """Serve the main HTML page."""
    return FileResponse("static/index.html")


@app.post("/api/transcribe")
async def transcribe(
    mode: str = Query("smart", regex="^(direct|smart)$"),
    specialty: str = Query("radiology"),  # Optional, defaults to radiology
    file: UploadFile = File(...),
):
    """
    mode = "direct"  -> return raw transcript only
    mode = "smart"   -> try to return cleaned transcript using clean_live_text()
                        (falls back to raw if cleaner fails)
    """

    # 1) Save uploaded file to temp
    with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
        data = await file.read()
        tmp.write(data)
        tmp_path = tmp.name

    try:
        # 2) Transcription with OpenAI (common for both modes)
        try:
            with open(tmp_path, "rb") as f:
                transcription = client.audio.transcriptions.create(
                    model="gpt-4o-mini-transcribe",
                    file=f,
                    response_format="text",  # plain string
                    language="en",           # force English
                )
        except Exception as e:
            logger.exception("OpenAI transcription error: %s", e)
            raise HTTPException(
                status_code=500,
                detail=f"Transcription error: {e}"
            )

        raw_text = (
            transcription if isinstance(transcription, str)
            else getattr(transcription, "text", "")
        )

        if not raw_text.strip():
            raise HTTPException(
                status_code=400,
                detail="No text recognized from audio."
            )

        cleaned = None
        clean_error = None

        # 3) Optional cleaning for SMART mode
        if mode == "smart":
            try:
                cleaned = clean_live_text(
                    raw_transcript=raw_text,
                    specialty=specialty,
                    language="en",
                )
            except Exception as e:
                clean_error = str(e)
                cleaned = None
                logger.exception("clean_live_text failed: %s", e)
                # do NOT raise; we still return raw text

        # 4) Respond
        return {
            "mode": mode,
            "raw": raw_text,
            "cleaned": cleaned,
            "clean_error": clean_error,
        }

    finally:
        try:
            os.remove(tmp_path)
        except OSError:
            pass


# ============================================
# NEW ENDPOINTS: Review, Modify, Generate Report
# ============================================

# System prompt for AI tasks
AI_SYSTEM_PROMPT = """
You are an AI assistant specialized in healthcare documentation for a Voice Reporting System (VRS).

Your role depends on the task:

1. REVIEW (task="review"):
   - Clean and polish the provided medical text
   - Fix grammar, punctuation, and capitalization
   - Ensure proper medical terminology and formatting
   - Preserve ALL medical meaning, findings, diagnoses, and numbers
   - Do NOT add or remove medical information
   - Maintain the original structure and flow

2. MODIFY (task="modify"):
   - Apply the user's natural language instruction to modify the existing text
   - The instruction will describe what change to make (e.g., "add that the patient has diabetes",
     "remove the mention of chest pain", "change the date to January 15th")
   - Make the requested change while preserving the rest of the text
   - Ensure the modified text remains grammatically correct and medically accurate

3. SMART_REPORT (task="smart_report"):
   - Convert the transcription into a structured, professional medical report
   - Format according to the specified specialty's conventions
   - Include appropriate sections (e.g., History, Findings, Impression for radiology)
   - Use proper medical terminology and formatting
   - If a signature is provided, append it at the end
   - Ensure the report is ready for clinical use

Always respect medical accuracy and never invent or remove medical information.
"""


# Pydantic models for request bodies
class ReviewRequest(BaseModel):
    text: str
    specialty: str
    target: str


class ModifyRequest(BaseModel):
    text: str
    instruction: str
    specialty: str
    target: str


class ReportRequest(BaseModel):
    text: str
    specialty: str
    signature: str | None = None


def call_ai_task(
    task: str,
    specialty: str,
    input_text: str,
    instruction: str = None,
    signature: str = None,
) -> str:
    """Call OpenAI with the appropriate task and return the output text."""
    try:
        # Build user message based on task
        if task == "review":
            user_content = {
                "task": "review",
                "specialty": specialty,
                "input_text": input_text,
            }

        elif task == "modify":
            user_content = {
                "task": "modify",
                "specialty": specialty,
                "input_text": input_text,
                "instruction": instruction,
            }

        elif task == "smart_report":
            user_content = {
                "task": "smart_report",
                "specialty": specialty,
                "input_text": input_text,
                "signature": signature,
            }

        else:
            return input_text  # Unknown task, return original

        # Call OpenAI
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": AI_SYSTEM_PROMPT},
                {"role": "user", "content": json.dumps(user_content)},
            ],
            temperature=0.3,
            max_tokens=2000,
        )

        output_text = response.choices[0].message.content.strip()

        # For smart_report, append signature if provided
        if task == "smart_report" and signature and signature.strip():
            output_text += f"\n\n{signature.strip()}"

        return output_text

    except Exception as e:
        logger.exception("AI task error (%s): %s", task, e)
        # Return original text on error so UI still works
        return input_text


@app.post("/api/review")
async def review(request: ReviewRequest):
    """Review and clean medical text."""
    try:
        output_text = call_ai_task(
            task="review",
            specialty=request.specialty,
            input_text=request.text,
        )
        return {"output_text": output_text}

    except Exception as e:
        logger.exception("Review endpoint error: %s", e)
        return {"output_text": request.text}  # Return original on error


@app.post("/api/modify")
async def modify(request: ModifyRequest):
    """Modify text based on voice instruction."""
    try:
        output_text = call_ai_task(
            task="modify",
            specialty=request.specialty,
            input_text=request.text,
            instruction=request.instruction,
        )
        return {"output_text": output_text}

    except Exception as e:
        logger.exception("Modify endpoint error: %s", e)
        return {"output_text": request.text}  # Return original on error


@app.post("/api/report")
async def generate_report(request: ReportRequest):
    """Generate structured smart report from transcription."""
    try:
        output_text = call_ai_task(
            task="smart_report",
            specialty=request.specialty,
            input_text=request.text,
            signature=request.signature,
        )
        return {"output_text": output_text}

    except Exception as e:
        logger.exception("Report endpoint error: %s", e)
        return {"output_text": request.text}  # Return original on error
