#!/usr/bin/env python3
"""Transcribe audio with Gemini 2.5 Pro + speaker diarization.

Usage:
    python3 transcribe_gemini.py /path/to/audio.m4a

Output: /path/to/audio_transcript_gemini.txt
"""

import sys
import os
import json
import time
import requests

# Load API key
ENV_FILE = os.path.expanduser("~/VSCodeProjects/62a-provision/.openai.env")
def load_env(path):
    env = {}
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith("#") and "=" in line:
                k, v = line.split("=", 1)
                env[k.strip()] = v.strip().strip('"').strip("'")
    return env

env = load_env(ENV_FILE)
API_KEY = env["GEMINI_API_KEY"]
BASE_URL = "https://generativelanguage.googleapis.com"

audio_path = sys.argv[1] if len(sys.argv) > 1 else os.path.expanduser("~/Downloads/New_Recording_4.m4a")
output_path = os.path.splitext(audio_path)[0] + "_transcript_gemini.txt"

# --- Step 1: Upload file ---
print(f"[1/3] Uploading {os.path.basename(audio_path)} ({os.path.getsize(audio_path)/1024/1024:.1f} MB)...")
file_size = os.path.getsize(audio_path)
mime = "audio/mp4"

with open(audio_path, "rb") as f:
    resp = requests.post(
        f"{BASE_URL}/upload/v1beta/files?key={API_KEY}",
        headers={
            "X-Goog-Upload-Command": "start, upload, finalize",
            "X-Goog-Upload-Header-Content-Length": str(file_size),
            "X-Goog-Upload-Header-Content-Type": mime,
            "Content-Type": mime,
        },
        data=f,
    )
resp.raise_for_status()
file_info = resp.json()["file"]
file_uri = file_info["uri"]
print(f"    Uploaded: {file_info['name']} (state={file_info['state']})")

# --- Step 2: Wait for file to become ACTIVE ---
for _ in range(30):
    if file_info["state"] == "ACTIVE":
        break
    print("    Waiting for file processing...")
    time.sleep(5)
    r = requests.get(f"{BASE_URL}/v1beta/{file_info['name']}?key={API_KEY}")
    file_info = r.json()

# --- Step 3: Transcribe with Gemini 2.5 Pro ---
print("[2/3] Transcribing with Gemini 2.5 Pro (this may take a few minutes)...")

prompt = (
    "Transcribe this entire Vietnamese audio recording with speaker diarization.\n\n"
    "This is a family conversation. The speakers are:\n"
    "- [Duyên]: Con dâu (daughter-in-law), young woman, speaks respectfully but defensively\n"
    "- [Mẹ San]: Mẹ chồng (mother-in-law), older woman, speaks with authority and emotion\n"
    "- [Anh Hào]: Chồng (husband/son), male voice, often mediates saying 'nói bé' (speak softly)\n"
    "- [Ông nội]: Ông nội (grandfather), elderly male, says things like 'ông bế ông bế'\n\n"
    "Output format: one line per utterance, with speaker name in brackets:\n"
    "[Duyên]: text here\n"
    "[Mẹ San]: text here\n"
    "[Anh Hào]: text here\n"
    "[Ông nội]: text here\n\n"
    "Rules:\n"
    "- Include ALL speech from beginning to end, do not skip or summarize\n"
    "- If speakers talk over each other, transcribe both\n"
    "- Include filler words, repetitions, hesitations as spoken\n"
    "- Keep consistent speaker labels using the exact names above\n"
    "- Output ONLY the transcript lines, no commentary\n"
    "- Each new utterance or turn gets its own line\n"
    "- IMPORTANT: Do NOT repeat lines. If you find yourself outputting the same block of lines again, STOP immediately"
)

model = "gemini-2.5-pro"
resp = requests.post(
    f"{BASE_URL}/v1beta/models/{model}:generateContent?key={API_KEY}",
    headers={"Content-Type": "application/json"},
    json={
        "contents": [{
            "parts": [
                {"fileData": {"mimeType": mime, "fileUri": file_uri}},
                {"text": prompt},
            ]
        }],
        "generationConfig": {
            "temperature": 0.1,
            "maxOutputTokens": 65536,
            "thinkingConfig": {"thinkingBudget": 1024},
        },
    },
    timeout=600,
)
if resp.status_code != 200:
    print(f"ERROR {resp.status_code}: {resp.text[:1000]}")
    sys.exit(1)
data = json.loads(resp.text, strict=False)

# --- Parse response ---
candidates = data.get("candidates", [])
if not candidates:
    print("ERROR: No candidates in response")
    print(json.dumps(data, indent=2, ensure_ascii=False)[:2000])
    sys.exit(1)

candidate = candidates[0]
finish_reason = candidate.get("finishReason", "N/A")
parts = candidate.get("content", {}).get("parts", [])

text_parts = [p["text"] for p in parts if "text" in p]
if not text_parts:
    print(f"ERROR: No text in response (finishReason={finish_reason})")
    usage = data.get("usageMetadata", {})
    print(f"  Tokens: prompt={usage.get('promptTokenCount')}, "
          f"thoughts={usage.get('thoughtsTokenCount', 0)}, "
          f"total={usage.get('totalTokenCount')}")
    sys.exit(1)

transcript = "\n".join(text_parts)

# --- Post-process: trim repetition loops ---
raw_lines = transcript.split("\n")
cleaned = []
seen_blocks = set()
block_size = 7  # detect repeating blocks of this size
for i, line in enumerate(raw_lines):
    # Build a block fingerprint from current + next lines
    block = tuple(l.strip() for l in raw_lines[i:i+block_size] if l.strip())
    if len(block) == block_size and block in seen_blocks:
        print(f"  Trimmed repetition starting at line {i+1} ({len(raw_lines)-i} lines removed)")
        break
    if len(block) == block_size:
        seen_blocks.add(block)
    cleaned.append(line)
transcript = "\n".join(cleaned)

# If finish reason is MAX_TOKENS, warn but still save
if finish_reason == "MAX_TOKENS":
    print(f"  WARNING: Output was truncated (MAX_TOKENS). Transcript may be incomplete.")

# --- Step 4: Save ---
with open(output_path, "w", encoding="utf-8") as f:
    f.write(transcript)

lines = transcript.count("\n") + 1
print(f"[3/3] Saved to {output_path}")
print(f"    {len(transcript)} chars, {lines} lines, finishReason={finish_reason}")

usage = data.get("usageMetadata", {})
print(f"    Tokens: prompt={usage.get('promptTokenCount')}, "
      f"output={usage.get('candidatesTokenCount', 'N/A')}, "
      f"thoughts={usage.get('thoughtsTokenCount', 0)}")
