"""
3D Thumbnail Generator for AIMMS
Generates static 2D thumbnails for 3D models to be used as previews
"""

import os
import json
from pathlib import Path
import tempfile
import time
from datetime import datetime

import trimesh
import plotly.graph_objects as go
from PIL import Image, ImageDraw
import numpy as np


def generate_3d_thumbnail(model_path, output_path=None, size=(280, 180)):
    """
    Generate a 2D thumbnail for a 3D model.
    
    Args:
        model_path (str or Path): Path to the 3D model file
        output_path (str or Path, optional): Path to save the thumbnail. If None, uses default location
        size (tuple): Thumbnail size as (width, height)
        
    Returns:
        Path: Path to the generated thumbnail
    """
    model_path = Path(model_path)
    
    if not model_path.exists():
        return None
    
    # Generate default output path if not provided
    if output_path is None:
        thumbnail_filename = f"{model_path.stem}_thumbnail.png"
        # Create thumbnails subfolder within the asset's media folder
        thumbnails_folder = model_path.parent / "thumbnails"
        thumbnails_folder.mkdir(exist_ok=True)
        output_path = thumbnails_folder / thumbnail_filename
    
    try:
        # Load the 3D model
        mesh = load_3d_model(model_path)
        if mesh is None:
            return None
        
        # Create Plotly figure with specific camera angle for good thumbnail
        fig = create_thumbnail_figure(mesh, size)
        
        # Generate temporary HTML file
        html_content = fig.to_html(include_plotlyjs='cdn', config={'displayModeBar': False})
        
        # Create temporary HTML file
        with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as temp_file:
            temp_file.write(html_content)
            temp_html_path = temp_file.name
        
        try:
            # Use a headless browser approach to capture screenshot
            # Since we can't easily use selenium here, we'll create a simple fallback
            # that generates a basic placeholder for now
            thumbnail_path = create_fallback_thumbnail(mesh, output_path)
            
        finally:
            # Clean up temporary file
            if os.path.exists(temp_html_path):
                os.unlink(temp_html_path)
        
        return thumbnail_path
        
    except Exception as e:
        print(f"Error generating 3D thumbnail for {model_path}: {e}")
        return None


def load_3d_model(model_path):
    """Load 3D model using trimesh with error handling."""
    try:
        # Load mesh using trimesh
        mesh = trimesh.load(str(model_path), force='mesh', process=False)
        
        # Handle Scene objects
        if isinstance(mesh, trimesh.Scene):
            geometries = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
            if not geometries:
                raise ValueError("No valid mesh found in scene")
            mesh = geometries[0]
        
        # Validate mesh
        if not isinstance(mesh, trimesh.Trimesh) or len(mesh.vertices) == 0:
            raise ValueError("Invalid mesh data")
        
        # Simplify if too complex for better performance
        if len(mesh.faces) > 10000:
            try:
                target_faces = 5000
                current_faces = len(mesh.faces)
                target_reduction = max(0.1, min(0.9, 1.0 - (target_faces / current_faces)))
                mesh = mesh.simplify_quadric_decimation(target_reduction=target_reduction)
            except TypeError:
                try:
                    mesh = mesh.simplify_quadric_decimation(5000)
                except Exception:
                    pass
        
        return mesh
        
    except Exception as e:
        print(f"Error loading 3D model {model_path}: {e}")
        return None


def create_thumbnail_figure(mesh, size):
    """Create a Plotly figure optimized for thumbnail generation."""
    vertices = mesh.vertices
    faces = mesh.faces
    
    # Create Plotly figure with good camera angle for thumbnails
    fig = go.Figure(data=[
        go.Mesh3d(
            x=vertices[:, 0],
            y=vertices[:, 1],
            z=vertices[:, 2],
            i=faces[:, 0],
            j=faces[:, 1],
            k=faces[:, 2],
            color='lightblue',
            opacity=0.9,
            flatshading=True
        )
    ])
    
    # Configure layout for thumbnail
    fig.update_layout(
        scene=dict(
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, visible=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, visible=False),
            zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, visible=False),
            aspectmode='data',
            bgcolor='#2b2b2b',
            camera=dict(
                eye=dict(x=1.25, y=1.25, z=1.25),  # Good 3/4 view angle
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0)
            )
        ),
        margin=dict(l=0, r=0, b=0, t=0),
        paper_bgcolor='#2b2b2b',
        showlegend=False,
        width=size[0],
        height=size[1]
    )
    
    return fig


