import json
import base64
import hmac
import hashlib
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import time, requests, uuid

def sign_jws_billdesk(payload_string, signing_key, key_id, client_id, algorithm="HS256"):
    """
    Sign data using JWS with HMAC for BillDesk integration
    
    Args:
        payload_string (str): The data to sign (usually encrypted JWE)
        signing_key (str): BillDesk provided signing key
        key_id (str): BillDesk provided signing key ID
        client_id (str): BillDesk provided client ID
        algorithm (str): JWS algorithm (default: "HS256")
    
    Returns:
        str: JWS signed string
    """
    
    # Create JWS Header
    jws_header = {
        "alg": algorithm,  # HMAC SHA-256
        "typ": "JWT",     # Standard JWT type
        "kid": key_id,    # Key ID
        "clientid": client_id  # Custom parameter
    }
    
    # Base64URL encode header
    header_json = json.dumps(jws_header, separators=(',', ':'))
    header_b64 = base64.urlsafe_b64encode(header_json.encode('utf-8')).decode('utf-8').rstrip('=')
    
    # Base64URL encode payload
    payload_b64 = base64.urlsafe_b64encode(payload_string.encode('utf-8')).decode('utf-8').rstrip('=')
    
    # Create signing input: header.payload
    signing_input = f"{header_b64}.{payload_b64}"
    
    # Sign with HMAC-SHA256
    signing_key_bytes = signing_key.encode('utf-8')
    
    if algorithm == "HS256":
        signature = hmac.new(signing_key_bytes, signing_input.encode('utf-8'), hashlib.sha256).digest()
    elif algorithm == "HS384":
        signature = hmac.new(signing_key_bytes, signing_input.encode('utf-8'), hashlib.sha384).digest()
    elif algorithm == "HS512":
        signature = hmac.new(signing_key_bytes, signing_input.encode('utf-8'), hashlib.sha512).digest()
    else:
        raise ValueError(f"Unsupported algorithm: {algorithm}")
    
    # Base64URL encode signature
    signature_b64 = base64.urlsafe_b64encode(signature).decode('utf-8').rstrip('=')
    
    # Construct JWS: header.payload.signature
    jws_signed_data = f"{header_b64}.{payload_b64}.{signature_b64}"
    
    return jws_signed_data


def encrypt_and_sign_billdesk_payment(merc_id, amount, return_url, 
                                    encryption_key, enc_key_id, enc_client_id,
                                    signing_key, sign_key_id, sign_client_id,
                                    currency="356", itemcode="DIRECT", user_ip="123.0.0.1"):
    """
    Complete BillDesk payment flow: Create payload -> Encrypt with JWE -> Sign with JWS
    
    Args:
        merc_id (str): Merchant ID
        amount (str): Payment amount
        return_url (str): Return URL
        encryption_key (str): Encryption key
        enc_key_id (str): Encryption key ID
        enc_client_id (str): Encryption client ID
        signing_key (str): Signing key
        sign_key_id (str): Signing key ID
        sign_client_id (str): Signing client ID
        currency (str): Currency code (default: "356" for INR)
        itemcode (str): Item code (default: "DIRECT")
        user_ip (str): User's IP address
    
    Returns:
        tuple: (original_payload, encrypted_data, signed_data)
    """
    
    # Step 1: Create payload
    payload = {
        "orderid": f"TEST{int(time.time())}",
        "mercid": merc_id,
        "order_date": time.strftime("%Y-%m-%dT%H:%M:%S+05:30"),
        "amount": amount,
        "currency": currency,
        "ru": return_url,
        "itemcode": itemcode,
        "device": {
            "init_channel": "internet",
            "ip": user_ip,
            "user_agent": "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:51.0) Gecko/20100101 Firefox/51.0",
            "accept_header": "application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9"
        }
    }
    
    payload_json = json.dumps(payload, separators=(',', ':'))
    
    # Step 2: Encrypt with JWE
    encrypted_data = encrypt_jwe_billdesk(payload_json, encryption_key, enc_key_id, enc_client_id)
    
    # Step 3: Sign the encrypted data with JWS
    signed_data = sign_jws_billdesk(encrypted_data, signing_key, sign_key_id, sign_client_id)
    
    return payload_json, encrypted_data, signed_data



