# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import sys
import shutil
import zipfile
from paddle.utils.download import get_path_from_url

__all__ = []

DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub'))


def _remove_if_exists(path):
    if os.path.exists(path):
        if os.path.isfile(path):
            os.remove(path)
        else:
            shutil.rmtree(path)


def _import_module(name, repo_dir):
    sys.path.insert(0, repo_dir)
    try:
        hub_module = __import__(name)
        sys.modules.pop(name)
    except ImportError:
        sys.path.remove(repo_dir)
        raise RuntimeError(
            'Please make sure config exists or repo error messages above fixed when importing'
        )

    sys.path.remove(repo_dir)

    return hub_module


def _git_archive_link(repo_owner, repo_name, branch, source):
    if source == 'github':
        return 'https://github.com/{}/{}/archive/{}.zip'.format(
            repo_owner, repo_name, branch)
    elif source == 'gitee':
        return 'https://gitee.com/{}/{}/repository/archive/{}.zip'.format(
            repo_owner, repo_name, branch)


def _parse_repo_info(repo, source):
    branch = 'main' if source == 'github' else 'master'
    if ':' in repo:
        repo_info, branch = repo.split(':')
    else:
        repo_info = repo
    repo_owner, repo_name = repo_info.split('/')
    return repo_owner, repo_name, branch


def _make_dirs(dirname):
    try:
        from pathlib import Path
    except ImportError:
        from pathlib2 import Path
    Path(dirname).mkdir(exist_ok=True)


def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
    # Setup hub_dir to save downloaded files
    hub_dir = HUB_DIR

    _make_dirs(hub_dir)

    # Parse github/gitee repo information
    repo_owner, repo_name, branch = _parse_repo_info(repo, source)
    # Github allows branch name with slash '/',
    # this causes confusion with path on both Linux and Windows.
    # Backslash is not allowed in Github branch name so no need to
    # to worry about it.
    normalized_br = branch.replace('/', '_')
    # Github renames folder repo/v1.x.x to repo-1.x.x
    # We don't know the repo name before downloading the zip file
    # and inspect name from it.
    # To check if cached repo exists, we need to normalize folder names.
    repo_dir = os.path.join(hub_dir,
                            '_'.join([repo_owner, repo_name, normalized_br]))

    use_cache = (not force_reload) and os.path.exists(repo_dir)

    if use_cache:
        if verbose:
            sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
    else:
        cached_file = os.path.join(hub_dir, normalized_br + '.zip')
        _remove_if_exists(cached_file)

        url = _git_archive_link(repo_owner, repo_name, branch, source=source)

        fpath = get_path_from_url(
            url,
            hub_dir,
            check_exist=not force_reload,
            decompress=False,
            method=('wget' if source == 'gitee' else 'get'))
        shutil.move(fpath, cached_file)

        with zipfile.ZipFile(cached_file) as cached_zipfile:
            extraced_repo_name = cached_zipfile.infolist()[0].filename
            extracted_repo = os.path.join(hub_dir, extraced_repo_name)
            _remove_if_exists(extracted_repo)
            # Unzip the code and rename the base folder
            cached_zipfile.extractall(hub_dir)

        _remove_if_exists(cached_file)
        _remove_if_exists(repo_dir)
        # rename the repo
        shutil.move(extracted_repo, repo_dir)

    return repo_dir


def _load_entry_from_hubconf(m, name):
    '''load entry from hubconf
    '''
    if not isinstance(name, str):
        raise ValueError(
            'Invalid input: model should be a str of function name')

    func = getattr(m, name, None)

    if func is None or not callable(func):
        raise RuntimeError('Cannot find callable {} in hubconf'.format(name))

    return func


def _check_module_exists(name):
    try:
        __import__(name)
        return True
    except ImportError:
        return False


def _check_dependencies(m):
    dependencies = getattr(m, VAR_DEPENDENCY, None)

    if dependencies is not None:
        missing_deps = [
            pkg for pkg in dependencies if not _check_module_exists(pkg)
        ]
        if len(missing_deps):
            raise RuntimeError('Missing dependencies: {}'.format(
                ', '.join(missing_deps)))


def list(repo_dir, source='github', force_reload=False):
    r"""
    List all entrypoints available in `github` hubconf.

    Args:
        repo_dir(str): github or local path.

            github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
            tag/branch. The default branch is `main` if not specified.
            
            local path (str): local repo path
        
        source (str): `github` | `gitee` | `local`, default is `github`.
        force_reload (bool, optional): whether to discard the existing cache and force a fresh download, default is `False`.
    Returns:
        entrypoints: a list of available entrypoint names

    Example:
        .. code-block:: python

            import paddle

            paddle.hub.list('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False)

    """
    if source not in ('github', 'gitee', 'local'):
        raise ValueError(
            'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'
            .format(source))

    if source in ('github', 'gitee'):
        repo_dir = _get_cache_or_reload(repo_dir,
                                        force_reload,
                                        True,
                                        source=source)

    hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

    entrypoints = [
        f for f in dir(hub_module)
        if callable(getattr(hub_module, f)) and not f.startswith('_')
    ]

    return entrypoints


def help(repo_dir, model, source='github', force_reload=False):
    """
    Show help information of model

    Args:
        repo_dir(str): github or local path.

            github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
            tag/branch. The default branch is `main` if not specified.
            
            local path (str): local repo path.
        
        model (str): model name.
        source (str): `github` | `gitee` | `local`, default is `github`.
        force_reload (bool, optional): default is `False`.
    Return:
        docs

    Example:
        .. code-block:: python

            import paddle

            paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github')

    """
    if source not in ('github', 'gitee', 'local'):
        raise ValueError(
            'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'
            .format(source))

    if source in ('github', 'gitee'):
        repo_dir = _get_cache_or_reload(repo_dir,
                                        force_reload,
                                        True,
                                        source=source)

    hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

    entry = _load_entry_from_hubconf(hub_module, model)

    return entry.__doc__


def load(repo_dir, model, source='github', force_reload=False, **kwargs):
    """
    Load model

    Args:
        repo_dir(str): github or local path.

            github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
            tag/branch. The default branch is `main` if not specified.

            local path (str): local repo path.
        
        model (str): model name.
        source (str): `github` | `gitee` | `local`, default is `github`.
        force_reload (bool, optional): default is `False`.
        **kwargs: parameters using for model
    Return:
        paddle model
    Example:
        .. code-block:: python

            import paddle
            paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github')

    """
    if source not in ('github', 'gitee', 'local'):
        raise ValueError(
            'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'
            .format(source))

    if source in ('github', 'gitee'):
        repo_dir = _get_cache_or_reload(repo_dir,
                                        force_reload,
                                        True,
                                        source=source)

    hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)

    _check_dependencies(hub_module)

    entry = _load_entry_from_hubconf(hub_module, model)

    return entry(**kwargs)
