import logging
from contextlib import closing
import paramiko
from paramiko.ssh_exception import SSHException

from parallels.core import messages
from parallels.core.utils.steps_profiler import sleep
from parallels.core import MigrationError
from parallels.core.utils.common import poll_data, get_env_str

logger = logging.getLogger(__name__)

POLL_INTERVALS = [5, 5, 5]


class ExecutionError(Exception):
    def __init__(self, command, exit_status, stdout, stderr):
        msg = messages.SSH_EXECUTION_ERROR % (command, exit_status, stdout, stderr)
        Exception.__init__(self, msg)
        self.command = command
        self.exit_status = exit_status
        self.stdout = stdout
        self.stderr = stderr


# Use this to avoid missing keys warning at WARNING logging level (which makes migrator output a bit messy).
# Please note that "how should we handle missing keys" is a disputable question,
# and probably ignoring them is not the best way.
class IgnoreMissingKeyPolicy(paramiko.MissingHostKeyPolicy):
    def missing_host_key(self, client, hostname, key):
        pass


def run_unchecked(
    ssh_client, command, stdin_content=None, output_codepage='utf-8', error_policy='strict', env=None
):
    """
    command - is either an unicode string (u''), or byte string that consists of only ascii symbols
    """

    if env is not None:
        command = u'%s%s' % (get_env_str(env), command)

    def exec_command():
        try:
            return ssh_client.exec_command(command.encode('utf-8'))
        except SSHException as e:
            expected_messages = [u'Unable to open channel.', u'Channel closed.']
            if e.message in expected_messages:
                logger.debug(messages.LOG_EXCEPTION, exc_info=True)
                return None
            else:
                raise

    result = poll_data(exec_command, POLL_INTERVALS)
    if result is None:
        assert(ssh_client.hostname is not None)
        raise MigrationError(messages.UNABLE_OPEN_CHANNEL_SSH_CONNECTION_S % ssh_client.hostname)
    else:
        stdin, stdout, stderr = result

    channel = stdout.channel

    if stdin_content is not None:
        stdin.write(stdin_content)

    # Close stdin so program waiting for input will get EOF.
    # stdin.close() does nothing beside flushing and setting internal 'closed' flag,
    # so it is needed to explicitly close writing side of channel.
    stdin.close()
    channel.shutdown_write()

    # TODO: This will hang if tool produces a lot of stderr output.
    stdout_content = stdout.read().decode(output_codepage, error_policy)
    stderr_content = stderr.read().decode(output_codepage, error_policy)

    stdout.close()
    stderr.close()

    exit_status = channel.recv_exit_status()

    return exit_status, stdout_content, stderr_content


class SSHClientWithHostname(paramiko.SSHClient):
    def __init__(self):
        paramiko.SSHClient.__init__(self)
        self.hostname = None

    def connect(self, hostname, *args, **kw):
        self.hostname = hostname
        self.args = args
        self.kw = kw
        self._connect_multiple_attemtps()

    def _connect_multiple_attemtps(self):
        max_attempts = 5
        interval_between_attempts = 10

        for attempt in range(0, max_attempts):
            try:
                paramiko.SSHClient.connect(self, self.hostname, *self.args, **self.kw)
                if attempt > 0:
                    logger.info(messages.SUCCESSFULLY_CONNECTED_S_BY_SSH, self.hostname)
                return
            except IOError as e:
                logger.debug(messages.LOG_EXCEPTION, exc_info=True)
                if attempt >= max_attempts - 1:
                    raise MigrationError(
                        messages.UNABLE_CONNECT_HOST_BY_SSH_EXCEPTION.format(
                            host=self.hostname, exception=str(e)
                        )
                    )
                else:
                    logger.error(
                        messages.SSH_UNABLE_TO_CONNECT_TO_HOST.format(
                            host=self.hostname,
                            interval_between_attempts=interval_between_attempts,
                            exception=str(e)))
                    sleep(interval_between_attempts, messages.WAIT_RECONNECT_BY_SSH)

    def exec_command(self, *args, **kwargs):
        if self.get_transport() is None or not self.get_transport().is_active():
            logger.warning(messages.SSH_CONNECTION_S_WAS_UNEXPECTEDLY_CLOSED % self.hostname)
            self.reconnect()

        return super(SSHClientWithHostname, self).exec_command(*args, **kwargs)

    def reconnect(self):
        super(SSHClientWithHostname, self).close()
        self._connect_multiple_attemtps()


def connect(settings):
    client = SSHClientWithHostname()
    client.set_missing_host_key_policy(IgnoreMissingKeyPolicy())
    settings.ssh_auth.connect(settings.ip, client)
    return closing(client)
