Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def queue_elements(bec_client_mock):
client = bec_client_mock
request_msg = messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
Expand Down Expand Up @@ -52,7 +52,7 @@ def queue_elements(bec_client_mock):
def sample_request_msg():
return messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_available_req_blocks_multiple_blocks(bec_client_mock):

request_msg = messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "test_rid"},
)
Expand Down
20 changes: 11 additions & 9 deletions bec_ipython_client/tests/client_tests/test_live_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def client_with_grid_scan(bec_client_mock):
client = bec_client_mock
request_msg = messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_sort_devices(self):
(
messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
),
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_wait_for_request_acceptance(self, client_with_grid_scan):
def test_run_update(self, bec_client_mock, scan_item):
request_msg = messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
Expand All @@ -161,7 +161,7 @@ def test_run_update(self, bec_client_mock, scan_item):
def test_run_update_without_monitored_devices(self, bec_client_mock, scan_item):
request_msg = messages.ScanQueueMessage(
scan_type="grid_scan",
parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}},
parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
Expand Down Expand Up @@ -303,8 +303,10 @@ def test_print_table_data_hinted_value_with_precision(
@pytest.mark.parametrize(
"value,expected",
[
(np.int32(1), "1.00"),
(np.float64(1.00000), "1.00"),
# Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish
# lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly
# (np.int32(1), "1.00"),
# (np.float64(1.00000), "1.00"),
(0, "0.00"),
(1, "1.00"),
(0.000, "0.00"),
Expand All @@ -314,10 +316,10 @@ def test_print_table_data_hinted_value_with_precision(
("False", "False"),
("0", "0"),
("1", "1"),
((0, 1), "(0, 1)"),
# ((0, 1), "(0, 1)"),
({"value": 0}, "{'value': 0}"),
(np.array([0, 1]), "[0 1]"),
({1, 2}, "{1, 2}"),
# (np.array([0, 1]), "[0 1]"),
# ({1, 2}, "{1, 2}"),
],
)
def test_print_table_data_variants(self, client_with_grid_scan, value, expected):
Expand Down
2 changes: 1 addition & 1 deletion bec_ipython_client/tests/end-2-end/test_scans_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def test_scan_repeat_decorator(bec_ipython_client_fixture):
"update_frequency": 400,
},
"readoutPriority": "baseline",
"deviceTags": {"user motors"},
"deviceTags": ["user motors"],
"enabled": True,
"readOnly": False,
}
Expand Down
8 changes: 4 additions & 4 deletions bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,13 +563,13 @@ def test_image_analysis(bec_client_lib):
dev.eiger.sim.select_model("gaussian")
dev.eiger.sim.params = {
"amplitude": 100,
"center_offset": np.array([0, 0]),
"covariance": np.array([[1, 0], [0, 1]]),
"center_offset": [0, 0],
"covariance": [[1, 0], [0, 1]],
"noise": "uniform",
"noise_multiplier": 10,
"hot_pixel_coords": np.array([[24, 24], [50, 20], [4, 40]]),
"hot_pixel_coords": [[24, 24], [50, 20], [4, 40]],
"hot_pixel_types": ["fluctuating", "constant", "fluctuating"],
"hot_pixel_values": np.array([1000.0, 10000.0, 1000.0]),
"hot_pixel_values": [1000.0, 10000.0, 1000.0],
}

res = scans.line_scan(dev.samx, -5, 5, steps=10, relative=False, exp_time=0)
Expand Down
16 changes: 12 additions & 4 deletions bec_lib/bec_lib/atlas_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
import hashlib
import json
from enum import Enum
from typing import AbstractSet, Any, Literal, TypeVar

from pydantic import BaseModel, Field, PrivateAttr, create_model, field_validator, model_validator
from typing import AbstractSet, Annotated, Any, Literal, TypeVar

from pydantic import (
BaseModel,
Field,
PlainSerializer,
PrivateAttr,
create_model,
field_validator,
model_validator,
)
from pydantic_core import PydanticUndefined

from bec_lib.utils.json_extended import ExtendedEncoder
Expand Down Expand Up @@ -42,7 +50,7 @@ class _DeviceModelCore(BaseModel):
deviceConfig: dict | None = None
connectionTimeout: float = 5.0
description: str = ""
deviceTags: set[str] = set()
deviceTags: Annotated[set[str], Field(default_factory=set), PlainSerializer(list)]
needs: list[str] = []
onFailure: Literal["buffer", "retry", "raise"] = "retry"
readOnly: bool = False
Expand Down
33 changes: 33 additions & 0 deletions bec_lib/bec_lib/bec_serializable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, computed_field


class BecCodecInfo(BaseModel):
type_name: str


class BECSerializable(BaseModel):
"""A base class for serializable BEC objects, especially BEC messages.
Fields in subclasses which use non-primitive types must be in structured,
type-hinted objects, and their encoders and JSON schema should be defined in
this class."""

model_config = ConfigDict(
json_schema_serialization_defaults_required=True,
arbitrary_types_allowed=True,
extra="forbid",
)

@computed_field()
@property
def bec_codec(self) -> BecCodecInfo:
return BecCodecInfo(type_name=self.__class__.__name__)

def __eq__(self, other):
if type(other) is not type(self):
return False
try:
np.testing.assert_equal(self.model_dump(), other.model_dump())
return True
except AssertionError:
return False
24 changes: 0 additions & 24 deletions bec_lib/bec_lib/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,6 @@ def decode(type_name: str, data: str) -> str:
return data


class PydanticEncoder(BECCodec):
obj_type: Type = BaseModel

@staticmethod
def encode(obj: BaseModel) -> dict:
return obj.model_dump()