def encrypt_billdesk_payment(merc_id, amount, return_url, encryption_key, key_id, client_id, 
                           currency="356", itemcode="DIRECT", user_ip="123.0.0.1"):
    """
    Create and encrypt a BillDesk payment payload
    
    Args:
        merc_id (str): Merchant ID provided by BillDesk
        amount (str): Payment amount
        return_url (str): Return URL after payment
        encryption_key (str): BillDesk provided encryption key
        key_id (str): BillDesk provided encryption key ID
        client_id (str): BillDesk provided client ID
        currency (str): Currency code (default: "356" for INR)
        itemcode (str): Item code (default: "DIRECT")
        user_ip (str): User's IP address
    
    Returns:
        str: JWE encrypted payment data
    """
    
    # Create the payment payload
    payload = {
        "orderid": f"TEST{int(time.time())}",   # unique order id
        "mercid": merc_id,
        "order_date": time.strftime("%Y-%m-%dT%H:%M:%S+05:30"),
        "amount": amount,
        "currency": currency,   # INR
        "ru": return_url,
        "itemcode": itemcode,
        "device": {
            "init_channel": "internet",
            "ip": user_ip,
            "user_agent": "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:51.0) Gecko/20100101 Firefox/51.0",
            "accept_header": "application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9"
        }
    }
    
    # Convert to JSON string
    payload_json = json.dumps(payload, separators=(',', ':'))
    
    # Encrypt and return
    return encrypt_jwe_billdesk(payload_json, encryption_key, key_id, client_id)




def encrypt_jwe_billdesk(response_string, encryption_key, key_id, client_id):
    """
    Encrypt data using JWE with AES-256-GCM for BillDesk integration
    
    Args:
        response_string (str): The payload to encrypt
        encryption_key (str): BillDesk provided encryption key
        key_id (str): BillDesk provided encryption key ID
        client_id (str): BillDesk provided client ID
    
    Returns:
        str: JWE encrypted string
    """
    
    # Convert encryption key to bytes (equivalent to encryptionKey.getBytes())
    key = encryption_key.encode('utf-8')
    
    # Ensure key is 32 bytes for AES-256
    if len(key) < 32:
        key = key.ljust(32, b'\0')
    elif len(key) > 32:
        key = key[:32]
    
    # Create JWE Header (equivalent to JWEHeader.Builder)
    jwe_header = {
        "alg": "dir",  # Direct encryption (JWEAlgorithm.DIR)
        "enc": "A256GCM",  # AES-256-GCM (EncryptionMethod.A256GCM)
        "kid": key_id,  # keyID
        "clientid": client_id  # custom parameter
    }
    
    # Base64URL encode the header
    header_json = json.dumps(jwe_header, separators=(',', ':'))
    header_b64 = base64.urlsafe_b64encode(header_json.encode('utf-8')).decode('utf-8').rstrip('=')
    
    # Generate a random 96-bit (12-byte) IV for GCM
    iv = os.urandom(12)
    
    # Encrypt the payload using AES-256-GCM
    cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
    encryptor = cipher.encryptor()
    
    # Additional Authenticated Data (AAD) is the encoded header
    aad = header_b64.encode('ascii')
    encryptor.authenticate_additional_data(aad)
    
    # Encrypt the payload
    payload_bytes = response_string.encode('utf-8')
    ciphertext = encryptor.update(payload_bytes) + encryptor.finalize()
    
    # Get the authentication tag
    auth_tag = encryptor.tag
    
    # Base64URL encode components
    encrypted_key = ""  # Empty for direct encryption (dir algorithm)
    iv_b64 = base64.urlsafe_b64encode(iv).decode('utf-8').rstrip('=')
    ciphertext_b64 = base64.urlsafe_b64encode(ciphertext).decode('utf-8').rstrip('=')
    tag_b64 = base64.urlsafe_b64encode(auth_tag).decode('utf-8').rstrip('=')
    
    # Construct JWE: header.encrypted_key.iv.ciphertext.tag
    jwe_encrypted_data = f"{header_b64}.{encrypted_key}.{iv_b64}.{ciphertext_b64}.{tag_b64}"
    
    return jwe_encrypted_data


