import os, sys, time, queue, threading, subprocess, collections
import numpy as np
import sounddevice as sd
import webrtcvad
from openwakeword.model import Model
from openai import OpenAI

# ========= CONFIG =========
WAKEWORD_NAME = "hey_birdie"           # logical label used for thresholds/logging
WAKEWORD_THRESHOLD = 0.5                # tune 0.3..0.7 (lower = more sensitive)
COOLDOWN_SEC = 1.5                      # avoid multiple triggers per utterance

SAMPLE_RATE = 16000
FRAME_MS = 20                           # 20ms frames for VAD and wakeword
CHANNELS = 1
INPUT_BLOCK_FRAMES = int(SAMPLE_RATE * FRAME_MS / 1000)

MAX_RECORD_SEC = 15                     # hard cap per utterance
SILENCE_HANG_MS = 800                   # stop after this long of silence
VAD_AGGRESSIVENESS = 2                  # 0..3 (higher = more strict "voice only")

CHAT_MODEL = "gpt-4o-mini"
STT_MODEL  = "whisper-1"                # or "gpt-4o-mini-transcribe" if you have it
TTS_MODEL  = "gpt-4o-mini-tts"
TTS_VOICE  = "alloy"
OUTPUT_MP3 = "reply.mp3"
INPUT_WAV  = "utterance.wav"

PERSONA = (
  "You are Birdie, a friendly, whimsical bird perched among plush birds. "
  "Greet people warmly, keep responses concise, curious, and a little playful. "
  "Avoid sensitive topics; keep conversations light and fun."
)

# ========= INIT =========
if "OPENAI_API_KEY" not in os.environ:
    print("ERROR: OPENAI_API_KEY not set.")
    sys.exit(1)

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

audio_q = queue.Queue()   # mic frames (int16) from callback
stop_flag = threading.Event()

# Wake word model (bundled general model)
oww = Model()  # downloads weights on first run to ~/.cache/openwakeword

vad = webrtcvad.Vad(VAD_AGGRESSIVENESS)

def int16_to_float32(x):
    return (x / 32768.0).astype(np.float32)

def is_speech(frame_bytes):
    # frame_bytes: 16-bit mono PCM
    return vad.is_speech(frame_bytes, SAMPLE_RATE)

def mic_callback(indata, frames, time_info, status):
    if status:
        # print(status)  # uncomment for debugging
        pass
    # indata: float32 [-1,1], convert to int16 PCM
    pcm16 = (indata[:,0] * 32767.0).astype(np.int16).tobytes()
    audio_q.put(pcm16)

def start_mic():
    stream = sd.InputStream(
        samplerate=SAMPLE_RATE,
        channels=CHANNELS,
        dtype='float32',
        blocksize=INPUT_BLOCK_FRAMES,
        callback=mic_callback,
        latency='low'
    )
    stream.start()
    return stream

def detect_wakeword_loop():
    """Continuously examine mic frames for wake word; trigger recording when detected."""
    last_trigger = 0
    ring = collections.deque(maxlen=50)  # store last ~1s for context if desired

    while not stop_flag.is_set():
        try:
            frame = audio_q.get(timeout=0.5)  # bytes int16 mono
        except queue.Empty:
            continue

        ring.append(frame)

        # For openwakeword, we need float32 numpy
        # Concatenate if model prefers larger chunks; here we feed frame-by-frame
        pcm16 = np.frombuffer(frame, dtype=np.int16)
        f32 = int16_to_float32(pcm16)
        scores = oww.predict(np.expand_dims(f32, axis=0))  # returns dict of {keyword: score}

        # If any keyword crosses threshold, fire
        # Default openWakeWord model exports names like "hey_jarvis", we'll just check max score
        kw, score = max(scores.items(), key=lambda kv: kv[1])

        now = time.time()
        if score >= WAKEWORD_THRESHOLD and (now - last_trigger) > COOLDOWN_SEC:
            last_trigger = now
            print(f"🔔 Wake word detected ({kw}={score:.2f}). Listening…")
            record_until_silence()

def record_until_silence():
    """Record from the live mic queue until VAD says we've hit trailing silence."""
    frames = []
    voiced_recent = 0
    start_t = time.time()
    max_frames = int(MAX_RECORD_SEC * 1000 / FRAME_MS)
    silence_frames_needed = int(SILENCE_HANG_MS / FRAME_MS)

    # Prime: small beep? uncomment if you want
    # subprocess.run(["speaker-test","-t","sine","-f","1000","-l","1","-p","10"])

    # read live frames
    while len(frames) < max_frames:
        try:
            frame = audio_q.get(timeout=1.0)
        except queue.Empty:
            continue

        frames.append(frame)
        if is_speech(frame):
            voiced_recent = 0
        else:
            voiced_recent += 1

        # stop after trailing silence
        if voiced_recent >= silence_frames_needed and len(frames) > 5:
            break

    # Save utterance to WAV (16k mono)
    pcm = b"".join(frames)
    # write WAV header quickly (wavio is simple)
    import wavio, numpy as np
    data = np.frombuffer(pcm, dtype=np.int16).reshape(-1, 1)
    wavio.write(INPUT_WAV, data, SAMPLE_RATE, sampwidth=2)

    print(f"🎙️ Captured {len(frames)} frames (~{len(frames)*FRAME_MS/1000:.1f}s). Transcribing...")
    handle_request(INPUT_WAV)

def stt_transcribe(filepath):
    with open(filepath, "rb") as f:
        tx = client.audio.transcriptions.create(model=STT_MODEL, file=f)
    return tx.text.strip()

def chat_reply(user_text):
    resp = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=[
            {"role": "system", "content": PERSONA},
            {"role": "user", "content": user_text}
        ]
    )
    return resp.choices[0].message.content.strip()

def tts_speak(text, out_mp3=OUTPUT_MP3):
    speech = client.audio.speech.create(
        model=TTS_MODEL,
        voice=TTS_VOICE,
        input=text
    )
    speech.stream_to_file(out_mp3)
    # Play reliably on Linux
    try:
        subprocess.run(["mpg123", "-q", out_mp3], check=True)
    except Exception:
        subprocess.run(["ffplay", "-nodisp", "-autoexit", out_mp3], check=False)

def handle_request(wav_path):
    try:
        user_text = stt_transcribe(wav_path)
    except Exception as e:
        print("STT error:", e)
        return

    if not user_text:
        print("(Heard nothing.)")
        return

    print(f"🗣️  You: {user_text}")
    try:
        reply = chat_reply(user_text)
    except Exception as e:
        print("Chat error:", e)
        return

    print(f"🤖 Birdie: {reply}\n🔊 Speaking…")
    try:
        tts_speak(reply)
    except Exception as e:
        print("TTS error:", e)

def main():
    print("Birdie is perched and listening. Say the wake word (e.g., 'Hey Birdie'). Ctrl+C to stop.")
    stream = start_mic()
    try:
        detect_wakeword_loop()
    except KeyboardInterrupt:
        pass
    finally:
        stop_flag.set()
        stream.stop()
        stream.close()
        print("Goodbye.")

if __name__ == "__main__":
    main()
    
