import pandas as pd
import re
from django.core.management.base import BaseCommand
from django.db import transaction
from products.models import ProductCategory, ProductSubCategory, Products, SKU


class Command(BaseCommand):
    help = "Import Products & SKUs from Excel (robust unit detection + smart codes)"

    def add_arguments(self, parser):
        parser.add_argument("file_path", type=str, help="Path to Excel file")

    # ======================================================
    # HELPERS
    # ======================================================
    def v(self, row, col, default=None):
        return row[col] if col in row and pd.notna(row[col]) else default

    def generate_code(self, name, prefix="", max_len=40):
        name = str(name).upper().strip()
        name = re.sub(r"[^\w\s]", "", name)
        parts = name.split()

        if len(parts) > 3:
            code = "".join(p[0] for p in parts)
        else:
            code = "_".join(parts)

        code = re.sub(r"_+", "_", code)

        if prefix:
            code = f"{prefix}_{code}"

        return code[:max_len]

    def detect_unit(self, sku_name):
        sku_name = sku_name.upper().strip()

        if re.search(r"(KG|KGS|KILOGRAM)", sku_name):
            return "kg"
        if re.search(r"(GM|GRAM|\bG\b)", sku_name):
            return "gm"
        if re.search(r"(ML|MILLILITER)", sku_name):
            return "ml"
        if re.search(r"(LTR|LITER|\bL\b)", sku_name):
            return "ltr"
        if re.search(r"(NOS|PCS|PIECE)", sku_name):
            return "nos"

        return "unit"

    # ======================================================
    # MAIN
    # ======================================================
    def handle(self, *args, **kwargs):
        file_path = kwargs["file_path"]

        df = pd.read_excel(file_path, header=1)
        df.columns = df.columns.str.strip()

        products_cache = {}
        product_count = 0
        sku_count = 0
        errors = 0

        with transaction.atomic():
            for index, row in df.iterrows():
                row_no = index + 3

                try:
                    # ================= CATEGORY =================
                    category_name = self.v(row, "Category")
                    if not category_name:
                        raise ValueError("Category missing")

                    category_code = self.generate_code(category_name, prefix="CAT")

                    category, _ = ProductCategory.objects.get_or_create(
                        category_code=category_code,
                        defaults={
                            "category_name": category_name,
                            "long_distance_availability": False
                        }
                    )

                    # ================= SUB CATEGORY =================
                    sub_name = self.v(row, "Sub-Category")
                    if not sub_name:
                        raise ValueError("Sub-category missing")

                    sub_code = self.generate_code(sub_name, prefix=category_code)

                    sub_category, _ = ProductSubCategory.objects.get_or_create(
                        sub_category_code=sub_code,
                        defaults={
                            "sub_category_name": sub_name,
                            "category": category,
                            "long_distance_availability": False
                        }
                    )

                    # ================= PRODUCT =================
                    item_code = str(self.v(row, "Item Code")).strip()
                    if not item_code:
                        raise ValueError("Item Code missing")

                    if item_code in products_cache:
                        product = products_cache[item_code]
                    else:
                        product, _ = Products.objects.get_or_create(
                            item_code=item_code,
                            defaults={
                                "product_type": "Master Product",
                                "item_name": self.v(row, "Item Name"),
                                "item_category": category,
                                "item_sub_category": sub_category,
                                "item_description": self.v(row, "Item Description", ""),
                                "veg_or_non_veg_status": self.v(row, "Veg or Non-Veg", "Veg"),
                                "i_gst": float(self.v(row, "IGST", 0)),
                                "s_gst": float(self.v(row, "SGST", 0)),
                                "c_gst": float(self.v(row, "CGST", 0)),
                                "cess": float(self.v(row, "CESS", 0)),
                                "long_distance_availability": False,
                            }
                        )
                        products_cache[item_code] = product
                        product_count += 1

                    # ================= SKU =================
                    sku_code = self.v(row, "SKU Code")
                    if not sku_code:
                        raise ValueError("SKU Code missing")

                    sku_name = str(self.v(row, "SKU Name", "")).strip()

                    raw_sku_unit = self.v(row, "Unit")
                    sku_unit = str(raw_sku_unit).strip().lower() if raw_sku_unit else ""

                    if not sku_unit or sku_unit in ["nan", "none", "null", "-", ""]:
                        sku_unit = self.detect_unit(sku_name)

                    sku_defaults = {
                        "sku_name": sku_name,
                        "sku_quantity": int(self.v(row, "Quantity", 1)),
                        "sku_unit": sku_unit,
                        "sku_mrp": float(self.v(row, "MRP", 0)),
                        "sku_expiry_duration": int(self.v(row, "Expiry Duration(Days)", 0)),
                        "sku_bulk_qty_limit": int(self.v(row, "Bulk Quantity Limit", 0)),
                        "sku_status": self.v(row, "SKU Status", "Visible").title(),
                        "long_distance_availability": False,
                        "same_day_delivery": False,
                        "customization_available": False,
                    }

                    SKU.objects.update_or_create(
                        sku_code=sku_code,
                        product=product,
                        defaults=sku_defaults
                    )

                    sku_count += 1
                    self.stdout.write(f"✅ Row {row_no}: SKU saved → {sku_code} ({sku_unit})")

                except Exception as e:
                    errors += 1
                    self.stdout.write(self.style.ERROR(f"❌ Row {row_no}: {e}"))

        # ================= SUMMARY =================
        self.stdout.write(self.style.SUCCESS("\n🎉 IMPORT COMPLETED"))
        self.stdout.write(f"📦 Products created : {product_count}")
        self.stdout.write(f"🏷️  SKUs processed   : {sku_count}")
        self.stdout.write(f"❌ Errors            : {errors}")
