You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

386 lines
13 KiB

#!/usr/bin/env python3
"""
DaSiWa I2V/FLF2V API Server для ComfyUI.
Работает рядом с ComfyUI на той же машине.
Принимает HTTP запросы с HMAC авторизацией,
отправляет workflow в ComfyUI, возвращает видео.
"""
import os
import sys
import json
import uuid
import time
import shutil
import base64
import random
import logging
import binascii
import subprocess
import urllib.request
import urllib.parse
import websocket as ws_client
from flask import Flask, request, jsonify
from hmac_auth import verify_request
# ============================================================================
# Конфигурация
# ============================================================================
COMFY_HOST = os.getenv("COMFY_HOST", "127.0.0.1")
COMFY_PORT = os.getenv("COMFY_PORT", "8188")
API_PORT = int(os.getenv("API_PORT", "8080"))
WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workflow_api.json")
KEYS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "keys.json")
COMFY_OUTPUT_DIR = os.getenv("COMFY_OUTPUT_DIR", "/ComfyUI/output")
# ============================================================================
# Инициализация
# ============================================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Загрузка ключей
if not os.path.exists(KEYS_FILE):
logger.error(f"❌ Файл ключей не найден: {KEYS_FILE}")
logger.error(" Запусти: python generate_keys.py")
sys.exit(1)
with open(KEYS_FILE, "r") as f:
keys = json.load(f)
CLIENT_ID = keys["client_id"]
SECRET_KEY = keys["secret_key"]
logger.info(f"🔐 Ключи загружены. Client ID: {CLIENT_ID}")
# Множество использованных nonce (защита от replay-атак)
used_nonces = set()
# WebSocket client ID
ws_client_id = str(uuid.uuid4())
# ============================================================================
# Утилиты
# ============================================================================
def to_nearest_multiple_of_16(value):
"""Округляет значение до ближайшего кратного 16."""
try:
numeric_value = float(value)
except Exception:
raise ValueError(f"width/height value is not a number: {value}")
adjusted = int(round(numeric_value / 16.0) * 16)
return max(adjusted, 16)
def save_base64_to_file(base64_data, temp_dir, filename):
"""Сохраняет base64 данные в файл."""
decoded = base64.b64decode(base64_data)
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.abspath(os.path.join(temp_dir, filename))
with open(file_path, "wb") as f:
f.write(decoded)
return file_path
def download_file(url, temp_dir, filename):
"""Скачивает файл по URL."""
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.abspath(os.path.join(temp_dir, filename))
result = subprocess.run(
["wget", "-O", file_path, "--no-verbose", url],
capture_output=True, text=True, timeout=120
)
if result.returncode != 0:
raise RuntimeError(f"Download failed: {result.stderr}")
return file_path
def process_image_input(job_input, prefix, temp_dir):
"""
Обрабатывает входные данные изображения.
prefix: "image" или "last_image"
Возвращает (file_path, True) или (None, False)
"""
path_key = f"{prefix}_path"
url_key = f"{prefix}_url"
b64_key = f"{prefix}_base64"
if path_key in job_input and job_input[path_key]:
return job_input[path_key], True
elif url_key in job_input and job_input[url_key]:
return download_file(job_input[url_key], temp_dir, f"{prefix}.png"), True
elif b64_key in job_input and job_input[b64_key]:
return save_base64_to_file(job_input[b64_key], temp_dir, f"{prefix}.png"), True
return None, False
def queue_prompt(prompt):
"""Отправляет prompt в ComfyUI."""
url = f"http://{COMFY_HOST}:{COMFY_PORT}/prompt"
data = json.dumps({"prompt": prompt, "client_id": ws_client_id}).encode("utf-8")
req = urllib.request.Request(url, data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_history(prompt_id):
"""Получает историю выполнения prompt."""
url = f"http://{COMFY_HOST}:{COMFY_PORT}/history/{prompt_id}"
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def generate_video(prompt):
"""Подключается к ComfyUI по WebSocket, запускает генерацию, ждёт результат."""
ws_url = f"ws://{COMFY_HOST}:{COMFY_PORT}/ws?clientId={ws_client_id}"
ws = ws_client.WebSocket()
ws.connect(ws_url)
logger.info("🔌 WebSocket подключён к ComfyUI")
prompt_id = queue_prompt(prompt)["prompt_id"]
logger.info(f"📤 Prompt отправлен: {prompt_id}")
# Ждём завершения
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break
ws.close()
logger.info("✅ Генерация завершена")
# Извлекаем видео
history = get_history(prompt_id)[prompt_id]
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "gifs" in node_output:
for video in node_output["gifs"]:
video_path = video["fullpath"]
with open(video_path, "rb") as f:
video_b64 = base64.b64encode(f.read()).decode("utf-8")
# Очистка
try:
os.remove(video_path)
except OSError:
pass
return video_b64
return None
# ============================================================================
# API Endpoints
# ============================================================================
@app.before_request
def check_hmac_auth():
"""Проверяет HMAC подпись для всех запросов кроме health check."""
if request.path == "/health":
return None
body = request.get_data()
headers = {
"X-Client-Id": request.headers.get("X-Client-Id", ""),
"X-Timestamp": request.headers.get("X-Timestamp", ""),
"X-Nonce": request.headers.get("X-Nonce", ""),
"X-Signature": request.headers.get("X-Signature", ""),
}
is_valid, error = verify_request(body, headers, SECRET_KEY, CLIENT_ID, used_nonces)
if not is_valid:
logger.warning(f"🚫 Auth failed: {error} from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "detail": error}), 401
@app.route("/health", methods=["GET"])
def health():
"""Health check — без авторизации."""
try:
url = f"http://{COMFY_HOST}:{COMFY_PORT}/"
urllib.request.urlopen(url, timeout=5)
comfy_status = "ok"
except Exception:
comfy_status = "unavailable"
return jsonify({
"status": "ok",
"comfyui": comfy_status,
"timestamp": int(time.time())
})
@app.route("/generate", methods=["POST"])
def generate():
"""Основной endpoint для генерации видео."""
start_time = time.time()
job_input = request.json or {}
logger.info("=" * 60)
logger.info("🎬 Новый запрос на генерацию")
logger.info("=" * 60)
# Логирование (без base64 данных)
log_input = {k: v for k, v in job_input.items()
if not k.endswith("_base64")}
logger.info(f"Параметры: {json.dumps(log_input, ensure_ascii=False)}")
task_id = f"task_{uuid.uuid4().hex[:8]}"
temp_dir = os.path.join("/tmp", task_id)
try:
# === Обработка изображений ===
image_path, has_image = process_image_input(job_input, "image", temp_dir)
if not has_image:
return jsonify({"error": "No input image provided. Use image_base64, image_url, or image_path"}), 400
last_image_path, use_flf2v = process_image_input(job_input, "last_image", temp_dir)
mode = "FLF2V" if use_flf2v else "I2V"
logger.info(f"🎬 Режим: {mode}")
# === Загрузка workflow ===
if not os.path.exists(WORKFLOW_FILE):
return jsonify({"error": f"Workflow file not found: {WORKFLOW_FILE}"}), 500
with open(WORKFLOW_FILE, "r") as f:
prompt = json.load(f)
# === Параметры генерации ===
width = to_nearest_multiple_of_16(job_input.get("width", 528))
height = to_nearest_multiple_of_16(job_input.get("height", 768))
length = job_input.get("length", 81)
steps = job_input.get("steps", 4)
cfg = job_input.get("cfg", 1.0)
seed = job_input.get("seed", -1)
fps = job_input.get("fps", 16)
sampler_name = job_input.get("sampler_name", "euler")
scheduler = job_input.get("scheduler", "linear_quadratic")
if seed == -1:
seed = random.randint(0, 2**63 - 1)
logger.info(f"📐 {width}x{height}, {length} frames, {steps} steps, CFG {cfg}, seed {seed}")
# === Заполнение workflow ===
# Positive prompt
prompt["5"]["inputs"]["text"] = job_input.get("prompt", "")
# Negative prompt
negative_prompt = job_input.get("negative_prompt", prompt["6"]["inputs"]["text"])
prompt["6"]["inputs"]["text"] = negative_prompt
# First frame image
prompt["7"]["inputs"]["image"] = image_path
# FLF2V / I2V mode
if use_flf2v and last_image_path:
prompt["15"]["inputs"]["image"] = last_image_path
logger.info(f"🎬 FLF2V: last frame = {last_image_path}")
else:
prompt["8"]["class_type"] = "WanImageToVideo"
if "end_image" in prompt["8"]["inputs"]:
del prompt["8"]["inputs"]["end_image"]
if "15" in prompt:
del prompt["15"]
logger.info("🎬 I2V: single image mode")
# Video dimensions
prompt["8"]["inputs"]["width"] = width
prompt["8"]["inputs"]["height"] = height
prompt["8"]["inputs"]["length"] = length
# KSampler High
prompt["11"]["inputs"]["noise_seed"] = seed
prompt["11"]["inputs"]["steps"] = steps
prompt["11"]["inputs"]["cfg"] = cfg
prompt["11"]["inputs"]["sampler_name"] = sampler_name
prompt["11"]["inputs"]["scheduler"] = scheduler
prompt["11"]["inputs"]["end_at_step"] = steps // 2
# KSampler Low
prompt["12"]["inputs"]["noise_seed"] = seed
prompt["12"]["inputs"]["steps"] = steps
prompt["12"]["inputs"]["cfg"] = cfg
prompt["12"]["inputs"]["sampler_name"] = sampler_name
prompt["12"]["inputs"]["scheduler"] = scheduler
prompt["12"]["inputs"]["start_at_step"] = steps // 2
# Video output
prompt["14"]["inputs"]["frame_rate"] = fps
# === Генерация ===
video_b64 = generate_video(prompt)
if not video_b64:
return jsonify({"error": "Video generation failed — no output"}), 500
elapsed = time.time() - start_time
logger.info(f"✅ Видео сгенерировано за {elapsed:.1f}s")
return jsonify({
"video": video_b64,
"seed": seed,
"mode": mode,
"elapsed": round(elapsed, 1)
})
except Exception as e:
logger.error(f"❌ Ошибка: {e}", exc_info=True)
return jsonify({"error": str(e)}), 500
finally:
# Очистка temp файлов
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# Очистка output ComfyUI
try:
if os.path.exists(COMFY_OUTPUT_DIR):
for fname in os.listdir(COMFY_OUTPUT_DIR):
fpath = os.path.join(COMFY_OUTPUT_DIR, fname)
if os.path.isfile(fpath):
os.unlink(fpath)
elif os.path.isdir(fpath):
shutil.rmtree(fpath)
except Exception:
pass
# ============================================================================
# Запуск
# ============================================================================
if __name__ == "__main__":
logger.info("=" * 60)
logger.info("🚀 DaSiWa API Server")
logger.info(f" ComfyUI: http://{COMFY_HOST}:{COMFY_PORT}")
logger.info(f" API Port: {API_PORT}")
logger.info(f" Workflow: {WORKFLOW_FILE}")
logger.info("=" * 60)
# Проверяем подключение к ComfyUI
try:
urllib.request.urlopen(f"http://{COMFY_HOST}:{COMFY_PORT}/", timeout=5)
logger.info("✅ ComfyUI доступен")
except Exception:
logger.warning(" ComfyUI недоступен — запросы будут ждать")
app.run(host="0.0.0.0", port=API_PORT, debug=False)