@staticmethod
def decode(type_name: str, data: dict) -> dict:
return data


class EndpointInfoEncoder(BECCodec):
obj_type: Type = EndpointInfo

Expand All @@ -130,18 +118,6 @@ def decode(type_name: str, data: dict) -> EndpointInfo:
)


class SetEncoder(BECCodec):
obj_type: Type = set

@staticmethod
def encode(obj: set) -> list:
return list(obj)

@staticmethod
def decode(type_name: str, data: list) -> set:
return set(data)


class BECTypeEncoder(BECCodec):
obj_type: Type = type

Expand Down
8 changes: 6 additions & 2 deletions bec_lib/bec_lib/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from bec_lib.endpoints import MessageEndpoints
from bec_lib.file_utils import DeviceConfigWriter
from bec_lib.logger import bec_logger
from bec_lib.messages import ConfigAction
from bec_lib.messages import ConfigAction, sanitize_one_way_encodable
from bec_lib.utils.import_utils import lazy_import_from
from bec_lib.utils.json_extended import ExtendedEncoder

Expand Down Expand Up @@ -617,7 +617,11 @@ def send_config_request(
request_id = str(uuid.uuid4())
self._connector.send(
MessageEndpoints.device_config_request(),
DeviceConfigMessage(action=action, config=config, metadata={"RID": request_id}),
DeviceConfigMessage(
action=action,
config=sanitize_one_way_encodable(config),
metadata={"RID": request_id},
),
)

if wait_for_response:
Expand Down
27 changes: 16 additions & 11 deletions bec_lib/bec_lib/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable

import numpy as np
from pydantic import ConfigDict
from pydantic import ConfigDict, field_serializer
from rich.console import Console
from rich.table import Table
from typeguard import typechecked
Expand Down Expand Up @@ -169,7 +169,6 @@ def wait(self, timeout=None, raise_on_failure=True):
raise_on_failure (bool, optional): If True, an RPCError is raised if the request fails. Defaults to True.
"""
try:

if not self._status_done.wait(timeout):
raise TimeoutError("The request has not been completed within the specified time.")
finally:
Expand All @@ -185,11 +184,19 @@ def wait(self, timeout=None, raise_on_failure=True):
class _PermissiveDeviceModel(_DeviceModelCore):
model_config = ConfigDict(extra="allow")

@field_serializer("deviceTags")
def serialize_devicetags(self, value: set[str], info):
if info.mode == "json":
return list[value]
else:
return value


def set_device_config(device: "DeviceBase", config: dict | _PermissiveDeviceModel | None):
# device._config = config
device._config = ( # pylint: disable=protected-access
_PermissiveDeviceModel.model_validate(config).model_dump() if config is not None else None
_PermissiveDeviceModel.model_validate(config).model_dump(mode="python")
if config is not None
else None
)


Expand Down Expand Up @@ -346,7 +353,7 @@ def _prepare_rpc_msg(
client: BECClient = self.root.parent.parent
msg = messages.ScanQueueMessage(
scan_type="device_rpc",
parameter=params,
parameter=messages.sanitize_one_way_encodable(params),
queue=client.queue.get_default_scan_queue(), # type: ignore
metadata={"RID": request_id, "response": True},
)
Expand Down Expand Up @@ -742,7 +749,6 @@ def _compile_rich_str(obj: DeviceBase) -> str | None:


class DeviceBaseWithConfig(DeviceBase):

@property
def full_name(self):
"""Returns the full name of the device or signal, separated by "_" e.g. samx_velocity"""
Expand Down Expand Up @@ -777,10 +783,10 @@ def _update_config(self, update: dict) -> None:
action="update", config={self.name: update}
)

def get_device_tags(self) -> list:
def get_device_tags(self) -> set[str]:
"""get the device tags for this device"""
# pylint: disable=protected-access
return self.root._config.get("deviceTags", [])
return self.root._config.get("deviceTags", {})

@typechecked
def set_device_tags(self, val: Iterable):
Expand Down Expand Up @@ -1164,8 +1170,8 @@ def limits(self):
if not limit_msg:
return [0, 0]
limits = [
limit_msg.content["signals"].get("low", {}).get("value", 0),
limit_msg.content["signals"].get("high", {}).get("value", 0),
limit_msg.signals.get("low", {}).get("value", 0),
limit_msg.signals.get("high", {}).get("value", 0),
]
return limits

Expand Down Expand Up @@ -1203,7 +1209,6 @@ class Signal(AdjustableMixin, OphydInterfaceBase):


class ComputedSignal(Signal):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._num_args_method = None
Expand Down
10 changes: 8 additions & 2 deletions bec_lib/bec_lib/devicemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _rgetattr_safe(obj, attr, *args):
return None


class DeviceContainer(dict):
class DeviceContainer(dict[str, DeviceBase]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for arg in args:
Expand Down Expand Up @@ -693,15 +693,21 @@ def _get_redis_device_config(self) -> list:
return devices.content["resource"]

def _add_multiple_devices_with_log(self, devices: Iterable[tuple[dict, DeviceInfoMessage]]):
override = self._allow_override
try:
override = self._allow_override
self._allow_override = True
logs = (self._add_device(*conf_msg) for conf_msg in devices if conf_msg is not None)
if set(logs) == {None}:
logger.warning("No devices added!")
return
logger.info(f"Adding new devices:\n" + ", ".join(f"{name}: {t}" for name, t in logs)) # type: ignore # filtered
finally:
self._allow_override = override

def _add_device(self, dev: dict, msg: DeviceInfoMessage) -> tuple[str, str] | None:
if msg is None:
logger.error(f"No device info in Redis for: {dev}")
return None
name = msg.content["device"]
info = msg.content["info"]

Expand Down
Loading
Loading