# main.py
import crafter
import numpy as np
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from io import BytesIO
from PIL import Image
import base64
import json

app = FastAPI()

# Allow frontend connection
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve frontend files
app.mount("/static", StaticFiles(directory="frontend"), name="static")

@app.get("/")
async def index():
    return FileResponse("frontend/index.html")

# Start environment
env = crafter.Env()


obs = env.reset()

# Action mapping
ACTION_MAP = {
    "left": 1,
    "right": 2,
    "up": 3,
    "down": 4,
    "jump": 5,
    "attack": 6
}

ACTION_MAP = {
    "left": 0,
    "right": 1,
    "up": 2,
    "down": 3,
    "jump": 4,
    "inventory": 5,
    "attack": 6,
}

def decode_action(action_str):
    if action_str in env.action_names:
        return env.action_names.index(action_str)
    return env.action_names.index('noop')  # fallback



def render_obs_raw(obs):
    try:
        if obs.dtype != np.uint8:
            obs = (obs * 255).astype(np.uint8)

        # Flatten RGB array to raw bytes and base64 encode
        obs = env.render(size=(512, 512))  # ⬅️ higher-quality render

        raw_bytes = obs.tobytes()

        return base64.b64encode(raw_bytes).decode("utf-8")
    except Exception as e:
        print("❌ Raw render error:", e)
        return ""



@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    global obs
    done = False
    print("✅ WebSocket connected.")

    def send_frame():
        raw = render_obs_raw(obs)
        websocket_data = {
            "frame_raw": raw,
            "reward": 0.0,
            "done": False
        }
        return websocket_data

    # Send first frame
    await websocket.send_text(json.dumps(send_frame()))

    while True:
        try:
            data = await websocket.receive_text()
            print("👉 Received action:", data)
            action = decode_action(data)
            obs, reward, done, info = env.step(action)

            if done:
                print("🔁 Episode done. Resetting env.")
                obs = env.reset()

            raw = render_obs_raw(obs)
            await websocket.send_text(json.dumps({
                "frame_raw": raw,
                "reward": float(reward),
                "done": done
            }))

        except Exception as e:
            print("❌ Connection closed or error:", e)
            break