def create_fallback_thumbnail(mesh, output_path):
    """
    Create a thumbnail using trimesh scene rendering to show actual 3D model.
    This generates real visual representation of the model instead of text.
    """
    try:
        # Try to create actual 3D model thumbnail using trimesh scene rendering
        if mesh and len(mesh.vertices) > 0:
            # Create scene from mesh
            scene = mesh.scene()
            
            # Set up camera for good thumbnail view
            # Use better distance and try different approaches for optimal model visibility
            try:
                # Try the new API first (trimesh 4.0+) with better distance and angle
                camera_transform = trimesh.scene.cameras.look_at(
                    points=[scene.centroid],
                    distance=scene.scale * 1.2,  # Closer for better model visibility
                    rotation=None,
                    fov=60.0  # Add required field of view parameter
                )
            except TypeError:
                # Fall back to old API (trimesh 3.x)
                camera_transform = trimesh.scene.cameras.look_at(
                    points=[scene.centroid],
                    distance=scene.scale * 1.2  # Closer distance
                )
            
            # Set camera on scene
            scene.camera_transform = camera_transform
            
            # Try to render the scene to PNG
            try:
                png = scene.save_image(
                    resolution=(280, 180),
                    transparent=False,
                    background=None
                )
                
                # Save the rendered image
                with open(output_path, 'wb') as f:
                    f.write(png)
                
                print(f"Generated actual 3D model thumbnail: {output_path}")
                return output_path
                
            except Exception as render_error:
                print(f"trimesh scene rendering failed: {render_error}")
                # Fall back to simple wireframe if scene rendering fails
        
        # Fallback: Create simple wireframe representation if scene rendering fails
        width, height = 280, 180
        image = Image.new('RGB', (width, height), '#2b2b2b')
        draw = ImageDraw.Draw(image)
        
        if mesh and len(mesh.vertices) > 0:
            # Project 3D vertices to 2D for wireframe
            vertices_2d = mesh.vertices[:, :2]  # Simple orthographic projection
            
            # Normalize to image coordinates
            min_vals = vertices_2d.min(axis=0)
            max_vals = vertices_2d.max(axis=0)
            range_vals = max_vals - min_vals
            
            if not np.any(range_vals == 0):
                # Scale and center the model in the image
                vertices_2d = (vertices_2d - min_vals) / range_vals
                vertices_2d = vertices_2d * np.array([width * 0.8, height * 0.8]) + np.array([width * 0.1, height * 0.1])
                
                # Draw edges from faces (limit to avoid clutter)
                max_faces_to_draw = min(50, len(mesh.faces))
                for face_idx in range(max_faces_to_draw):
                    face = mesh.faces[face_idx]
                    for i in range(3):
                        v1_idx = face[i]
                        v2_idx = face[(i + 1) % 3]
                        
                        v1 = vertices_2d[v1_idx]
                        v2 = vertices_2d[v2_idx]
                        
                        # Check if points are within image bounds
                        if (0 <= v1[0] <= width and 0 <= v1[1] <= height and
                            0 <= v2[0] <= width and 0 <= v2[1] <= height):
                            draw.line([tuple(v1), tuple(v2)], fill='#44aaff', width=1)
        
        # Save the wireframe image
        image.save(output_path)
        print(f"Generated wireframe thumbnail: {output_path}")
        return output_path
        
    except Exception as e:
        print(f"Error creating thumbnail: {e}")
        return None


def get_model_thumbnail_path(model_path):
    """
    Get the expected thumbnail path for a 3D model.
    
    Args:
        model_path (str or Path): Path to the 3D model file
        
    Returns:
        Path: Expected thumbnail path in thumbnails subfolder
    """
    model_path = Path(model_path)
    
    # Create thumbnails subfolder within the asset's media folder
    thumbnails_folder = model_path.parent / "thumbnails"
    thumbnails_folder.mkdir(exist_ok=True)
    
    thumbnail_filename = f"{model_path.stem}_thumbnail.png"
    return thumbnails_folder / thumbnail_filename