def decrypt(encrypted_data, encryption_key):
    try:
        # Split the input to check format
        parts = encrypted_data.split('.')

        # Check if this is a JWS (3 parts) containing a JWE payload
        if len(parts) == 3:
            jws_header_b64, jws_payload_b64, jws_signature_b64 = parts
            jws_payload_padded = jws_payload_b64 + '=' * ((4 - len(jws_payload_b64) % 4) % 4)
            try:
                jwe_encrypted_data = base64.urlsafe_b64decode(jws_payload_padded).decode('utf-8')
            except Exception as e:
                raise ValueError(f"Failed to decode JWS payload: {str(e)}")

            # Optional: Verify JWS signature (uncomment if required)
            # from jwcrypto import jws, jwk
            # from django.conf import settings
            # jws_token = jws.JWS()
            # jws_token.deserialize(encrypted_data)
            # signing_key = jwk.JWK.from_pem(settings.BILLDESK_SIGNING_KEY.encode('utf-8'))
            # try:
            #     jws_token.verify(signing_key)
            # except Exception as e:
            #     raise ValueError(f"JWS signature verification failed: {str(e)}")

        elif len(parts) == 5:
            jwe_encrypted_data = encrypted_data
        else:
            raise ValueError(f"Invalid format - got {len(parts)} parts, expected 3 (JWS) or 5 (JWE)")

        # Process the JWE
        jwe_parts = jwe_encrypted_data.split('.')
        if len(jwe_parts) != 5:
            raise ValueError(f"Invalid JWE format - must have 5 parts, got {len(jwe_parts)}")

        header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64 = jwe_parts

        # Decode the header
        header_b64_padded = header_b64 + '=' * ((4 - len(header_b64) % 4) % 4)
        try:
            header_json = base64.urlsafe_b64decode(header_b64_padded).decode('utf-8')
            header = json.loads(header_json)
        except Exception as e:
            raise ValueError(f"Failed to decode or parse JWE header: {str(e)}")

        # Verify encryption method
        if header.get('alg') != 'dir':
            raise ValueError(f"Unsupported algorithm: {header.get('alg')}, expected 'dir'")
        if header.get('enc') != 'A256GCM':
            raise ValueError(f"Unsupported encryption method: {header.get('enc')}, expected 'A256GCM'")

        # Prepare the encryption key
        try:
            key = encryption_key.encode('utf-8')
            if encryption_key.startswith('b64:'):
                key = base64.b64decode(encryption_key[4:])
            if len(key) != 32:
                raise ValueError(f"Encryption key must be 32 bytes, got {len(key)}")
        except Exception as e:
            raise ValueError(f"Invalid encryption key: {str(e)}")

        # Decode components
        def base64url_decode(data):
            padded = data + '=' * ((4 - len(data) % 4) % 4)
            return base64.urlsafe_b64decode(padded)

        if encrypted_key_b64:
            raise ValueError("Expected empty encrypted key for 'dir' algorithm")

        try:
            iv = base64url_decode(iv_b64)
            ciphertext = base64url_decode(ciphertext_b64)
            auth_tag = base64url_decode(tag_b64)
        except Exception as e:
            raise ValueError(f"Failed to decode JWE components: {str(e)}")

        if len(iv) != 12:
            raise ValueError(f"Invalid IV length: {len(iv)}, expected 12")

        # Decrypt using AES-256-GCM
        try:
            cipher = Cipher(algorithms.AES(key), modes.GCM(iv, auth_tag), backend=default_backend())
            decryptor = cipher.decryptor()
            aad = header_b64.encode('ascii')
            decryptor.authenticate_additional_data(aad)
            decrypted_bytes = decryptor.update(ciphertext) + decryptor.finalize()
        except Exception as e:
            raise ValueError(f"Decryption failed: {str(e)}")

        # Convert to string and parse JSON
        try:
            decrypted_string = decrypted_bytes.decode('utf-8')
            decrypted_json = json.loads(decrypted_string)
        except json.JSONDecodeError as e:
            raise ValueError(f"Failed to parse decrypted JSON: {str(e)}")
        except Exception as e:
            raise ValueError(f"Failed to decode decrypted data: {str(e)}")

        # Validate JSON structure
        if not isinstance(decrypted_json, dict):
            raise ValueError(f"Decrypted data is not a dictionary: {decrypted_json}")
        return decrypted_json
    except:
        raise ValueError(f"Decryption failed: {str(e)}")


def decrypt_jwe_billdesk(encrypted_data, encryption_key):
    """
    Decrypt JWE data using AES-256-GCM for BillDesk integration
    Handles both direct JWE and nested JWS->JWE formats
    
    Args:
        encrypted_data (str): The JWE encrypted string or JWS containing JWE
        encryption_key (str): BillDesk provided encryption key
    
    Returns:
        dict: Decrypted and parsed JSON payload with 'bdorderid' key
    
    Raises:
        ValueError: If decryption, JSON parsing, or structure validation fails
    """
    try:
        decrypted_json = decrypt(encrypted_data, encryption_key)
        return decrypted_json
    except Exception as e:
        raise ValueError(f"Decryption failed: {str(e)}")

