import re
from pyedm_platform_selector import pyedm
import bpy
import os.path
import pickle
import copy
from collections import namedtuple
from typing import List, Union, Dict, NamedTuple, Callable, Type

import block_builder
from utils import md5, EDMPath, make_acro_map, make_socket_map
from enums import NodeGroupTypeEnum, NodeSocketInDefaultEnum, NodeSocketInDeckEnum, BpyShaderNode, NodeSocketCommonEnum
from serializer import SLink, SInput
from logger import log
from material_wrap import DefMaterialWrap, DeckMaterialWrap, GlassMaterialWrap, MirrorMaterialWrap, MaterialWrapCustomType, FakeOmniLightMaterialWrap, FakeSpotLightMaterialWrap
from mesh_storage import MeshStorage
from serializer_tools import MatDesc
from version_specific import InterfaceNodeSocket, get_version, IS_BLENDER_4, IS_BLENDER_3
from bpy.types import Object, ShaderNodeGroup, Material
from custom_sockets import TransparencyEnumItems, ShadowCasterEnumItems
from export_fake_lights import make_fake_omni_edm_mat_blocks, make_fake_spot_edm_mat_blocks

def filter_materials(edm_group_name: str) -> List[Material]:
    out: List[Material] = []
    for mat in bpy.data.materials:
        use_nodes = mat.use_nodes and mat.node_tree
        if not use_nodes:
            continue

        for bpy_node in mat.node_tree.nodes:
            if bpy_node.bl_idname in (BpyShaderNode.NODE_GROUP, BpyShaderNode.NODE_GROUP_EDM, BpyShaderNode.NODE_GROUP_DEFAULT, BpyShaderNode.NODE_GROUP_DECK, BpyShaderNode.NODE_GROUP_FAKE_OMNI, BpyShaderNode.NODE_GROUP_FAKE_SPOT) and bpy_node.node_tree and bpy_node.node_tree.name == edm_group_name:
                out.append(mat)
                
    return out

def filter_materials_re(edm_group_regex) -> List[Material]:
    out: List[Material] = []
    if not edm_group_regex:
        return out
    if not hasattr(bpy.data, "materials"):
        return out
    for mat in bpy.data.materials:
        use_nodes = mat.use_nodes and mat.node_tree
        if not use_nodes:
            continue

        for bpy_node in mat.node_tree.nodes:
            if bpy_node.bl_idname in (BpyShaderNode.NODE_GROUP, BpyShaderNode.NODE_GROUP_EDM, BpyShaderNode.NODE_GROUP_DEFAULT, BpyShaderNode.NODE_GROUP_DECK, BpyShaderNode.NODE_GROUP_FAKE_OMNI, BpyShaderNode.NODE_GROUP_FAKE_SPOT) and bpy_node.node_tree and re.match(edm_group_regex, bpy_node.node_tree.name):
                out.append(mat)
                
    return out

def process_def_links(old_links: List[SLink], old_version: int, material_name: str) -> List[SLink]:
    new_links: List[SLink] = []
    
    if old_version == 0 or old_version == 1:
        for link in old_links:
            if link.to_type == 'ShaderNodeGroup':
                if link.to_socket == 'Normal' or link.to_socket == 'Normal  (Non-Color)' or link.to_socket == 'Normal (Non color)':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDefaultEnum.NORMAL
            new_links.append(link)
    elif old_version <= 11:
        for link in old_links:
            if link.to_type == 'ShaderNodeGroup':
                if link.to_socket == 'Damage Color':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDefaultEnum.DAMAGE_COLOR
                elif link.to_socket == 'Damage Map':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDeckEnum.DAMAGE_MASK
                elif link.to_socket == 'Damage Map (Non-Color)':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDeckEnum.DAMAGE_MASK
                elif link.to_socket == 'Damage Normal':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDefaultEnum.DAMAGE_NORMAL
            new_links.append(link)
    else:
        for link in old_links:
            new_links.append(link)

    return new_links

def process_deck_links(old_links: List[SLink], old_version: int, material_name: str) -> List[SLink]:
    new_links: List[SLink] = []
    
    if old_version < 7:
        for link in old_links:
            if link.to_type == 'ShaderNodeGroup':
                if link.to_socket == 'Damage Color':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDefaultEnum.DAMAGE_COLOR
                elif link.to_socket == 'Damage Map':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDeckEnum.DAMAGE_MASK
                elif link.to_socket == 'Damage Map (Non-Color)':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDeckEnum.DAMAGE_MASK
                elif link.to_socket == 'Damage Normal':
                    link = copy.copy(link)
                    link.to_socket = NodeSocketInDeckEnum.DAMAGE_NORMAL
            new_links.append(link)
    else:
        for link in old_links:
            new_links.append(link)

    return new_links


def get_first_socket_by_name(sockest_list: List[SInput], name: str) -> Union[SInput, None]:
    for socket in sockest_list:
        if socket.name == name:
            return socket
    return None

