from collections import defaultdict
from contextlib import contextmanager
from copy import copy
from dataclasses import dataclass, field, asdict
from abc import ABC, abstractmethod
from typing import Tuple, Iterable, Set, List, Union
from loguru import logger
from pyqtgraph.parametertree import Parameter
from tsuchinoko.graphs import Graph
from tsuchinoko.utils.mutex import RWLock
@dataclass
class Data:
"""
A data class to track the data state of an experiment. This type can be appended to as new data is received. A
locking mechanism is provided to support read/write locking for parallelism.
"""
dimensionality: int = None
positions: list = field(default_factory=list)
scores: list = field(default_factory=list)
variances: list = field(default_factory=list)
metrics: dict = field(default_factory=lambda: defaultdict(list))
states: dict = field(default_factory=dict)
graphics_items: dict = field(default_factory=dict)
@property
def measurements(self):
return list(zip(self.positions, self.scores, self.variances, [{key: values[i] for key, values in self.metrics.items()} for i in range(len(self))]))
def __post_init__(self):
self._lock = RWLock()
self.w_lock = self._lock.w_locked
self.r_lock = self._lock.r_locked
self._completed_iterations = 0
def inject_new(self, data):
with self.w_lock():
for datum in data:
self.positions.append(datum[0])
self.scores.append(datum[1])
self.variances.append(datum[2])
for metric in datum[3]: # TODO: handle logical cases
if metric not in self.metrics:
self.metrics[metric] = []
self.metrics[metric].append(datum[3][metric])
def as_dict(self):
self_copy = copy(self)
self_copy.metrics = dict(self_copy.metrics)
return asdict(self_copy)
def __getitem__(self, item: Union[slice, str]):
if isinstance(item, str):
if item in self.metrics and item in self.states:
raise ValueError(f'{item} exists in both states and metrics.')
elif item in self.metrics:
return self.metrics[item]
elif item in self.states:
return self.states[item]
elif item.lower() in ['variances', 'scores', 'positions']:
return getattr(self, item.lower())
elif isinstance(item, slice):
if item.stop is None or item.stop > len(self):
item = slice(item.start, len(self), item.step)
return Data(self.dimensionality,
self.positions[item],
self.scores[item],
self.variances[item],
{key: value[item] for key, value in self.metrics.items()},
self.states,
self.graphics_items)
raise ValueError(f'Unknown item: {item}')
def __setitem__(self, key, value):
if isinstance(key, str):
self.metrics[key] = value
else:
raise ValueError()
def __len__(self):
return len(self.positions)
def extend(self, data: 'Data'):
with self.w_lock():
self.positions += data.positions
self.scores += data.scores
self.variances += data.variances
for key in set(self.metrics) | set(data.metrics):
self.metrics[key] += data.metrics.get(key, [])
self.dimensionality = data.dimensionality
self.graphics_items.update(data.graphics_items)
self.states = data.states
def __enter__(self):
self.w_lock().__enter__()
logger.exception(RuntimeError("deprecation in progress"))
def __exit__(self, exc_type, exc_val, exc_tb):
self.w_lock().__exit__(exc_type, exc_val, exc_tb)
def __bool__(self):
return bool(len(self))
@contextmanager
def iteration(self):
yield
self._completed_iterations += 1
[docs]
class Engine(ABC):
"""
The Adaptive Engine base class. This component is generally to be responsible for determining future measurement targets.
"""
dimensionality: int = None
parameters: Parameter = None
graphs: List['Graph'] = None
last_position: tuple = None
[docs]
@abstractmethod
def update_measurements(self, data: Data):
"""
Update internal variables with the provided new data
"""
...
[docs]
@abstractmethod
def request_targets(self, position: Tuple) -> Iterable[Tuple]:
"""
Determine new targets to be measured
Parameters
----------
position: tuple
The current 'position' of the experiment in the target domain.
Returns
-------
targets: array_like
The new targets to be measured.
"""
...
[docs]
@abstractmethod
def reset(self):
"""
Called when an experiment stops, or is about to start. Returns the engine to a clean state.
"""
...
[docs]
@abstractmethod
def train(self):
"""
Perform training. This can be short-circuited to only train on every N-th iteration, for example.
"""
...
[docs]
@abstractmethod
def update_metrics(self, data: Data):
"""
Calculates various metrics to drive visualizations for the client. The data object is expected to be mutated to
include these new values.
"""
...