# Helper function to analyze the encrypted data structure
def analyze_encrypted_data(encrypted_data):
    """
    Analyze the structure of encrypted data from BillDesk
    """
    parts = encrypted_data.split('.')
    print(f"Number of parts: {len(parts)}")
    
    if len(parts) == 3:
        print("Format: JWS (JSON Web Signature)")
        jws_header_b64, jws_payload_b64, jws_signature_b64 = parts
        
        # Decode JWS header
        try:
            jws_header_padded = jws_header_b64 + '=' * (4 - len(jws_header_b64) % 4) % 4
            jws_header = json.loads(base64.urlsafe_b64decode(jws_header_padded).decode('utf-8'))
            print(f"JWS Header: {json.dumps(jws_header, indent=2)}")
        except:
            print("Could not decode JWS header")
        
        # Try to decode JWS payload
        try:
            jws_payload_padded = jws_payload_b64 + '=' * (4 - len(jws_payload_b64) % 4) % 4
            jws_payload = base64.urlsafe_b64decode(jws_payload_padded).decode('utf-8')
            print(f"JWS Payload (JWE): {jws_payload}")
            
            # Check if payload is JWE
            jwe_parts = jws_payload.split('.')
            print(f"JWE parts in payload: {len(jwe_parts)}")
            
            if len(jwe_parts) == 5:
                # Decode JWE header
                try:
                    jwe_header_padded = jwe_parts[0] + '=' * (4 - len(jwe_parts[0]) % 4) % 4
                    jwe_header = json.loads(base64.urlsafe_b64decode(jwe_header_padded).decode('utf-8'))
                    print(f"JWE Header: {json.dumps(jwe_header, indent=2)}")
                except:
                    print("Could not decode JWE header")
                    
        except Exception as e:
            print(f"Could not decode JWS payload: {e}")
            
    elif len(parts) == 5:
        print("Format: JWE (JSON Web Encryption)")
        # Decode JWE header
        try:
            header_padded = parts[0] + '=' * (4 - len(parts[0]) % 4) % 4
            header = json.loads(base64.urlsafe_b64decode(header_padded).decode('utf-8'))
            print(f"JWE Header: {json.dumps(header, indent=2)}")
        except:
            print("Could not decode JWE header")
    
    print("-" * 50)
