from mathutils import Vector, Euler, Quaternion, Matrix
from pyedm_platform_selector import pyedm
from enum import Enum
from math_tools import euler_to_quat, ZERO_VEC3
from bpy.types import FCurve, Action, Object, AnimData
from typing import Union, Callable, Set, Tuple, List
import utils
import re
from armature_export import BoneNode

class Data_Path_Enum(str, Enum):
    LOCATION        = 'location'
    ROTATION_QUAT   = 'rotation_quaternion'
    ROTATION_EULER  = 'rotation_euler'
    SCALE           = 'scale'
    ENERGY          = 'energy'
    COLOR           = 'color'
    CUTOFF_DISTANCE = 'cutoff_distance'
    SPECULAR        = 'specular_factor'
    SPOT_SIZE       = 'spot_size'
    SPOT_BLEND      = 'spot_blend'

OBJ_PATHS = [
    Data_Path_Enum.ROTATION_QUAT,
    Data_Path_Enum.ROTATION_EULER,
    Data_Path_Enum.LOCATION,
    Data_Path_Enum.SCALE
]

DATA_PATHS = [
    Data_Path_Enum.ENERGY,
    Data_Path_Enum.COLOR,
    Data_Path_Enum.CUTOFF_DISTANCE,
    Data_Path_Enum.SPECULAR,
    Data_Path_Enum.SPOT_SIZE,
    Data_Path_Enum.SPOT_BLEND
]

def get_anim_ch_paths(action: Action) -> Set[str]:
    result: Set[str] = set()
    if not action:
        return result
    for fcurve in action.fcurves:
        result.add(fcurve.data_path)
    return result

class DummyFCurve:
    def __init__(self, array_index: int, val) -> None:
        self.keyframe_points = []
        self.array_index = array_index
        self.val = val
    
    def update(self):
        pass

    def evaluate(self, frame):
        return self.val

def has_transform_anim(obj: Object) -> bool:
    obj_ad = obj.animation_data
    if not obj_ad or not obj_ad.action:
        return False

    for fcu in obj_ad.action.fcurves:
         if fcu.data_path in OBJ_PATHS:
            return True
    return False

def has_data_anim(obj: Object) -> bool:
    if not hasattr(obj, 'data'):
        return False
    data_ad: AnimData = obj.data.animation_data
    if not data_ad or not data_ad.action:
        return False

    for fcu in data_ad.action.fcurves:
         if fcu.data_path in DATA_PATHS:
            return True
    return False

def has_path_anim(anim_data: AnimData, data_path: str) -> bool:
    if not anim_data: 
        return False
    if not data_path:
        return False
    if not anim_data.action:
        return False
    action: Action = anim_data.action
    for fcu in action.fcurves:
        if data_path == fcu.data_path:
            return True
    return False

KeyFrameTime = float
KeyFrameValue = Union[float, List[float]]
KeyFrameValueTransform = Callable[[KeyFrameValue], KeyFrameValue]
KeyFramePoint = Tuple[KeyFrameTime, KeyFrameValue]
KeyFramePoints = List[KeyFramePoint]