def restore_defaults(old_sockest: List[SInput], new_node_group: ShaderNodeGroup, old_version: int, material_name: str) -> None:
    pass

def restore_def_mat_defaults(old_sockest: List[SInput], new_node_group: ShaderNodeGroup, old_version: int, material_name: str) -> None:
    version_new: int = get_version(new_node_group.node_tree)
    if old_version == 0:
        return
    if old_version < 8:
        for new_socket in new_node_group.inputs:
            old_socket_wrp: SInput = get_first_socket_by_name(old_sockest, new_socket.name)
            if not old_socket_wrp:
                continue
            if new_socket.name == NodeSocketInDefaultEnum.SHADOW_CASTER and (old_socket_wrp.bl_socket_idname == 'NodeSocketUndefined' or not old_socket_wrp.instance_value):
                new_socket.default_value = ShadowCasterEnumItems[0][0]
                continue
            if new_socket.name == NodeSocketInDefaultEnum.SHADOW_CASTER:
                new_socket.default_value = ShadowCasterEnumItems[old_socket_wrp.instance_value][0]
                continue
            if new_socket.name == NodeSocketInDefaultEnum.TRANSPARENCY and (old_socket_wrp.bl_socket_idname == 'NodeSocketUndefined' or not old_socket_wrp.instance_value):
                new_socket.default_value = TransparencyEnumItems[0][0]
                continue
            if new_socket.name == NodeSocketInDefaultEnum.TRANSPARENCY:
                new_socket.default_value = TransparencyEnumItems[old_socket_wrp.instance_value][0]
                continue
            if new_socket.name == NodeSocketCommonEnum.VERSION:
                new_socket.default_value = version_new
                continue
            if hasattr(old_socket_wrp, 'instance_value'):
                new_socket.default_value = old_socket_wrp.instance_value
    elif old_version >= 8:
        socket_map: Dict[str, Dict[str, str]] = make_socket_map(new_node_group)
        socket_acro_map: Dict[str, str] = make_acro_map(new_node_group)
        if socket_map.get(material_name):
            for socket_name in socket_map[material_name].keys():
                if new_node_group.bl_idname in (BpyShaderNode.NODE_GROUP_DEFAULT, BpyShaderNode.NODE_GROUP_DECK, BpyShaderNode.NODE_GROUP_FAKE_OMNI, BpyShaderNode.NODE_GROUP_FAKE_SPOT):
                    prop_name: str = socket_acro_map.get(socket_name)
                    old_socket_wrp: SInput = get_first_socket_by_name(old_sockest, socket_name)
                    if hasattr(old_socket_wrp, 'instance_value'):
                        setattr(new_node_group, prop_name, old_socket_wrp.instance_value)
                else:
                    enum_name: str = socket_map[material_name][socket_name]
                    old_socket_wrp: SInput = get_first_socket_by_name(old_sockest, socket_name)
                    if old_socket_wrp and hasattr(old_socket_wrp, 'instance_value'):
                        setattr(new_node_group, enum_name, old_socket_wrp.instance_value)
        for new_socket in new_node_group.inputs:
            old_socket_wrp: SInput = get_first_socket_by_name(old_sockest, new_socket.name)
            if not old_socket_wrp:
                continue
            if new_socket.name == NodeSocketCommonEnum.VERSION:
                new_socket.default_value = version_new
                continue
            if old_socket_wrp.instance_value:
                if IS_BLENDER_3:
                    new_socket.default_value = old_socket_wrp.instance_value
                elif IS_BLENDER_4:
                    if new_socket.name not in (NodeSocketInDefaultEnum.TRANSPARENCY, NodeSocketInDefaultEnum.SHADOW_CASTER):
                        new_socket.default_value = old_socket_wrp.instance_value

def restore_deck_mat_defaults(old_sockest: List[SInput], new_node_group: ShaderNodeGroup, old_version: int, material_name: str) -> None:
    version_new: int = get_version(new_node_group.node_tree)
    if old_version >= 1:
        for new_socket in new_node_group.inputs:
            old_socket_wrp: SInput = get_first_socket_by_name(old_sockest, new_socket.name)
            if not old_socket_wrp:
                continue
            if new_socket.name == NodeSocketCommonEnum.VERSION:
                new_socket.default_value = version_new
                continue
            if old_socket_wrp.instance_value and new_socket.name not in (NodeSocketInDeckEnum.TRANSPARENCY): 
                new_socket.default_value = old_socket_wrp.instance_value

def process_general_links(old_links: List[SLink], old_version: int, material_name: str) -> List[SLink]:
    return old_links