# Example usage with BillDesk payload
# if __name__ == "__main__":
#
#     MERC_ID = "BDUAT2K676"          # UAT Merchant ID
#     CLIENT_ID = client_id =  "bduat2k676sj"      # Provided by BillDesk
#     KEY_ID =  key_id = "vstkEu52BWR9"            # Provided by BillDesk
#     ENCRYPTION_KEY = "e9Y5khgyMyluQRrH4XjgQLgZ9oZdwwk2"
#     SIGNING_KEY = "5ZvlXPsYeuk6VTJRlB5HyBOMpb4wPYQJ"
#
#
#
#     payload = {
#         "orderid": f"TEST{int(time.time())}",
#         "mercid": MERC_ID,
#         "order_date": time.strftime("%Y-%m-%dT%H:%M:%S+05:30"),
#         "amount": "29999.28",
#         "currency": "356",
#         "ru": "https://www.merchant.com",
#         "itemcode": "DIRECT",
#         "device": {
#             "init_channel": "internet",
#             "ip": "123.0.0.1",
#             "user_agent": "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:51.0) Gecko/20100101 Firefox/51.0",
#             "accept_header": "application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9"
#         }
#     }
#
#     payload_json = json.dumps(payload, separators=(',', ':'))
#     print(payload_json)
#     print("paylodddddddddddddddddddddddd")
#
#     encrypted_data = encrypt_jwe_billdesk(payload_json, ENCRYPTION_KEY, KEY_ID, CLIENT_ID)
#     print(encrypted_data)
#     print("encryptttttttttttttttttttttttttttttttttttttttttttttttttttttttt")
#
#
#     encrypted_payment = encrypt_billdesk_payment(
#         merc_id=MERC_ID,
#         amount="29999.28",
#         return_url="http://127.0.0.1:8000/verifypayment/",
#         encryption_key=ENCRYPTION_KEY,
#         key_id=KEY_ID,
#         client_id=CLIENT_ID
#     )
#     signed_encrypted_data = sign_jws_billdesk(encrypted_payment, SIGNING_KEY, KEY_ID, CLIENT_ID)
#     print(signed_encrypted_data)
#
#
#
#     headers = {
#     "Content-Type": "application/jose",
#     "Accept": "application/jose",
#     "BD-Traceid": str(uuid.uuid4()).replace("-", "")[:20],  # Unique ≤35 chars
#     "BD-Timestamp": str(int(time.time()))  # Epoch in seconds
#     }
#
#     # -------------------------------
#     # Step 4: Call BillDesk API
#     # -------------------------------
#     url = "https://uat1.billdesk.com/u2/payments/ve1_2/orders/create"
#     print("11111111111111111111111111111111111111111111111111111111111111")
#     print(headers)
#     print("11111111111111111111111111111111111111111111111111111111111111")
#     print(signed_encrypted_data)
#     response = requests.post(url, headers=headers, data=signed_encrypted_data)
#
#     decrypted_data = decrypt_jwe_billdesk(response.text, ENCRYPTION_KEY)
#     print(decrypted_data)
#       ENCRYPTION_KEY = "e9Y5khgyMyluQRrH4XjgQLgZ9oZdwwk2"
#       encrypted_data = "eyJhbGciOiJIUzI1NiIsImNsaWVudGlkIjoiYmR1YXQyazY3NnNqIiwia2lkIjoiSE1BQyJ9.ZXlKamJHbGxiblJwWkNJNkltSmtkV0YwTW1zMk56WnphaUlzSW1WdVl5STZJa0V5TlRaSFEwMGlMQ0poYkdjaU9pSmthWElpTENKcmFXUWlPaUoyYzNSclJYVTFNa0pYVWpraWZRLi5qM1E0UURjTmdYZ2t1MWdBLkFRRmdSSUJQZ0JNajZSNjUyUW1LTzVyU3pHNzVFQ3RyTGhVTms5bHF2NTRsajZpWlpFR0w3dUhhdXd5azZrQlZMR04zV0VsaUJ4UWtPWURKUld1bXRrdW9ublBnTVc5M004NUJIUV9SZjJUR080OTg4dXZQT0dieVZ4UDdyNDFfOEJQaUlUOGdfdTh2aFlCQTJqaEhnbi13VGxDazVDU091NU0tbkFreGNvajZWeVJEcVdUbzZYZ0o0c0RiNnpDYmwzOXo3VjA3SnNHTnV5b0hVTnpNYm5SY1JJbThLTWwyRWlTdzZtTFk4cGdwbmpNV21YaHR4WGFsYXhfQTVKVWRQekgyVVlhdGJ6bENpQ251Z2dtalJCaDhlcjVmRE9PdTN3YktrM2pTdm1yZDIxM0V2TmFRdlZieG9PLUFCeC1STXkxVjJyZjFwY0ZjR3dRU0gxS3AtMUg4a3dzWEgwT0FkejEzRkp1M0pSUjBkOVVWbndvZVBidFpCOWlzZGhkTmxjX3NKYjlxTjhzQmh6Q2JIbWpxMjNBYmZmUDdxVVpXVmIwNTFRN2ptN1YwdVhwemxGU3JidVlWWDZRMHNtamdfU0FuQUtNN0ZHODNodVA5UkpqazRpZmY3dGg4NlliZ0JYRXRMWVVOV0VYbUVWMU9pU2NzNkhLcmhJdldZVjM3eVk0aFk5Y3haU0hDdlhNaUNSUUxBWUxHeGRsbGhRUXh0eVV1by1oeU1rZ1NRdHpxY1B5SXB6bV9BWjNEVktMX2kwU1ZKRGJoSnNNckVDMGZva1ZKUTNkUjlUSXVrVGN0em9RUXJScUxaUXZJN3ZfRm9fdDVueDEtM2pjSm1BMjFZM21Da1lXaFRkUk9Ta1NYTW04TGpxcVpfeDRKVXVna0ZOdkNkNjMyeUczc0NqM0lDZklnTVc0RDhVQVNBTF9VUC1pckQ3LWY4RXVvcVYtNjdwaTZjYkxVRmswajJPWWo0RVlFeEpqQUVfTGJZc2hLRzhBNW9YNXNXSFFIb1dOdVNMRFhmQVJ3cWhBOG1iY2pKdy5LcTB1aEliRGhsdlZjU1EzMlZhTmx3.aaeTOlzNe1GewLgraXFriE1MYemDDsr9zo8B4srz0eU"
#       decrypted_data = decrypt_jwe_billdesk(encrypted_data, ENCRYPTION_KEY)
#       print(decrypted_data)
    
   