# Returns [(key, value), ...] for 1 animation element and [(key, [value1, value2, ...], ...] for multiple, or None
def fcurves_animation(fcurves: List[FCurve], expected_num: int, def_value: KeyFrameValue, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    if not def_value:
        def_value = [0] * expected_num
    
    # collect keys
    keys = set()
    for fcu in fcurves:
        fcu.update()
        for kf in fcu.keyframe_points:
            keys.add(kf.co[0])

    # collect values
    keys = sorted([x for x in keys])
    kvs = []
    if expected_num > 1:
        for k in keys:
            v = []
            for fcu in fcurves:
                if not type(fcu) is DummyFCurve:
                    val = fcu.evaluate(k)
                    v.append(val)
                else:
                    v.append(fcu.val)
            kvs.append((((k / 100.0) - 1.0), fn(v)))
    else:
        for k in keys:
            for fcu in fcurves:
                v = fcu.evaluate(k)
                kvs.append((((k / 100.0) - 1.0), fn(v)))
        
    return kvs

# Returns [(key, value), ...] for 1 animation element and [(key, [value1, value2, ...], ...] for multiple, or None
def action_animation(action: Action, data_path: str, expected_num: int, def_value: KeyFrameValue, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    if not action:
        return None
    
    if not def_value:
        def_value = [0] * expected_num
    
    fcurves = [DummyFCurve(i, def_value[i]) for i in range(expected_num)]

    not_dummy = False
    for fcu in action.fcurves:
        if data_path == fcu.data_path:
            not_dummy = True
            fcurves[fcu.array_index] = fcu

    if not not_dummy:
        return None
    
    return fcurves_animation(fcurves, expected_num, def_value, fn)

def extract_anim_float(action: Action, data_path: str, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    return action_animation(action, data_path, 1, None, fn)

def extract_anim_vec2(action: Action, data_path: str, def_value, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    return action_animation(action, data_path, 2, def_value, fn)

def extract_anim_vec3(action: Action, data_path: str, def_value, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    return action_animation(action, data_path, 3, def_value, fn)

def extract_anim_vec4(action: Action, data_path: str, def_value, fn: KeyFrameValueTransform = lambda v: v) -> KeyFramePoints:
    return action_animation(action, data_path, 4, def_value, fn)

def euler_to_quat_anim(rot_anim):
    a = []
    for i in rot_anim[1]:
        a.append((i[0], euler_to_quat(i[1])))

    return [rot_anim[0], a]

# allowed_args == None means any args are allowed
# Animation keys are in space before parenting applied (matrix_basis's space).
def extract_transform_animation(parent: pyedm.Node, obj: Object, allowed_args = None) -> Union[pyedm.AnimationNode, pyedm.Transform]:
    if not has_transform_anim(obj):
        parent = parent.addChild(pyedm.Transform(obj.name, obj.matrix_local))
        return parent
    
    action = obj.animation_data.action
    arg = utils.extract_arg_number(action.name)
    if arg < 0 or (allowed_args and arg not in allowed_args):
        parent = parent.addChild(pyedm.Transform(obj.name, obj.matrix_local))
        return parent

    mat = obj.matrix_local

    bmat = obj.matrix_basis
    bmat_inv = obj.matrix_basis.inverted()
    bloc, brot, bsca = bmat.decompose()
    euler_brot = brot.to_euler()

    al = ar = asc = None

    loc_keys = extract_anim_vec3(action, Data_Path_Enum.LOCATION, bloc)
    if loc_keys != None:
        al = pyedm.AnimationNode('al_' + obj.name)
        al.setPositionAnimation([[arg, loc_keys]])
    
    scale_keys = extract_anim_vec3(action, Data_Path_Enum.SCALE, bsca)
    if scale_keys:
        asc = pyedm.AnimationNode('as_' + obj.name)
        asc.setScaleAnimation([[arg, scale_keys]])

    rot_keys = extract_anim_vec3(action, Data_Path_Enum.ROTATION_EULER, euler_brot)
    if rot_keys != None:
        anim = euler_to_quat_anim([arg, rot_keys])
        ar = pyedm.AnimationNode('ar_' + obj.name)
        ar.setRotationAnimation([anim])
    else:
        rot_keys = extract_anim_vec4(action, Data_Path_Enum.ROTATION_QUAT, brot)
        if rot_keys:
            ar = pyedm.AnimationNode('ar_' + obj.name)
            ar.setRotationAnimation([[arg, rot_keys]])

    #al = ar = asc =  None
    a = parent
    
    a = a.addChild(pyedm.Transform(obj.name + '_mat', mat))
    a = a.addChild(pyedm.Transform(obj.name + '_bmat_inv', bmat_inv))

    if al:
        a = a.addChild(al)
    else:
        a = a.addChild(pyedm.Transform('tl_' + obj.name, Matrix.LocRotScale(bloc, None, None)))

    if asc:
        a = a.addChild(asc)
    else:
        a = a.addChild(pyedm.Transform('s_' + obj.name, Matrix.LocRotScale(None, None, bsca)))

    if ar:
        a = a.addChild(ar)
    else:
        a = a.addChild(pyedm.Transform('tr_' + obj.name, Matrix.LocRotScale(None, brot, None)))

    return a


bone_path_re_c = re.compile(r'.*\["(.*)"\]\.(.*)')
def split_data_path(data_path):
    m = re.match(bone_path_re_c, data_path)
    if not m:
        return (None, data_path)
    return (m.group(1), m.group(2))
    

# Returns start and end nodes.
def extract_bone_animation(parent: pyedm.Node, bone: BoneNode):
    edm_bone = pyedm.Bone(bone.name)
    mat = bone.mat
    mat_inv = bone.mat_inv

    edm_bone.setInvertedBaseBoneMatrix(mat_inv)

    if not bone.armature.animation_data or not bone.armature.animation_data.action:
        edm_bone.setMatrix(mat)
        return parent.addChild(edm_bone)
    
    action = bone.armature.animation_data.action
    arg = utils.extract_arg_number(action.name)
    if arg < 0:
        edm_bone.setMatrix(mat)
        return parent.addChild(edm_bone)

    loc, rot, sca = mat.decompose()

    fcu_loc = []
    fcu_rot = []
    for fc in action.fcurves:
        (name, dp) = split_data_path(fc.data_path)
        if bone.bone.name == name:
            if dp == Data_Path_Enum.LOCATION:
                fcu_loc.append(fc)
            elif dp == Data_Path_Enum.ROTATION_QUAT:
                fcu_rot.append(fc)

    bmat = bone.pbone.matrix
    bloc, brot, bsca = bmat.decompose()
    inv_brot = brot.inverted()

    amat = bone.armature.matrix_basis
    aloc, arot, asca = amat.decompose()
    inv_arot = arot.inverted()

    al = ar = None

    def f_loc(v):
        l = -loc  + Vector(v)
        #m = Matrix.LocRotScale(l, None, None)
        #l, _, _ = m.decompose()
        return l
    
    if fcu_loc:
        loc_keys = fcurves_animation(fcu_loc, 3, bloc, f_loc)
        if loc_keys != None:
            al = pyedm.AnimationNode('alb_' + bone.bone.name)
            al.setPositionAnimation([[arg, loc_keys]])

    if fcu_rot:
        rot_keys = fcurves_animation(fcu_rot, 4, arot, lambda v: rot @ inv_brot @ Quaternion(v))
        if rot_keys != None:
            ar = pyedm.AnimationNode('arb_' + bone.bone.name)
            ar.setRotationAnimation([[arg, rot_keys]])

#    al = ar = None

    if 1:
        if not al:
            parent = parent.addChild(pyedm.Transform('tlb_' + bone.bone.name, Matrix.LocRotScale(loc, None, None)))

        if not ar:
            parent = parent.addChild(pyedm.Transform('trb_' + bone.bone.name, Matrix.LocRotScale(None, rot, None)))

        if al:
            parent = parent.addChild(al)

        if ar:
            parent = parent.addChild(ar)

        parent = parent.addChild(pyedm.Transform('sb_' + bone.bone.name, Matrix.LocRotScale(None, None, sca)))
    else:
        if al:
            parent = parent.addChild(al)
        else:
            parent = parent.addChild(pyedm.Transform('tl_' + bone.bone.name, Matrix.LocRotScale(loc, None, None)))

        parent = parent.addChild(pyedm.Transform('s_' + bone.bone.name, Matrix.LocRotScale(None, None, sca)))
        #log.debug(f'{sca}  {bsca}')

        parent = parent.addChild(pyedm.Transform('sb_' + bone.bone.name, Matrix.LocRotScale(None, None, sca)))

        if ar:
            parent = parent.addChild(ar)
        else:
            parent = parent.addChild(pyedm.Transform('tr_' + bone.bone.name, Matrix.LocRotScale(None, rot, None)))

        
    #log.debug(f'{sca}  {bsca}')

    parent = parent.addChild(edm_bone)
    
    return parent