Source code for tsuchinoko.execution.bluesky_adaptive

import pickle
import time
from typing import Tuple, List

import numpy as np
import zmq
from loguru import logger

from . import Engine
from ..adaptive.gpCAM_in_process import GPCAMInProcessEngine

SLEEP_FOR_AGENT_TIME = .1
SLEEP_FOR_TSUCHINOKO_TIME = .1
FORCE_KICKSTART_TIME = 5


[docs] class BlueskyAdaptiveEngine(Engine): """ A `tsuchinoko.execution.Engine` that sends targets to Blueskly-Adaptive and receives back measured data. """ suggest_blacklist = ["x_data", "y_data", "noise_variances"] # keys with ragged state def __init__(self, adaptive_engine:GPCAMInProcessEngine, host: str = '127.0.0.1', port: int = 5557): """ Parameters ---------- host A host address target for the zmq socket. port The port used for the zmq socket. """ super(BlueskyAdaptiveEngine, self).__init__() self.adaptive_engine = adaptive_engine self.position = None self.context = None self.socket = None self.host = host self.port = port self.setup_socket() self._last_targets_sent = None # Lock sending new points until at least one from the previous list is measured self.has_fresh_points_on_server = False def setup_socket(self): self.context = zmq.Context() self.socket = self.context.socket(zmq.PAIR) # Attempt to bind, retry every second if fails while True: try: self.socket.bind(f"tcp://{self.host}:{self.port}") except zmq.ZMQError as ex: logger.info(f'Unable to bind to tcp://{self.host}:{self.port}. Retrying in 1 second...') logger.exception(ex) time.sleep(1) else: logger.info(f'Bound to tcp://{self.host}:{self.port}.') break
[docs] def update_targets(self, targets: List[Tuple]): if self.has_fresh_points_on_server: time.sleep(SLEEP_FOR_AGENT_TIME) # chill if the Agent hasn't measured any points from the previous list else: # checkpoint optimizer state gpcam_state = self.adaptive_engine.optimizer.__getstate__() sanitized_gpcam_state = dict( **{key if key not in self.suggest_blacklist else f"STATEDICT-{key}": np.asarray(val) for key, val in gpcam_state.items() if key in self.suggest_blacklist}) # send targets to TsuchinokoAgent self.has_fresh_points_on_server = self.send_payload({'candidate': targets, 'optimizer': sanitized_gpcam_state}) self._last_targets_sent = targets
[docs] def get_measurements(self) -> List[Tuple]: new_measurements = [] # get newly completed measurements from bluesky-adaptive; repeat until buffered payloads are exhausted while True: try: payload = self.recv_payload(flags=zmq.NOBLOCK) except zmq.ZMQError: break else: assert 'target_measured' in payload x, (y, v) = payload['target_measured'] new_measurements.append((x, y, v, {})) # stash the last position measured as the 'current' position of the instrument self.position = x if new_measurements: self.has_fresh_points_on_server = False return new_measurements
[docs] def get_position(self) -> Tuple: # return last measurement position received from bluesky-adaptive return self.position
def send_payload(self, payload: dict): logger.info(f'message: {payload}') try: self.socket.send(pickle.dumps(payload), flags=zmq.NOBLOCK) except zmq.error.Again: return False return True def recv_payload(self, flags=0) -> dict: payload_response = pickle.loads(self.socket.recv(flags=flags)) logger.info(f'response: {payload_response}') # if the returned message is the kickstart message, resend the last targets sent and check for more payloads if payload_response == {'send_targets': True}: self.has_fresh_points_on_server = False self.update_targets(self._last_targets_sent) payload_response = self.recv_payload(flags) return payload_response