from parallels.core import messages
import logging
import threading
import Queue 
from contextlib import contextmanager

from parallels.core.workflow.runner.base import Runner
from parallels.core.actions.base.action_pointer import ActionPointer
from parallels.core.actions.base.common_action import CommonAction
from parallels.core.actions.base.condition_action_pointer import ConditionActionPointer
from parallels.core.actions.base.subscription_action import SubscriptionAction
from parallels.core.actions.base.compound_action import CompoundAction
from parallels.core.migrator_config import MultithreadingStatus
from parallels.core.utils.migrator_utils import trace_step

logger = logging.getLogger(__name__)


class ActionRunnerByLayer(Runner):
	"""Executes workflow actions and entry points by layer

	"By layer" means that:
	the 1st action is executed for each subscription
	the 2nd action is executed for each subscription
	and so on

	Considers their sequence, handles exceptions so non-critical failures do
	not break migration
	"""

	def __init__(self, multithreading=None):
		"""Class constructor
		:type multithreading: parallels.core.migrator_config.MultithreadingParams | None
		"""
		super(ActionRunnerByLayer, self).__init__(multithreading)
		self._executed_actions = set()

	def run(self, action, action_id=None, path=None):
		if path is None:
			path = []

		if isinstance(action, ActionPointer):
			action = action.resolve()
		elif isinstance(action, ConditionActionPointer):
			action = action.resolve(self._context)

		if isinstance(action, CompoundAction):
			self._run_compound_action(action, action_id, path)
		elif isinstance(action, SubscriptionAction):
			self._run_subscription_action(action_id, action)
		elif isinstance(action, CommonAction):
			self._run_action(action_id, action)
		else:
			raise Exception(
				messages.INVALID_ACTION_S_S_IS_NOT % (
					action_id,
					action.__class__.__name__
				)
			)

		self._executed_actions.add('/'.join(path))

	def _run_compound_action(self, action, action_id, path):
		with self._profile_action(action_id, action):
			with self._trace_action(action_id, action):
				for child_action_id, child_action in action.get_all_actions():
					self.run(child_action, child_action_id, path + [child_action_id])

	def _run_action(self, action_id, action):
		with self._profile_action(action_id, action):
			with self._trace_action(action_id, action):
				action.run(self._context)

	def _run_subscription_action(self, action_id, action):
		"""Run specified action on all migrated subscriptions

		This function also considers isolated safe execution for each
		subscription and logging.

		:type action: parallels.core.actions.base.subscription_action.SubscriptionAction
		"""
		with self._profile_action(action_id, action):
			logger.debug(messages.RUN_SUBSCRIPTION_ACTION_CLASS_S, action.__class__.__name__)
			logger.debug(action.get_description())
			subscriptions = filter(
				lambda s: action.filter_subscription(self._context, s),
				self._context.iter_all_subscriptions()
			)

			if len(subscriptions) > 0:
				with self._trace_action(action_id, action):
					if action.get_logging_properties().info_log:
						logging_function = logger.info
					else:
						logging_function = logger.debug

					action_threading_props = action.get_multithreading_properties()

					if self._multithreading.status == MultithreadingStatus.DISABLED:
						multithreading_enabled = False
					elif self._multithreading.status == MultithreadingStatus.DEFAULT:
						multithreading_enabled = action_threading_props.use_threads_by_default
					elif self._multithreading.status == MultithreadingStatus.FULL:
						multithreading_enabled = action_threading_props.can_use_threads
					else:
						assert False

					if not multithreading_enabled:
						for n, subscription in enumerate(subscriptions, start=1):
							logging_function(
								messages.LOG_PROCESS_SUBSCRIPTION,
								subscription.name, n, len(subscriptions)
							)
							self._run_single_subscription(action, subscription)
					else:
						queue = Queue.Queue()

						def worker():
							while True:
								try:
									subscription = queue.get(False)
								except Queue.Empty:
									return # no more subscriptions in a queue

								logging_function(
									messages.START_PROCESSING_SUBSCRIPTION_S, 
									subscription.name
								)
								self._run_single_subscription(action, subscription)
								logging_function(
									messages.FINISH_PROCESSING_SUBSCRIPTION_S, 
									subscription.name
								)
								queue.task_done()

						for subscription in subscriptions:
							queue.put(subscription)

						for i in range(0, self._multithreading.num_workers):
							thread = threading.Thread(target=worker)
							thread.start()

						queue.join()

	def _run_single_subscription(self, action, subscription):
		safe = self._context.safe

		safe.try_subscription_with_rerun(
			lambda: action.run(self._context, subscription),
			subscription.name,
			error_message=action.get_failure_message(
				self._context, subscription
			),
			repeat_error=messages.S_TRY_REPEAT_OPERATION_ONCE_MORE_1 % (
				action.get_failure_message(
					self._context, subscription
				)
			),
			repeat_count=action.get_repeat_count()
		)

	def _shutdown(self, entry_point):
		shutdown_paths = set()

		# add shutdown actions for each set-up action
		all_shutdown_actions = entry_point.get_shutdown_actions()
		for path in self._executed_actions:
			if path in all_shutdown_actions:
				shutdown_paths.update(all_shutdown_actions[path])

		# add overall shutdown actions
		if None in all_shutdown_actions:
			shutdown_paths.update(all_shutdown_actions[None])

		for path in shutdown_paths - self._executed_actions:
			self.run(entry_point.get_path(path))

	@contextmanager
	def _trace_action(self, action_id, action):
		if action.get_description() is not None:
			if action.get_logging_properties().info_log:
				log_level = 'info'
			else:
				log_level = 'debug'

			with trace_step(
				action_id, action.get_description(),
				log_level=log_level,
				compound=action.get_logging_properties().compound
			):
				yield
		else:
			yield

	def _get_subscription_action_queue(self, action, action_id, path):
		"""Get plain action list out of CompoundAction or SubscriptionAction"""

		actions = []
		if isinstance(action, CompoundAction):
			for child_action_id, child_action in action.get_all_actions():
				actions.extend(
					self._get_subscription_action_queue(
						child_action, child_action_id, path + [child_action_id]
					)
				)
			return actions
		elif isinstance(action, SubscriptionAction):
			return [(action, action_id, path)]
		elif isinstance(action, CommonAction):
			raise Exception(
				messages.CAN_NOT_RUN_COMMON_ACTION_IN)