def should_generate_thumbnail(model_path, thumbnail_path):
    """
    Check if a thumbnail should be generated (model is newer than thumbnail).
    
    Args:
        model_path (Path): Path to the 3D model file
        thumbnail_path (Path): Path to the thumbnail file
        
    Returns:
        bool: True if thumbnail should be generated, False otherwise
    """
    if not thumbnail_path.exists():
        return True
    
    model_mtime = model_path.stat().st_mtime
    thumbnail_mtime = thumbnail_path.stat().st_mtime
    
    return model_mtime > thumbnail_mtime


def generate_or_get_thumbnail(model_path):
    """
    Generate a thumbnail for a 3D model if needed, or return existing one.
    
    Args:
        model_path (str or Path): Path to the 3D model file
        
    Returns:
        Path: Path to the thumbnail file, or None if generation failed
    """
    model_path = Path(model_path)
    thumbnail_path = get_model_thumbnail_path(model_path)
    
    if should_generate_thumbnail(model_path, thumbnail_path):
        result = generate_3d_thumbnail(model_path, thumbnail_path)
        return result
    
    return thumbnail_path if thumbnail_path.exists() else None


def generate_thumbnails_for_asset_folder(asset_folder_path):
    """
    Generate thumbnails for all 3D models in an asset folder.
    
    Args:
        asset_folder_path (str or Path): Path to the asset folder
        
    Returns:
        dict: Dictionary mapping model paths to thumbnail paths
    """
    asset_folder_path = Path(asset_folder_path)
    thumbnail_map = {}
    
    if not asset_folder_path.exists():
        return thumbnail_map
    
    # Supported 3D file extensions
    supported_extensions = {'.fbx', '.glb', '.usd', '.obj', '.stl', '.ply'}
    
    for file_path in asset_folder_path.iterdir():
        if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
            thumbnail_path = generate_or_get_thumbnail(file_path)
            if thumbnail_path:
                thumbnail_map[str(file_path)] = str(thumbnail_path)
    
    return thumbnail_map


# Cache for generated thumbnails
_thumbnail_cache = {}
_last_cache_update = 0
_CACHE_TIMEOUT = 300  # 5 minutes


def get_cached_thumbnail(model_path, force_refresh=False):
    """
    Get a thumbnail from cache or generate it if not cached or stale.
    
    Args:
        model_path (str or Path): Path to the 3D model file
        force_refresh (bool): Force regeneration of thumbnail
        
    Returns:
        Path: Path to the thumbnail file, or None if not available
    """
    global _thumbnail_cache, _last_cache_update
    
    model_path = str(Path(model_path))
    
    # Check cache timeout
    current_time = time.time()
    if current_time - _last_cache_update > _CACHE_TIMEOUT:
        _thumbnail_cache.clear()
        _last_cache_update = current_time
    
    # Return cached result if available and not forcing refresh
    if not force_refresh and model_path in _thumbnail_cache:
        cached_path = _thumbnail_cache[model_path]
        if Path(cached_path).exists():
            return Path(cached_path)
    
    # Generate or get thumbnail
    thumbnail_path = generate_or_get_thumbnail(model_path)
    
    if thumbnail_path:
        _thumbnail_cache[model_path] = str(thumbnail_path)
    
    return thumbnail_path


if __name__ == "__main__":
    # Test the thumbnail generator
    import sys
    
    if len(sys.argv) > 1:
        model_path = Path(sys.argv[1])
        if model_path.exists():
            thumbnail_path = generate_or_get_thumbnail(model_path)
            if thumbnail_path:
                print(f"Thumbnail generated: {thumbnail_path}")
            else:
                print("Failed to generate thumbnail")
        else:
            print(f"Model file not found: {model_path}")
    else:
        print("Usage: python 3d_thumbnail_generator.py <3d_model_path>")