DefCall = Callable[[Object, DefMaterialWrap, MeshStorage], pyedm.PBRNode]
DeckCall = Callable[[Object, DeckMaterialWrap, MeshStorage], pyedm.DeckNode]
FakeOmniCall = Callable[[Object, FakeOmniLightMaterialWrap, MeshStorage], pyedm.FakeOmniLights]
FakeSpotCall = Callable[[Object, FakeSpotLightMaterialWrap, MeshStorage], pyedm.FakeSpotLights]
GlassCall = Callable[[Object, GlassMaterialWrap, MeshStorage], pyedm.PBRNode]
MirrorCall = Callable[[Object, MirrorMaterialWrap, MeshStorage], pyedm.MirrorNode]
BuildBlocksCall = Union[DefCall, DeckCall, GlassCall, MirrorCall, FakeOmniCall, FakeSpotCall]
MatNamedTupleType = NamedTuple(
    "Mat",
    [
        ('name', str),
        ('build_blocks', BuildBlocksCall),
        ('factory', Type[MaterialWrapCustomType]),
        ('description_file_name', str),
        ('process_links', Callable[[List[SLink], int, str], None]),
        ('restore_defaults', Callable[[List[SInput], ShaderNodeGroup, int, str], None])
    ]
)

Mat: MatNamedTupleType = namedtuple("Mat", "name build_blocks factory description_file_name process_links restore_defaults")

MATERIALS: Dict[NodeGroupTypeEnum, MatNamedTupleType] = {
    NodeGroupTypeEnum.DEFAULT : Mat(
        NodeGroupTypeEnum.DEFAULT,
        block_builder.make_def_edm_mat_blocks,
        DefMaterialWrap,
        'data/EDM_Default_Material.pickle',
        process_def_links,
        restore_def_mat_defaults
    ),
    NodeGroupTypeEnum.DECK : Mat(
        NodeGroupTypeEnum.DECK,
        block_builder.make_deck_edm_mat_blocks,
        DeckMaterialWrap,
        'data/EDM_Deck_Material.pickle',
        process_deck_links,
        restore_deck_mat_defaults
    ),
    NodeGroupTypeEnum.FAKE_OMNI : Mat(
        NodeGroupTypeEnum.FAKE_OMNI,
        make_fake_omni_edm_mat_blocks,
        FakeOmniLightMaterialWrap,
        'data/EDM_Fake_Omni_Material.pickle',
        process_general_links,
        restore_defaults
    ),
    NodeGroupTypeEnum.FAKE_SPOT : Mat(
        NodeGroupTypeEnum.FAKE_SPOT,
        make_fake_spot_edm_mat_blocks,
        FakeSpotLightMaterialWrap,
        'data/EDM_Fake_Spot_Material.pickle',
        process_general_links,
        restore_defaults
    ),
    # NodeGroupTypeEnum.GLASS : Mat(
    #     NodeGroupTypeEnum.GLASS,
    #     block_builder.make_glass_edm_mat_blocks,
    #     GlassMaterialWrap,
    #     'data/EDM_Glass_Material.pickle',
    #     process_general_links
    # ),
    # NodeGroupTypeEnum.MIRROR : Mat(
    #     NodeGroupTypeEnum.MIRROR,
    #     block_builder.make_mirror_edm_mat_blocks,
    #     MirrorMaterialWrap,
    #     'data/EDM_Mirror_Material.pickle',
    #     process_general_links
    # )
}

def check_if_referenced_file(blend_file_name):
    bf = os.path.splitext(os.path.basename(blend_file_name))[0].lower()
    for m in MATERIALS.keys():
        if m.lower() == bf:
            return True
    return False

def check_md5(material_name: NodeGroupTypeEnum, material_desc: MatDesc):
    node_tree_name = re.compile(f'[A-Za-z0-9_-]*{material_name}[A-Za-z0-9_.-]*')
    mat_list: List[Material] = filter_materials_re(node_tree_name)
    if mat_list:
        blend_file_name: str = 'data/' + str(material_name.value) + '.blend'
        blend_file_path: str = os.path.join(EDMPath.full_plugin_path, blend_file_name)
        blend_file_md5: str = md5(blend_file_path)
        if hasattr(material_desc, 'blend_file_md5') and material_desc.blend_file_md5 != blend_file_md5:
            log.fatal(f"Hash of material file {blend_file_path} is invalid.")
        
def build_material_descriptions(materials: Dict[NodeGroupTypeEnum, MatNamedTupleType]) -> Dict[NodeGroupTypeEnum, MatDesc]:
    material_descs: Dict[NodeGroupTypeEnum, MatDesc] = {}
    for name, mat in materials.items():
        try:
            pickle_file_name: str = os.path.join(EDMPath.full_plugin_path, mat.description_file_name)
            with open(pickle_file_name, 'rb') as f:
                material_descs[name] = pickle.load(f)
            
            check_md5(name, material_descs)
        except OSError as e:
            log.error(f"Can't open material description file: {mat.description_file_name}. Reason: {e}.")
            continue
    return material_descs