Source code for tsuchinoko.core

import threading
import time
from asyncio import events
from enum import Enum, auto
from queue import Queue

from loguru import logger

from .messages import FullDataRequest, FullDataResponse, PartialDataRequest, PartialDataResponse, StartRequest, \
    UnknownResponse, PauseRequest, StateRequest, GetParametersRequest, SetParameterRequest, GetParametersResponse, \
    SetParameterResponse, StopRequest, StateResponse, MeasureRequest, \
    MeasureResponse, ConnectRequest, ConnectResponse, ExceptionResponse, PushDataRequest, PushDataResponse, \
    GraphsResponse, ReplayResponse
from ..adaptive import Engine as AdaptiveEngine, Data
from ..execution import Engine as ExecutionEngine
from ..utils.logging import log_time


class CoreState(Enum):
    Connecting = auto()
    Inactive = auto()
    Starting = auto()
    Running = auto()
    Pausing = auto()
    Paused = auto()
    Resuming = auto()
    Stopping = auto()
    Restarting = auto()
    Exiting = auto()


SLEEP_FOR_FRESH_DATA_TIME = .1


[docs] class Core: def __init__(self, execution_engine: ExecutionEngine = None, adaptive_engine: AdaptiveEngine = None): self.execution_engine = execution_engine self.adaptive_engine = adaptive_engine self.iteration = 0 self._state = CoreState.Inactive self._exception_queue = Queue() self._forced_position_queue = Queue() self._forced_measurement_queue = Queue() self._has_fresh_data = True self.data = Data() self._graphs = [] self.experiment_thread = None @property def state(self): return self._state @state.setter def state(self, value): logger.info(f'Changing core state to {value}') self._state = value def set_execution_engine(self, engine: ExecutionEngine): self.execution_engine = engine def set_adaptive_engine(self, engine: AdaptiveEngine): self.adaptive_engine = engine def main(self, debug=False): loop = events.new_event_loop() # <---- this ensures the current loop is replaced try: events.set_event_loop(loop) loop.set_debug(debug) return loop.run_until_complete(self._main()) finally: try: # _cancel_all_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) finally: events.set_event_loop(None) loop.close() async def _main(self, min_response_sleep=.1): while self.state != CoreState.Exiting: if self.state == CoreState.Running: pass # await sleep(min_response_sleep) # short-circuit case elif self.state == CoreState.Starting: if not len(self.data): self.data = Data(dimensionality=self.adaptive_engine.dimensionality) self.adaptive_engine.reset() self.experiment_thread = threading.Thread(target=self.experiment_loop, args=()) # must hold ref self.experiment_thread.start() self.state = CoreState.Running elif self.state == CoreState.Inactive: pass # await sleep(min_response_sleep) elif self.state == CoreState.Paused: pass # await sleep(min_response_sleep) elif self.state == CoreState.Pausing: self.state = CoreState.Paused elif self.state == CoreState.Resuming: self.state = CoreState.Running elif self.state == CoreState.Stopping: self.state = CoreState.Inactive self.data = Data() # await sleep(min_response_sleep) if self.state not in [CoreState.Stopping, CoreState.Exiting, CoreState.Resuming, CoreState.Restarting]: await self.notify_clients() def experiment_loop(self): while True: if self.state == CoreState.Running: logger.info(f'Iteration: {self.data._completed_iterations}, Data count: {len(self.data)}') try: self.experiment_iteration() except Exception as ex: self._exception_queue.put(ex) self.state = CoreState.Pausing logger.exception(ex) elif self.state in [CoreState.Stopping, CoreState.Inactive, CoreState.Exiting]: return else: time.sleep(.1) def experiment_iteration(self): with self.data.iteration(): if self._has_fresh_data: with log_time('getting position', cumulative_key='getting position'): position = self.execution_engine.get_position() if position is None: position = [0] * self.data.dimensionality position = tuple(position) if self._forced_position_queue.empty(): with log_time('getting targets', cumulative_key='getting targets'): targets = self.adaptive_engine.request_targets(position) else: targets = [self._forced_position_queue.get()] if self._forced_measurement_queue.empty(): if self._has_fresh_data: with log_time('updating targets', cumulative_key='updating targets'): self.execution_engine.update_targets(targets) self._has_fresh_data = False with log_time('getting measurements', cumulative_key='getting measurements'): new_measurements = self.execution_engine.get_measurements() else: new_measurements = [self._forced_measurement_queue.get()] if len(new_measurements): self._has_fresh_data = True with log_time('stashing new measurements', cumulative_key='injecting new measurements'): self.data.inject_new(new_measurements) with log_time('updating engine with new measurements', cumulative_key='updating engine with new measurements'): self.adaptive_engine.update_measurements(self.data) with log_time('updating metrics', cumulative_key='updating metrics'): self.adaptive_engine.update_metrics(self.data) else: time.sleep(SLEEP_FOR_FRESH_DATA_TIME) if self._has_fresh_data: with log_time('training', cumulative_key='training'): self.adaptive_engine.train() else: logger.info('Current data is stale. Waiting for an update with fresh data.') async def notify_clients(self): ... @property def graphs(self): execution_graphs = getattr(self.execution_engine, 'graphs', []) or [] adaptive_graphs = getattr(self.adaptive_engine, 'graphs', []) or [] return execution_graphs + adaptive_graphs + self._graphs @graphs.setter def graphs(self, graphs): raise NotImplementedError('Updating graphs on server not supported yet.') def update_graph(self, new_graph): execution_graphs = getattr(self.execution_engine, 'graphs', []) or [] adaptive_graphs = getattr(self.adaptive_engine, 'graphs', []) or [] self_graphs = self._graphs for graph_list in [execution_graphs, adaptive_graphs, self_graphs]: for i, old_graph in enumerate(graph_list): if old_graph.id == new_graph.id: graph_list[i] = new_graph return else: raise ValueError('Graph not found in graphs lists.') def initialize_data(self, x, y, v): with log_time('updating engine with initial measurements'): self.data = Data(dimensionality=len(x[0]), positions=x, scores=y, variances=v) self.adaptive_engine.update_measurements(self.data)
class ZMQCore(Core): def __init__(self, *args, **kwargs): super(ZMQCore, self).__init__(*args, **kwargs) # self.start_server() self.context = None self.poller = None def start_server(self): import zmq from zmq.asyncio import Context, Poller self.poller = Poller() self.context = Context() socket = self.context.socket(zmq.REP) socket.bind("tcp://*:5555") self.poller.register(socket, zmq.POLLIN) def respond_FullDataRequest(self, request): with self.data.r_lock(): return FullDataResponse(self.data.as_dict()) def respond_PartialDataRequest(self, request): if self.data and request.iteration <= len(self.data) and self.state == CoreState.Running: with self.data.r_lock(): partial_data = self.data[request.iteration:] return PartialDataResponse(partial_data.as_dict(), request.iteration) else: return StateResponse(self.state) def respond_PushDataRequest(self, request): self.data = Data(**request.data) return PushDataResponse() def respond_StartRequest(self, request): if self.state == CoreState.Paused: self.state = CoreState.Resuming elif self.state == CoreState.Inactive: self.state = CoreState.Starting return StateResponse(self.state) def respond_StopRequest(self, request): self.state = CoreState.Stopping self.experiment_thread.join() return StateResponse(self.state) def respond_ExitRequest(self, request): self.state = CoreState.Exiting return StateResponse(self.state) def respond_PauseRequest(self, request): self.state = CoreState.Pausing return StateResponse(self.state) def respond_StateRequest(self, request): if not self._exception_queue.empty(): return ExceptionResponse(self._exception_queue.get()) else: return StateResponse(self.state) def respond_GetParametersRequest(self, request): return GetParametersResponse(self.adaptive_engine.parameters.saveState()) def respond_SetParameterRequest(self, request): self.adaptive_engine.parameters.child(*request.child_path).setValue(request.value) return SetParameterResponse(True) def respond_MeasureRequest(self, request): self._forced_position_queue.put(request.position) return MeasureResponse(True) def respond_ConnectRequest(self, request): return ConnectResponse(self.state) def respond_PullGraphsRequest(self, request): return GraphsResponse(self.graphs) def respond_PushGraphsRequest(self, request): for graph in request.graphs: try: self.update_graph(graph) except ValueError as ex: return ExceptionResponse("Graph ID not found in server's graphs.") # self.graphs = request.graphs return StateResponse(self.state) def respond_ReplayRequest(self, request): self._forced_measurement_queue.queue.clear() self._forced_position_queue.queue.clear() for position in request.positions: self._forced_position_queue.put(position) for measurement in request.measurements: self._forced_measurement_queue.put(measurement) logger.critical(f'Queue lengths: {len(self._forced_measurement_queue.queue)} {len(self._forced_position_queue.queue)}') return ReplayResponse(True) async def notify_clients(self): import zmq if not self.poller: self.start_server() sockets = dict(await self.poller.poll(timeout=.1)) for socket in sockets: try: request = await socket.recv_pyobj(zmq.NOBLOCK) except zmq.ZMQError as ex: logger.exception(ex) else: if not request: time.sleep(.1) continue logger.info(f"Received request: {request}") with log_time('preparing response', cumulative_key='preparing response'): responder = getattr(self, f'respond_{request.__class__.__name__}', None) if responder: try: response = responder(request) except Exception as ex: response = ExceptionResponse(ex) else: response = UnknownResponse() logger.info(f'Sending response: {response}') await socket.send_pyobj(response) if isinstance(response, UnknownResponse): logger.exception(ValueError(f'Unknown request received: {request}')) time.sleep(.1) def exit_later(self): self.state = CoreState.Exiting def exit(self): self.exit_later() self.experiment_thread.join()