"""
1: Generate a self-signed certificate with OpenSSL

ECDSA (Recommended):
    $ openssl ecparam -out key.pem -name secp256r1 -genkey
    $ openssl req -new -key key.pem -x509 -nodes -days 365 -out cert.pem

RSA:
    $ openssl req -x509 -newkey rsa:2048 -nodes -out cert.pem -keyout key.pem -days 365

2: Test the server with cURL
    $ curl --cacert cert.pem https://localhost:5000
    $ curl --cacert cert.pem https://localhost:5000/download --output download.bin
    $ curl -k -X POST --cacert cert.pem https://127.0.0.1:5000/anything -d "hello from curl"

Source(s):
 - https://blog.miguelgrinberg.com/post/running-your-flask-application-over-https
 - https://curl.se/docs/sslcerts.html

 TODO:

  - Investigate mTLS : https://en.wikipedia.org/wiki/Mutual_authentication

"""
import os
import re
import hashlib
from flask import (Flask, send_file, jsonify, request)

app = Flask(__name__)

CHUNK_SIZE = 1024 # Not used yet
FIRMWARE_PATH = "../iap_demo.img"
VERSION_HEADER_PATH = "../src/version.h"  # path to your version.h

def get_latest_firmware():
    """
    Returns a tuple (firmware_file_path, version_string)
    Reads version from version.h instead of filename.
    """
    # Ensure firmware exists
    if not os.path.isfile(FIRMWARE_PATH):
        return None, None

    # Default version if header cannot be read
    version = "1.0.0"

    if os.path.isfile(VERSION_HEADER_PATH):
        with open(VERSION_HEADER_PATH, "r") as f:
            for line in f:
                # Match: #define APP_VERSION_STRING "1.0.0"
                match = re.match(r'#define\s+APP_VERSION_STRING\s+"([^"]+)"', line)
                if match:
                    version = match.group(1)
                    break

    return FIRMWARE_PATH, version

def calculate_sha256(file_path):
    """Return SHA-256 checksum of a file as a hex string."""
    sha256 = hashlib.sha256()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            sha256.update(chunk)
    return sha256.hexdigest()

"""
Expects from device :
{
  "device_id": "STM32-ABC123",
  "fw_version": "1.0.0"
}
"""
@app.route("/check-update", methods=["POST"])
def check_update():
    data = request.get_json(silent=True) or {}
    device_id = data.get("device_id")
    current_version = data.get("fw_version")

    print(f">>> check-update from {device_id}, fw={current_version}")

    firmware_file, latest_version = get_latest_firmware()
    if not firmware_file:
        return jsonify({"error": "Firmware not found"}), 404

    # No update needed
    print(f"version available in server: {latest_version}")
    if current_version == latest_version:
        return jsonify({"update_available": False}), 200

    # Compute size and checksum
    firmware_size = os.path.getsize(FIRMWARE_PATH)
    firmware_checksum = calculate_sha256(FIRMWARE_PATH)

    return jsonify({
        "update_available": True,
        "version": latest_version,
        "size": firmware_size,
        "checksum": f"sha256:{firmware_checksum}",
        "download_url": "/firmware",
        "chunk_size": CHUNK_SIZE
    }), 200

"""
Direct data streaming, without chunking
"""
@app.route("/firmware", methods=["GET"])
def firmware():
    path = FIRMWARE_PATH

    # Make sure the file exists
    if not os.path.exists(path):
        return "Firmware not found", 404

    return send_file(
        path,
        mimetype="application/octet-stream",
        as_attachment=True,
        conditional=True,
        etag=None
    )

@app.route("/firmware/info", methods=["GET"])
def firmware_info():
    size = os.path.getsize(FIRMWARE_PATH)

    return {
        "size": size,
        "chunk_size": CHUNK_SIZE,
        "checksum": "sha256:todo",
        "version": "1.1.0"
    }, 200

"""
Expects:
{
  "device_id": "STM32-ABC123",
  "fw_version": "1.1.0",
  "result": "success"
}
"""
@app.route("/update-status", methods=["POST"])
def update_status():
    data = request.get_json(silent=True) or {}

    print(">>> update-status")
    print(data)

    return {"status": "received"}, 200

########################################################################################################################
########################################################################################################################
########################################################################################################################

"""
Not implemented in device.
We need to do something like: 

chunk = 0;
uint8_t *chuck_data;
while (download_chunk(chunk, chunk_data)) {
    updateProcess(chunk_data);
    chunk++;
}

"""
@app.route("/firmware/chunk/<int:index>", methods=["GET"])
def firmware_chunk(index):
    try:
        with open(FIRMWARE_PATH, "rb") as f:
            f.seek(index * CHUNK_SIZE)
            data = f.read(CHUNK_SIZE)

        if not data:
            return "", 404

        return data, 200, {
            "Content-Type": "application/octet-stream",
            "X-Chunk-Index": str(index)
        }

    except FileNotFoundError:
        return {"error": "firmware not found"}, 404

"""
Not implemented on device
Byte based access, instead of index based
"""
@app.route("/firmware/range", methods=["GET"])
def firmware_range():
    try:
        offset = int(request.args.get("offset", 0))
        length = int(request.args.get("length", CHUNK_SIZE))

        with open(FIRMWARE_PATH, "rb") as f:
            f.seek(offset)
            data = f.read(length)

        if not data:
            return "", 416  # Range Not Satisfiable

        return data, 200, {
            "Content-Type": "application/octet-stream",
            "Content-Range": f"bytes {offset}-{offset+len(data)-1}"
        }

    except Exception as e:
        return {"error": str(e)}, 400

"""
Not implemented on device
"""
@app.route("/firmware/raw", methods=["GET"])
def firmware_raw():
    range_header = request.headers.get("Range")

    if not range_header:
        return send_file(FIRMWARE_PATH, as_attachment=True)

    # Skeleton only — implement later
    return {"error": "Range not implemented yet"}, 501


if __name__ == "__main__":
    app.run(host="0.0.0.0", ssl_context=('cert.pem','key.pem'))