import streamlit as st
import pandas as pd
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from axe_selenium_python import Axe
import time
import requests
import logging
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from urllib.parse import urlparse
import os
import io
import json
import re

# ----------------------------
# Logging configuration
# ----------------------------
logging.basicConfig(
    filename="accessibility_checker.log",
    filemode="a",
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO,
)

# ----------------------------
# Helper / Utility functions
# ----------------------------

def initialize_driver(wait_time: int, headless: bool = True) -> webdriver.Chrome:
    """Initialise a Selenium Chrome driver."""
    options = webdriver.ChromeOptions()
    if headless:
        options.add_argument("--headless=new")
    options.add_argument("--disable-gpu")
    options.add_argument("--no-sandbox")
    driver = webdriver.Chrome(options=options)
    driver.implicitly_wait(wait_time)
    logging.info("Initialised Selenium driver (headless=%s).", headless)
    return driver


def accept_cookies(driver: webdriver.Chrome, element_wait_time: int):
    """Accept cookie banner if present (MSD sites)."""
    try:
        accept_btn = WebDriverWait(driver, element_wait_time).until(
            EC.element_to_be_clickable((By.ID, "onetrust-accept-btn-handler"))
        )
        accept_btn.click()
        logging.info("Accepted cookies banner.")
        time.sleep(1)
    except Exception:
        # Banner not present – ignore
        logging.debug("Cookie banner not found or already accepted.")


def login_to_portal(
    driver: webdriver.Chrome,
    username: str,
    password: str,
    login_url: str,
    element_wait_time: int,
) -> bool:
    """Login to MSD Profesionales portal. Returns True on success."""
    try:
        driver.get(login_url)
        logging.info("Navigated to login page (%s)", login_url)
        time.sleep(1)
        accept_cookies(driver, element_wait_time)

        # Username (step 1)
        user_field = WebDriverWait(driver, element_wait_time).until(
            EC.presence_of_element_located((By.ID, "capture_signInFull_username"))
        )
        user_field.clear()
        user_field.send_keys(username)
        next_btn = driver.find_element(By.ID, "buttonNext_signInFull")
        next_btn.click()
        time.sleep(2)

        # Username & password (step 2)
        user_retry = WebDriverWait(driver, element_wait_time).until(
            EC.presence_of_element_located(
                (By.ID, "capture_signInFull_signInUsername")
            )
        )
        user_retry.clear()
        user_retry.send_keys(username)
        pass_field = driver.find_element(By.ID, "capture_signInFull_currentPassword")
        pass_field.clear()
        pass_field.send_keys(password)
        submit_btn = driver.find_element(By.XPATH, "//button[@type='submit' and contains(., 'Acceda')]")
        submit_btn.click()
        logging.info("Submitted credentials.")
        time.sleep(4)

        # Simple success check: still on login page?
        if "login" in driver.current_url:
            logging.error("Login appears to have failed (still on login page).")
            return False
        logging.info("Login successful! Current URL: %s", driver.current_url)
        return True
    except Exception as exc:
        logging.exception("Exception during login: %s", exc)
        return False


def get_session_cookies(driver: webdriver.Chrome) -> dict:
    """Return cookies from Selenium driver as a dict suitable for requests."""
    cookies = {c["name"]: c["value"] for c in driver.get_cookies()}
    logging.debug("Extracted %d cookies from driver.", len(cookies))
    return cookies


def is_portal_url(url: str) -> bool:
    parsed = urlparse(url)
    return "profesionales.msd.es" in parsed.netloc.lower()


def _safe_slug(text: str, maxlen: int = 60) -> str:
    """Return filesystem-safe slug of URL."""
    slug = re.sub(r"[^a-zA-Z0-9_-]", "_", text)
    return slug[:maxlen]


