# trial_state_manager.py
import json
import os
import sys
import platform
import hashlib
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any
from .fails_open_integrity import get_stable_machine_fingerprint

class TrialStateManager:
    """Secure trial state manager with multiple persistence mechanisms"""
    
    def __init__(self):
        self.system = platform.system().lower()
        
        # Define secure storage locations
        self.primary_locations = self._get_primary_storage_locations()
        self.backup_locations = self._get_backup_storage_locations()
        
    def _get_primary_storage_locations(self) -> list:
        """Get primary storage locations for trial state"""
        locations = []
        
        # User home directory (hidden file)
        home_dir = Path.home()
        locations.append(home_dir / ".aimms_trial.json")
        
        # Application data directory (cross-platform)
        if self.system == "windows":
            app_data = Path(os.environ.get("APPDATA", ""))
            locations.append(app_data / "AIMMS" / "trial_state.json")
        elif self.system == "darwin":  # macOS
            app_support = Path.home() / "Library" / "Application Support" / "AIMMS"
            locations.append(app_support / "trial_state.json")
        else:  # Linux and others
            config_dir = Path.home() / ".config" / "aimms"
            locations.append(config_dir / "trial_state.json")
        
        return locations
    
    def _get_backup_storage_locations(self) -> list:
        """Get backup storage locations for trial state"""
        locations = []
        
        # System registry (Windows only)
        if self.system == "windows":
            locations.append("registry")
        
        # Additional hidden locations
        locations.append(Path.home() / ".aimms_trial_backup.json")
        
        return locations
    
    def save_trial_state(self, trial_data: Dict[str, Any]) -> bool:
        """Save trial state to multiple secure locations"""
        success_count = 0

        # Try to detect at least one writable location before attempting full save
        writable_found = False
        write_errors = []

        # Check primary file locations
        for location in self.primary_locations:
            try:
                # Ensure parent directory exists (do not create file yet)
                location.parent.mkdir(parents=True, exist_ok=True)
                # Try to open for writing a temp file to verify permissions
                test_path = location.with_suffix(location.suffix + ".tmp") if location.suffix else Path(str(location) + ".tmp")
                with open(test_path, 'w') as tf:
                    tf.write('test')
                try:
                    test_path.unlink()
                except Exception:
                    pass
                writable_found = True
                break
            except Exception as e:
                write_errors.append((str(location), str(e)))
                continue

        # Check backup locations (registry or file)
        if not writable_found:
            for location in self.backup_locations:
                try:
                    if location == "registry":
                        # Try to write/read a test value to registry to verify permissions
                        try:
                            import winreg
                            key_path = r"SOFTWARE\AIMMS\TrialState"
                            with winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path) as key:
                                winreg.SetValueEx(key, "TrialTest", 0, winreg.REG_SZ, "1")
                                winreg.DeleteValue(key, "TrialTest")
                            writable_found = True
                            break
                        except Exception as e:
                            write_errors.append(("registry", str(e)))
                            continue
                    else:
                        location.parent.mkdir(parents=True, exist_ok=True)
                        test_path = location.with_suffix(location.suffix + ".tmp") if location.suffix else Path(str(location) + ".tmp")
                        with open(test_path, 'w') as tf:
                            tf.write('test')
                        try:
                            test_path.unlink()
                        except Exception:
                            pass
                        writable_found = True
                        break
                except Exception as e:
                    write_errors.append((str(location), str(e)))
                    continue

        if not writable_found:
            # Fail fast with detailed message
            msg_lines = [f"No writable storage locations available for trial state. Checked locations:"]
            for loc, err in write_errors:
                msg_lines.append(f" - {loc}: {err}")
            msg = "\n".join(msg_lines)
            try:
                from logging_manager import log_boot_error
                log_boot_error(msg)
            except Exception:
                print(msg)
            return False

        # Save to primary locations
        for location in self.primary_locations:
            try:
                # Ensure directory exists
                location.parent.mkdir(parents=True, exist_ok=True)

                # Make file hidden on Windows (best-effort)
                if self.system == "windows":
                    try:
                        self._make_hidden_file(location)
                    except Exception:
                        pass

                # Prepare safe data and write
                safe_data = self._prepare_safe_data(trial_data)
                with open(location, 'w') as f:
                    json.dump(safe_data, f, indent=2)

                success_count += 1

            except Exception as e:
                try:
                    from logging_manager import log_boot_error
                    log_boot_error(f"Could not save trial state to {location}: {e}")
                except Exception:
                    print(f"Warning: Could not save trial state to {location}: {e}")
                continue

        # Save to backup locations
        for location in self.backup_locations:
            try:
                if location == "registry":
                    # Windows registry storage
                    try:
                        self._save_to_registry(trial_data)
                        success_count += 1
                    except Exception as e:
                        try:
                            from logging_manager import log_boot_error
                            log_boot_error(f"Registry save failed: {e}")
                        except Exception:
                            print(f"Warning: Registry save failed: {e}")
                else:
                    # File backup
                    location.parent.mkdir(parents=True, exist_ok=True)
                    safe_data = self._prepare_safe_data(trial_data)

                    with open(location, 'w') as f:
                        json.dump(safe_data, f, indent=2)

                    success_count += 1

            except Exception as e:
                try:
                    from logging_manager import log_boot_error
                    log_boot_error(f"Could not save backup trial state to {location}: {e}")
                except Exception:
                    print(f"Warning: Could not save backup trial state: {e}")
                continue

        # Return True if at least one location succeeded
        return success_count > 0
    
    def load_trial_state(self) -> Optional[Dict[str, Any]]:
        """Load trial state from secure locations (tries primary first)"""
        # Try primary locations first
        for location in self.primary_locations:
            try:
                if location.exists():
                    with open(location, 'r') as f:
                        data = json.load(f)
                        restored = self._restore_original_data(data)
                        # Validate machine fingerprint matches current machine
                        if restored is None:
                            return None
                        # If stored has machine_fingerprint_hash, validate it
                        stored_hash = restored.get('machine_fingerprint_hash')
                        if stored_hash:
                            try:
                                current_fp = get_stable_machine_fingerprint()
                                current_hash = hashlib.sha256(current_fp.encode()).hexdigest()[:16]
                                if stored_hash != current_hash:
                                    print("Warning: Trial state machine fingerprint does not match current machine")
                                    return None
                            except Exception as e:
                                print(f"Warning: Could not validate machine fingerprint: {e}")
                                return None

                        return restored
            except Exception as e:
                print(f"Warning: Could not load trial state from {location}: {e}")
                continue
        
        # Try backup locations
        for location in self.backup_locations:
            try:
                if location == "registry":
                    # Windows registry loading
                    data = self._load_from_registry()
                    if data:
                        restored = self._restore_original_data(data)
                        if restored is None:
                            return None
                        stored_hash = restored.get('machine_fingerprint_hash')
                        if stored_hash:
                            try:
                                current_fp = get_stable_machine_fingerprint()
                                current_hash = hashlib.sha256(current_fp.encode()).hexdigest()[:16]
                                if stored_hash != current_hash:
                                    print("Warning: Backup trial state fingerprint mismatch")
                                    return None
                            except Exception as e:
                                print(f"Warning: Could not validate machine fingerprint from registry: {e}")
                                return None

                        return restored
                else:
                    if location.exists():
                        with open(location, 'r') as f:
                            data = json.load(f)
                            restored = self._restore_original_data(data)
                            if restored is None:
                                return None
                            stored_hash = restored.get('machine_fingerprint_hash')
                            if stored_hash:
                                try:
                                    current_fp = get_stable_machine_fingerprint()
                                    current_hash = hashlib.sha256(current_fp.encode()).hexdigest()[:16]
                                    if stored_hash != current_hash:
                                        print("Warning: Backup trial state fingerprint mismatch (file)")
                                        return None
                                except Exception as e:
                                    print(f"Warning: Could not validate machine fingerprint from backup file: {e}")
                                    return None

                            return restored
            except Exception as e:
                print(f"Warning: Could not load backup trial state: {e}")
                continue
        
        return None
    
    def _prepare_safe_data(self, trial_data: Dict[str, Any]) -> Dict[str, Any]:
        """Prepare data for storage by encrypting sensitive fields"""
        safe_data = trial_data.copy()
        
        # Hash sensitive fields instead of storing them directly
        if 'machine_fingerprint' in safe_data:
            safe_data['machine_fingerprint_hash'] = hashlib.sha256(
                safe_data['machine_fingerprint'].encode()
            ).hexdigest()[:16]
            del safe_data['machine_fingerprint']
        
        if 'install_path_hash' in safe_data:
            safe_data['install_path_hash'] = hashlib.sha256(
                safe_data['install_path_hash'].encode()
            ).hexdigest()[:8]
        
        # Add integrity checksum
        data_string = json.dumps(safe_data, sort_keys=True)
        safe_data['integrity_checksum'] = hashlib.sha256(data_string.encode()).hexdigest()[:16]
        
        return safe_data
    
    def _restore_original_data(self, safe_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Restore original data from safe storage"""
        try:
            # Verify integrity
            if 'integrity_checksum' in safe_data:
                data_string = json.dumps({k: v for k, v in safe_data.items() if k != 'integrity_checksum'}, sort_keys=True)
                expected_checksum = safe_data['integrity_checksum']
                actual_checksum = hashlib.sha256(data_string.encode()).hexdigest()[:16]
                
                if expected_checksum != actual_checksum:
                    print("Warning: Trial state integrity check failed")
                    return None
            
            # Restore sensitive fields (they'll be re-validated on use)
            restored_data = safe_data.copy()
            
            # Remove checksum
            if 'integrity_checksum' in restored_data:
                del restored_data['integrity_checksum']
            
            return restored_data
            
        except Exception as e:
            print(f"Warning: Could not restore trial data: {e}")
            return None
    
    def _make_hidden_file(self, file_path: Path):
        """Make file hidden on Windows"""
        try:
            import stat
            file_path.chmod(stat.FILE_ATTRIBUTE_HIDDEN)
        except Exception:
            pass  # Not critical if hiding fails
    
    def _save_to_registry(self, trial_data: Dict[str, Any]):
        """Save trial state to Windows registry"""
        try:
            import winreg
            
            # Prepare safe data
            safe_data = self._prepare_safe_data(trial_data)
            
            # Create registry key
            key_path = r"SOFTWARE\AIMMS\TrialState"
            
            with winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path) as key:
                # Set values
                winreg.SetValueEx(key, "TrialData", 0, winreg.REG_SZ, json.dumps(safe_data))
                winreg.SetValueEx(key, "LastUpdated", 0, winreg.REG_SZ, datetime.now().isoformat())
                
        except Exception as e:
            raise Exception(f"Registry save failed: {e}")
    
    def _load_from_registry(self) -> Optional[Dict[str, Any]]:
        """Load trial state from Windows registry"""
        try:
            import winreg
            
            key_path = r"SOFTWARE\AIMMS\TrialState"
            
            with winreg.OpenKey(winreg.HKEY_CURRENT_USER, key_path) as key:
                trial_data_str = winreg.QueryValueEx(key, "TrialData")[0]
                safe_data = json.loads(trial_data_str)
                return self._restore_original_data(safe_data)
                
        except Exception:
            return None
    
    def delete_trial_state(self) -> bool:
        """Delete trial state from all locations"""
        success = True
        
        # Delete from primary locations
        for location in self.primary_locations:
            try:
                if location.exists():
                    location.unlink()
            except Exception:
                success = False
        
        # Delete from backup locations
        for location in self.backup_locations:
            try:
                if location == "registry":
                    self._delete_from_registry()
                else:
                    if location.exists():
                        location.unlink()
            except Exception:
                success = False
        
        return success
    
    def _delete_from_registry(self):
        """Delete trial state from Windows registry"""
        try:
            import winreg
            
            key_path = r"SOFTWARE\AIMMS\TrialState"
            
            try:
                winreg.DeleteKey(winreg.HKEY_CURRENT_USER, key_path)
            except FileNotFoundError:
                pass  # Key doesn't exist, that's fine
            
        except Exception:
            pass  # Not critical if deletion fails
    
    def trial_state_exists(self) -> bool:
        """Check if trial state exists in any location"""
        # Check primary locations
        for location in self.primary_locations:
            if location.exists():
                return True
        
        # Check backup locations
        for location in self.backup_locations:
            if location == "registry":
                try:
                    import winreg
                    key_path = r"SOFTWARE\AIMMS\TrialState"
                    winreg.OpenKey(winreg.HKEY_CURRENT_USER, key_path)
                    return True
                except:
                    pass
            else:
                if location.exists():
                    return True
        
        return False
    
    def get_trial_status(self) -> tuple[bool, str, int]:
        """Get trial status with days remaining"""
        trial_data = self.load_trial_state()
        if not trial_data:
            return False, "No trial found", 0

        # Validate structure
        if not isinstance(trial_data, dict):
            return False, "Corrupt trial data: unexpected format", 0

        end_date_str = trial_data.get('end_date')
        if not end_date_str:
            return False, "Corrupt trial data: missing end_date", 0

        try:
            current_date = datetime.now()
            trial_end_date = datetime.fromisoformat(end_date_str)
        except Exception as e:
            return False, f"Corrupt trial data: invalid end_date ({e})", 0

        if current_date > trial_end_date:
            return False, "Trial expired", 0

        days_remaining = (trial_end_date - current_date).days
        return True, f"Trial active - {days_remaining} days remaining", days_remaining