def run_axe(driver: webdriver.Chrome, wcag_tags: list[str], *, save_json: bool = False, url: str | None = None):
    """Run Axe accessibility audit with CSP-safe fallback. Returns list of violations."""
    try:
        axe = Axe(driver)
        try:
            axe.inject()  # may fail due to CSP
        except Exception as inject_err:
            logging.warning("Axe inline inject failed (%s). Trying CDN injection…", inject_err)
            axe_cdn = "https://cdnjs.cloudflare.com/ajax/libs/axe-core/4.9.1/axe.min.js"
            driver.execute_script(
                "if(!window.axeLoaded){var s=document.createElement('script');s.src='" + axe_cdn + "';s.onload=function(){window.axeLoaded=true;};document.head.appendChild(s);}"
            )
            WebDriverWait(driver, 10).until(lambda d: d.execute_script("return typeof axe === 'object'"))

        # Run with specified tags & explicit rules
        results = axe.run(options={
            "runOnly": {"type": "tag", "values": wcag_tags},
            "rules": {
                "image-alt": {"enabled": True},
                "image-redundant-alt": {"enabled": True},
                "input-image-alt": {"enabled": True},
                "object-alt": {"enabled": True},
                "area-alt": {"enabled": True},
                "aria-hidden-focus": {"enabled": True},
                "duplicate-id": {"enabled": True},
                "form-field-multiple-labels": {"enabled": True},
            },
        })

        # Optional: save raw JSON
        if save_json and url:
            fname = f"axe_{_safe_slug(url)}.json"
            try:
                with open(fname, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
                logging.info("Saved Axe JSON to %s", fname)
            except Exception as save_err:
                logging.warning("Could not save Axe JSON for %s: %s", url, save_err)

        violations = results.get("violations", [])
        return violations

    except Exception as e:
        logging.error(f"Error running Axe: {str(e)}")
        logging.exception("Axe execution failed")
        return []


def analyse_url(
    url: str,
    driver: webdriver.Chrome,
    session: requests.Session | None,
    wcag_tags: list[str],
    login_url: str,
    reauth_callback,
    element_wait_time: int,
    analysis_wait: int,
    debug_mode: bool,
):
    """Open URL in driver, run Axe analysis, return list of dicts (rows)."""
    try:
        # Ensure URL has a scheme
        if not url.startswith(('http://', 'https://')):
            url = 'https://' + url
            
        logging.info(f"Accessing URL: {url}")
        
        try:
            driver.get(url)
            time.sleep(2)  # Give it more time to load
            
            # Extra debug logging
            logging.info("Loaded %s (title=%s)", driver.current_url, driver.title)
            logging.info("Page source length: %d", len(driver.page_source))
            
            if debug_mode:
                screenshot_path = f"screenshot_{int(time.time())}.png"
                try:
                    driver.save_screenshot(screenshot_path)
                    logging.info("Screenshot saved to %s", screenshot_path)
                except Exception as scr_err:
                    logging.warning("Could not save screenshot for %s: %s", url, scr_err)
                
            # Check for HTTP error pages
            if "This site can't be reached" in driver.page_source:
                raise Exception(f"Cannot reach URL: {url}")
                
            # Handle potential login redirects for portal URLs
            if "login" in driver.current_url and is_portal_url(url):
                logging.info(f"Detected login redirect for {url}")
                if reauth_callback and reauth_callback():
                    driver.get(url)  # Retry after reauth
                    time.sleep(2)
                else:
                    logging.warning("Could not re-authenticate for URL: %s", url)
                    raise Exception("Authentication failed after redirect to login")
            
            # Accept cookies if banner is present
            accept_cookies(driver, element_wait_time)
            
            # Wait until document readyState is complete
            try:
                WebDriverWait(driver, 15).until(lambda d: d.execute_script('return document.readyState') == 'complete')
            except Exception:
                logging.warning("Document readyState not complete within timeout for %s", url)

            time.sleep(analysis_wait)

            # Run Axe analysis
            logging.info(f"Running Axe analysis for: {url}")
            violations = run_axe(driver, wcag_tags, save_json=debug_mode, url=url)
            logging.info("Found %d violations on %s", len(violations), url)
            
            # Process results
            rows = []
            if not violations:
                # Fallback: manual JS check for images without alt
                try:
                    missing_alt = driver.execute_script("return Array.from(document.images).filter(i => !i.hasAttribute('alt') || i.getAttribute('alt') === '').length;")
                except Exception as js_err:
                    logging.warning("JS alt check failed on %s: %s", url, js_err)
                    missing_alt = 0

                if missing_alt > 0:
                    rows.append({
                        "URL": url,
                        "Status": "Violation",
                        "Rule ID": "image-alt-js",
                        "Impact": "minor",
                        "Description": f"{missing_alt} images without alt attribute detected (manual JS check)",
                        "Help": "Add meaningful alt text to all images",
                        "WCAG Tags": "image-alt,manual-check",
                        "Help URL": "https://www.w3.org/WAI/WCAG22/Techniques/html/H37",
                        "HTML": "<img> elements missing alt",
                    })
                else:
                    rows.append({
                        "URL": url,
                        "Status": "Success",
                        "Rule ID": "-",
                        "Impact": "-",
                        "Description": "No accessibility violations found",
                        "Help": "-",
                        "WCAG Tags": ",".join(wcag_tags),
                        "Help URL": "-",
                        "HTML": "-",
                    })
            else:
                for v in violations:
                    rule_id = v.get("id", "unknown")
                    impact = v.get("impact", "unknown")
                    description = v.get("description", "No description")
                    help_text = v.get("help", "No help text available")
                    help_url = v.get("helpUrl", "#")
                    tags = ",".join(v.get("tags", []))
                    
                    # Handle nodes (multiple elements can be affected by the same rule)
                    nodes = v.get("nodes", [])
                    if not nodes:
                        rows.append({
                            "URL": url,
                            "Status": "Violation",
                            "Rule ID": rule_id,
                            "Impact": impact,
                            "Description": description,
                            "Help": help_text,
                            "WCAG Tags": tags,
                            "Help URL": help_url,
                            "HTML": "No specific element found",
                        })
                    else:
                        for node in nodes:
                            html = node.get("html", "")
                            rows.append({
                                "URL": url,
                                "Status": "Violation",
                                "Rule ID": rule_id,
                                "Impact": impact,
                                "Description": description,
                                "Help": help_text,
                                "WCAG Tags": tags,
                                "Help URL": help_url,
                                "HTML": html,
                            })
            return rows
            
        except Exception as page_error:
            logging.error(f"Page error for {url}: {str(page_error)}")
            raise  # Re-raise to be caught by the outer exception handler
            
    except Exception as exc:
        logging.exception(f"Error analyzing {url}")
        return [{
            "URL": url,
            "Status": "Error",
            "Rule ID": "ERROR",
            "Impact": "Critical",
            "Description": f"Failed to analyze: {str(exc)}",
            "Help": "Check if the URL is correct and accessible",
            "WCAG Tags": "",
            "Help URL": "#",
            "HTML": "-",
        }]

# ----------------------------
# Streamlit UI
# ----------------------------

st.set_page_config(page_title="Accessibility Compliance Checker", layout="wide")
st.image("https://msd.softwareinc.net/MSD_logo.png", width=180)
st.markdown(
    """<h2 style='text-align:center'>WCAG 2.2 Accessibility Compliance Checker</h2>""",
    unsafe_allow_html=True,
)

with st.form("accessibility_form"):
    st.subheader("Settings")
    without_login = st.checkbox("Check URLs without authentication", value=False)
    username = st.text_input("Portal Username", disabled=without_login)
    password = st.text_input("Portal Password", type="password", disabled=without_login)

    wcag_option = st.selectbox(
        "WCAG Standard / Levels", [
            "WCAG 2.2 A & AA",
            "WCAG 2.1 A & AA",
            "WCAG 2.0 A & AA",
        ],
    )

    uploaded_file = st.file_uploader("Excel file with a column named 'URL'", type=["xlsx"])

    test_mode = st.checkbox("Test mode (first 5 rows only)", value=True, help="Only process first 5 rows for faster testing")
    debug_mode = st.checkbox(
        "Debug mode (visible browser & screenshots)",
        value=False,
        help="Shows the real browser, logs page title/URL, and saves screenshots."
    )
    wait_time = st.number_input("Implicit wait (s)", min_value=1, max_value=20, value=4)
    element_wait_time = st.number_input("Element wait (s)", min_value=5, max_value=30, value=15)
    analysis_wait = st.number_input("Extra page load wait (s)", min_value=0, max_value=20, value=3, help="Additional seconds to wait after page load before running Axe")
    submit_btn = st.form_submit_button("Run Accessibility Audit")

if submit_btn:
    if uploaded_file is None:
        st.error("Please upload an Excel file.")
        st.stop()

    if not without_login and (not username or not password):
        st.error("Username and password are required.")
        st.stop()

    # Map wcag_option to axe tags
    option_tag_map = {
        "WCAG 2.2 A & AA": ["wcag2a", "wcag2aa", "wcag21a", "wcag21aa", "wcag22a", "wcag22aa", "best-practice"],
        "WCAG 2.1 A & AA": ["wcag2a", "wcag2aa", "wcag21a", "wcag21aa", "best-practice"],
        "WCAG 2.0 A & AA": ["wcag2a", "wcag2aa", "best-practice"],
    }
    wcag_tags = option_tag_map[wcag_option]

    # Read Excel
    df_urls = pd.read_excel(uploaded_file)
    if "URL" not in df_urls.columns:
        st.error("Excel must contain a column named 'URL'.")
        st.stop()

    # Limit to 5 rows if test mode is on
    if test_mode:
        df_urls = df_urls.head(5)
        st.info(f"Test mode: Processing first 5 rows out of {len(df_urls)} total rows.")

    driver = initialize_driver(wait_time=wait_time, headless=not debug_mode)
    login_url = "https://profesionales.msd.es/login/"
    session = requests.Session()

    # Authenticate if required
    if not without_login:
        with st.spinner("Logging in to portal..."):
            login_ok = login_to_portal(driver, username, password, login_url, element_wait_time)
        if not login_ok:
            st.error("Login failed. Please verify credentials.")
            driver.quit()
            st.stop()
        session.cookies.update(get_session_cookies(driver))

    # Prepare progress UI
    results_rows: list[dict] = []
    progress_bar = st.progress(0.0, text="Starting analysis…")
    status_placeholder = st.empty()

    def reauth():
        if without_login:
            return False
        return login_to_portal(driver, username, password, login_url, element_wait_time)

    urls = df_urls["URL"].dropna().unique().tolist()
    total = len(urls)

    for idx, u in enumerate(urls):
        rows = analyse_url(
            u,
            driver,
            session,
            wcag_tags,
            login_url,
            reauth,
            element_wait_time,
            analysis_wait,
            debug_mode,
        )
        results_rows.extend(rows)
        logging.info("Completed analysis for %s (rows=%d)", u, len(rows))
        progress_bar.progress((idx + 1) / total, text=f"Processed {idx+1}/{total}")
        status_placeholder.text(f"Processed {idx+1}/{total} URLs")

    driver.quit()

    if not results_rows:
        st.warning("No results produced.")
        st.stop()

    df_report = pd.DataFrame(results_rows)

    # Create Excel in-memory
    output = io.BytesIO()
    with pd.ExcelWriter(
        output,
        engine="xlsxwriter",
        date_format="yyyy-mm-dd",
        datetime_format="yyyy-mm-dd hh:mm:ss",
    ) as writer:
        df_report.to_excel(writer, index=False, sheet_name="Accessibility Report")
        worksheet = writer.sheets["Accessibility Report"]
        worksheet.set_column("A:A", 50)
        worksheet.set_column("B:B", 18)
        worksheet.set_column("C:C", 10)
        worksheet.set_column("D:D", 30)
        worksheet.set_column("E:E", 30)
        worksheet.set_column("F:F", 20)
        worksheet.set_column("G:G", 60)
        worksheet.set_column("H:H", 40)
    excel_bytes = output.getvalue()

    st.success("Audit complete! Download the report below.")
    st.download_button(
        label="Download Accessibility Report",
        data=excel_bytes,
        file_name=f"accessibility_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
        mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    )

st.markdown(
    """
    <style>
        .version-footer {position: fixed; bottom: 10px; left: 0; width: 100%; text-align: center; font-size: 11px; color: grey;}
    </style>
    <div class="version-footer">v0.1 Accessibility Checker</div>
    """,
    unsafe_allow_html=True,
)
