diff --git a/setup.py b/setup.py index a2598dd..fd83dae 100644 --- a/setup.py +++ b/setup.py @@ -1,29 +1,9 @@ -from setuptools import setup, find_packages, Extension -from Cython.Build import cythonize +from setuptools import setup, find_packages from pathlib import Path -import os - -extra_compile_args = [] -if os.name == "nt": # Windows - extra_compile_args = ["/w"] -else: # macOS and Linux - extra_compile_args = [ - "-Wno-unreachable-code", - "-Wno-unreachable-code-fallthrough", - "-O3", - ] this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -extensions = [ - Extension( - "synapse.utils.ndtp", - ["synapse/utils/ndtp.pyx"], - extra_compile_args=extra_compile_args, - ), -] - setup( name="science-synapse", version="2.4.0", @@ -33,14 +13,9 @@ packages=find_packages(include=["synapse", "synapse.*"]), long_description=long_description, long_description_content_type="text/markdown", - ext_modules=cythonize( - extensions, - compiler_directives={"language_level": "3"}, - ), python_requires=">=3.9", install_requires=[ "coolname", - "crcmod", "dearpygui", "grpcio-tools", "numexpr>=2.8.7", diff --git a/synapse/nodes/__init__.py b/synapse/nodes/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/synapse/server/nodes/base.py b/synapse/server/nodes/base.py index 42a4143..265ee13 100644 --- a/synapse/server/nodes/base.py +++ b/synapse/server/nodes/base.py @@ -4,7 +4,7 @@ from synapse.api.node_pb2 import NodeConfig, NodeType from synapse.server.status import Status -from synapse.utils.ndtp_types import SynapseData +from synapse.utils.types import SynapseData class BaseNode(object): diff --git a/synapse/server/nodes/spectral_filter.py b/synapse/server/nodes/spectral_filter.py index 3c1ad96..b077b19 100644 --- a/synapse/server/nodes/spectral_filter.py +++ b/synapse/server/nodes/spectral_filter.py @@ -11,7 +11,7 @@ ) from synapse.server.nodes.base import BaseNode from synapse.server.status import Status -from synapse.utils.ndtp_types import ElectricalBroadbandData, SynapseData +from synapse.utils.types import ElectricalBroadbandData, SynapseData def get_filter_coefficients(method, low_cutoff_hz, high_cutoff_hz, sample_rate): diff --git a/synapse/simulator/nodes/broadband_source.py b/synapse/simulator/nodes/broadband_source.py index 469a04f..bbe4deb 100644 --- a/synapse/simulator/nodes/broadband_source.py +++ b/synapse/simulator/nodes/broadband_source.py @@ -10,7 +10,7 @@ from synapse.api.tap_pb2 import TapConnection, TapType from synapse.api.datatype_pb2 import BroadbandFrame from synapse.server.status import Status -from synapse.utils.ndtp_types import ElectricalBroadbandData +from synapse.utils.types import ElectricalBroadbandData def r_sample(bit_width: int): return random.randint(0, 2**bit_width - 1) diff --git a/synapse/simulator/nodes/spike_source.py b/synapse/simulator/nodes/spike_source.py index 5afb7ee..5574fb3 100644 --- a/synapse/simulator/nodes/spike_source.py +++ b/synapse/simulator/nodes/spike_source.py @@ -6,7 +6,7 @@ from synapse.api.nodes.spike_source_pb2 import SpikeSourceConfig from synapse.server.nodes.base import BaseNode from synapse.server.status import Status -from synapse.utils.ndtp_types import SpiketrainData +from synapse.utils.types import SpiketrainData def r_sample(bit_width: int): return random.randint(0, 2**bit_width - 1) diff --git a/synapse/tests/simulator/test_broadband_source.py b/synapse/tests/simulator/test_broadband_source.py index 0f2e8c9..53188da 100644 --- a/synapse/tests/simulator/test_broadband_source.py +++ b/synapse/tests/simulator/test_broadband_source.py @@ -4,7 +4,7 @@ from synapse.simulator.nodes.broadband_source import BroadbandSource from synapse.api.nodes.signal_config_pb2 import SignalConfig, ElectrodeConfig from synapse.api.channel_pb2 import Channel -from synapse.utils.ndtp_types import ElectricalBroadbandData +from synapse.utils.types import ElectricalBroadbandData @pytest.mark.asyncio diff --git a/synapse/tests/test_ndtp.py b/synapse/tests/test_ndtp.py deleted file mode 100644 index e3d694d..0000000 --- a/synapse/tests/test_ndtp.py +++ /dev/null @@ -1,436 +0,0 @@ -import struct - -import pytest - -from synapse.api.datatype_pb2 import DataType -from synapse.utils.ndtp import ( - NDTP_VERSION, - NDTPHeader, - NDTPMessage, - NDTPPayloadBroadband, - NDTPPayloadBroadbandChannelData, - NDTPPayloadSpiketrain, - to_bytes, - to_ints, -) - - -def test_to_bytes(): - assert to_bytes([1, 2, 3, 0], bit_width=2) == (bytearray(b"\x6C"), 0) - - assert to_bytes([1, 2, 3, 2, 1], bit_width=2) == (bytearray(b"\x6E\x40"), 2) - - assert to_bytes([7, 5, 3, 1], bit_width=12) == ( - bytearray(b"\x00\x70\x05\x00\x30\x01"), - 0, - ) - - assert to_bytes([-7, -5, -3, -1], bit_width=12, is_signed=True) == ( - bytearray(b"\xFF\x9F\xFB\xFF\xDF\xFF"), - 0, - ) - - assert to_bytes( - [7, 5, 3], bit_width=12, existing=bytearray(b"\x01\x00"), writing_bit_offset=4 - ) == (bytearray(b"\x01\x00\x07\x00\x50\x03"), 0) - - assert to_bytes( - [-7, -5, -3], - bit_width=12, - existing=bytearray(b"\x01\x00"), - writing_bit_offset=4, - is_signed=True, - ) == (bytearray(b"\x01\x0F\xF9\xFF\xBF\xFD"), 0) - - assert to_bytes([7, 5, 3], bit_width=12) == (bytearray(b"\x00p\x05\x000"), 4) - - assert to_bytes([1, 2, 3, 4], bit_width=8) == (bytearray(b"\x01\x02\x03\x04"), 0) - - res, offset = to_bytes([7, 5, 3], bit_width=12) - assert res == bytearray(b"\x00p\x05\x000") - assert len(res) == 5 - assert offset == 4 - - res, offset = to_bytes( - [3, 5, 7], bit_width=12, existing=res, writing_bit_offset=offset - ) - assert res == bytearray(b"\x00\x70\x05\x00\x30\x03\x00\x50\x07") - assert len(res) == 9 - assert offset == 0 - - # 8 doesn't fit in 3 bits - with pytest.raises(ValueError): - to_bytes([8], 3) - - # Invalid bit width - with pytest.raises(ValueError): - to_bytes([1, 2, 3, 0], 0) - - -def test_to_ints(): - res, offset, _ = to_ints(b"\x6C", 2) - assert res == [1, 2, 3, 0] - assert offset == 8 - - res, offset, _ = to_ints(b"\x6C", 2, 3) - assert res == [1, 2, 3] - assert offset == 6 - - res, offset, _ = to_ints(b"\x00\x70\x05\x00\x30\x01", 12) - assert res == [7, 5, 3, 1] - assert offset == 48 - - res, offset, _ = to_ints(b"\x6C", 2, 3, 2) - assert res == [2, 3, 0] - assert offset == 6 + 2 - - res, offset, _ = to_ints(b"\x00\x07\x00\x50\x03", 12, 3, 4) - assert res == [7, 5, 3] - assert offset == 36 + 4 - - res, offset, _ = to_ints(b"\xFF\xF9\xFF\xBF\xFD", 12, 3, 4, is_signed=True) - assert res == [-7, -5, -3] - assert offset == 36 + 4 - - arry = bytearray(b"\x6E\x40") - res, offset, arry = to_ints(arry, 2, 1) - assert res == [1] - assert offset == 2 - - res, offset, arry = to_ints(arry, 2, 1, offset) - assert res == [2] - assert offset == 4 - - res, offset, arry = to_ints(arry, 2, 1, offset) - assert res == [3] - assert offset == 6 - - res, offset, arry = to_ints(arry, 2, 1, offset) - assert res == [2] - assert offset == 8 - - # Invalid bit width - with pytest.raises(ValueError): - to_ints(b"\x01", 0) - - # Incomplete value - with pytest.raises(ValueError): - to_ints(b"\x01", 3) - - # Insufficient data - with pytest.raises(ValueError): - to_ints(b"\x01\x02", 3) - - -def test_ndtp_payload_broadband(): - bit_width = 12 - sample_rate = 3 - is_signed = False - channels = [ - NDTPPayloadBroadbandChannelData( - channel_id=0, - channel_data=[1, 2, 3], - ), - NDTPPayloadBroadbandChannelData( - channel_id=1, - channel_data=[4, 5, 6], - ), - NDTPPayloadBroadbandChannelData( - channel_id=2, - channel_data=[3000, 2000, 1000], - ), - ] - - payload = NDTPPayloadBroadband(is_signed, bit_width, sample_rate, channels) - p = payload.pack() - hexstring = " ".join(f"{i:02x}" for i in p) - assert hexstring == "18 00 00 03 00 00 03 00 00 00 00 03 00 10 02 00 30 00 00 10 00 30 04 00 50 06 00 00 02 00 03 bb 87 d0 3e 80" - - assert p[0] == (bit_width << 1) | (is_signed << 0) - - # number of channels - assert p[1] == 0 - assert p[2] == 0 - assert p[3] == 3 - - # sample rate - assert p[4] == 0 - assert p[5] == 0 - assert p[6] == 3 - - # ch 0 channel_id, 0 (24 bits, 3 bytes) - assert p[7] == 0 - assert p[8] == 0 - assert p[9] == 0 - - # ch 0 num_samples, 3 (16 bits, 2 bytes) - assert p[10] == 0 - assert p[11] == 3 - - # ch 0 channel_data, 1, 2, 3 (12 bits, 1.5 bytes each) - # 0000 0000 0001 0000 0000 0010 0000 0000 0011 .... - assert p[12] == 0 - assert p[13] == 16 - assert p[14] == 2 - assert p[15] == 0 - assert p[16] >= 3 - - # ch 1 channel_id, 1 (24 bits, 3 bytes, starting from 4 bit offset) - # 0011 0000 0000 0000 0000 0000 0001 .... - assert p[16] == 48 - assert p[17] == 0 - assert p[18] == 0 - assert p[19] >= 16 - - # ch 1 num_samples, 3 (16 bits, 2 bytes, starting from 4 bit offset) - # 0001 0000 0000 0000 0011 .... - assert p[19] == 16 - assert p[20] == 0 - assert p[21] >= 48 - - # ch 1 channel_data, 4, 5, 6 (12 bits, 1.5 bytes each) - # 0011 0000 0000 0100 0000 0000 0101 0000 0000 0110 - assert p[21] == 48 - assert p[22] == 4 - assert p[23] == 0 - assert p[24] == 80 - assert p[25] >= 6 - - # ch 2 channel_id, 2 (24 bits, 3 bytes) - # 0000 0000 0000 0000 0000 0010 - assert p[26] == 0 - assert p[27] == 0 - assert p[28] == 2 - - # ch 2 num_samples, 3 (16 bits, 2 bytes) - # 0000 0000 0000 0011 - assert p[29] == 0 - assert p[30] == 3 - - # ch 2 channel_data, 3000, 2000, 1000 (12 bits, 1.5 bytes each) - # 1011 1011 1000 0111 1101 0000 0011 1110 1000 .... - assert p[31] == 187 - assert p[32] == 135 - assert p[33] == 208 - assert p[34] == 62 - assert p[35] >= 128 - - u = NDTPPayloadBroadband.unpack(p) - assert u.bit_width == bit_width - assert u.is_signed == is_signed - assert len(u.channels) == 3 - - assert u.channels[0].channel_id == 0 - assert list(u.channels[0].channel_data) == [1, 2, 3] - - assert u.channels[1].channel_id == 1 - assert list(u.channels[1].channel_data) == [4, 5, 6] - - assert u.channels[2].channel_id == 2 - assert list(u.channels[2].channel_data) == [3000, 2000, 1000] - - assert p[0] >> 1 == bit_width - - assert (p[1] << 16) | (p[2] << 8) | p[3] == 3 - p = p[7:] - - unpacked, offset, p = to_ints(p, bit_width=24, count=1) - assert unpacked[0] == 0 - assert offset == 24 - - unpacked, offset, p = to_ints(p, bit_width=16, count=1, start_bit=offset) - assert unpacked[0] == 3 - assert offset == 16 - - unpacked, offset, p = to_ints(p, bit_width=bit_width, count=3, start_bit=offset) - assert unpacked == [1, 2, 3] - assert offset == 36 - -def test_ndtp_payload_broadband_large(): - n_samples = 20000 - bit_width = 16 - sample_rate = 100000 - is_signed = False - channels = [ - NDTPPayloadBroadbandChannelData( - channel_id=0, - channel_data=[i for i in range(n_samples)], - ), - NDTPPayloadBroadbandChannelData( - channel_id=1, - channel_data=[i + 1 for i in range(n_samples)], - ), - NDTPPayloadBroadbandChannelData( - channel_id=2, - channel_data=[i + 2 for i in range(n_samples)], - ), - ] - - payload = NDTPPayloadBroadband(is_signed, bit_width, sample_rate, channels) - packed = payload.pack() - - unpacked = NDTPPayloadBroadband.unpack(packed) - assert unpacked.bit_width == bit_width - assert unpacked.is_signed == is_signed - assert len(unpacked.channels) == 3 - - assert unpacked.channels[0].channel_id == 0 - assert list(unpacked.channels[0].channel_data) == [i for i in range(n_samples)] - - assert unpacked.channels[1].channel_id == 1 - assert list(unpacked.channels[1].channel_data) == [i + 1 for i in range(n_samples)] - - assert unpacked.channels[2].channel_id == 2 - assert list(unpacked.channels[2].channel_data) == [i + 2 for i in range(n_samples)] - - -def test_ndtp_payload_spiketrain(): - samples = [0, 1, 2, 3, 4, 5, 6] - - payload = NDTPPayloadSpiketrain(10, samples) - packed = payload.pack() - hexstring = " ".join(f"{i:02x}" for i in packed) - print(hexstring) - - assert packed[0] == 0 - assert packed[1] == 0 - assert packed[2] == 0 - assert packed[3] == 7 - assert packed[4] == 10 - - # 0000 0001 0010 0011 0100 0101 0110 0000 - assert packed[5] == 1 - assert packed[6] == 35 - assert packed[7] == 69 - assert packed[8] == 96 - - - unpacked = NDTPPayloadSpiketrain.unpack(packed) - - assert unpacked == payload - assert unpacked.bin_size_ms == 10 - assert list(unpacked.spike_counts) == samples - - print("2s") - samples = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] - - payload = NDTPPayloadSpiketrain(10, samples) - packed = payload.pack() - hexstring = " ".join(f"{i:02x}" for i in packed) - print(hexstring) - - assert packed[3] == 10 - assert packed[4] == 10 - assert packed[5] == 34 - assert packed[6] == 34 - assert packed[7] == 34 - assert packed[8] == 34 - assert packed[8] == 34 - - unpacked = NDTPPayloadSpiketrain.unpack(packed) - - assert unpacked == payload - assert unpacked.bin_size_ms == 10 - assert list(unpacked.spike_counts) == samples - - -def test_ndtp_header(): - header = NDTPHeader(DataType.kBroadband, 1234567890, 42) - packed = header.pack() - unpacked = NDTPHeader.unpack(packed) - assert unpacked == header - - # Invalid version - with pytest.raises(ValueError): - NDTPHeader.unpack(b"\x00" + packed[1:]) - - # Data too smol - with pytest.raises(ValueError): - NDTPHeader.unpack( - struct.pack(">B", NDTP_VERSION) - + struct.pack(">B", DataType.kBroadband) - + struct.pack(">Q", 123) - ) - -def test_ndtp_message_broadband(): - header = NDTPHeader(DataType.kBroadband, timestamp=1234567890, seq_number=42) - payload = NDTPPayloadBroadband( - bit_width=12, - sample_rate=3, - is_signed=False, - channels=[ - NDTPPayloadBroadbandChannelData( - channel_id=c, - channel_data=[s * 1000 for s in range(c + 1)], - ) - for c in range(3) - ], - ) - message = NDTPMessage(header, payload) - - packed = message.pack() - assert message._crc16 == 19660 - - hexstring = " ".join(f"{i:02x}" for i in packed) - assert hexstring == "01 02 00 00 00 00 49 96 02 d2 00 2a 18 00 00 03 00 00 03 00 00 00 00 01 00 00 00 00 10 00 20 00 3e 80 00 00 20 00 30 00 3e 87 d0 4c cc" - - unpacked = NDTPMessage.unpack(packed) - assert message._crc16 == 19660 - - assert unpacked.header == message.header - assert isinstance(unpacked.payload, NDTPPayloadBroadband) - assert unpacked.payload == message.payload - -def test_ndtp_message_broadband_large(): - header = NDTPHeader(DataType.kBroadband, timestamp=1234567890, seq_number=42) - payload = NDTPPayloadBroadband( - bit_width=16, - sample_rate=36000, - is_signed=False, - channels=[ - NDTPPayloadBroadbandChannelData( - channel_id=c, - channel_data=[i for i in range(10000)], - ) - for c in range(20) - ], - ) - message = NDTPMessage(header, payload) - - packed = message.pack() - assert message._crc16 == 32263 - - unpacked = NDTPMessage.unpack(packed) - assert unpacked._crc16 == 32263 - - assert unpacked.header == message.header - assert isinstance(unpacked.payload, NDTPPayloadBroadband) - assert unpacked.payload == message.payload - - u_payload = unpacked.payload - assert u_payload.bit_width == payload.bit_width - assert u_payload.sample_rate == payload.sample_rate - assert u_payload.is_signed == payload.is_signed - assert len(u_payload.channels) == len(payload.channels) - for i, c in enumerate(payload.channels): - assert u_payload.channels[i].channel_id == c.channel_id - assert list(u_payload.channels[i].channel_data) == list(c.channel_data) - -def test_ndtp_message_spiketrain(): - header = NDTPHeader(DataType.kSpiketrain, timestamp=1234567890, seq_number=42) - payload = NDTPPayloadSpiketrain(bin_size_ms=10, spike_counts=[1, 2, 3, 2, 1]) - message = NDTPMessage(header, payload) - - packed = message.pack() - unpacked = NDTPMessage.unpack(packed) - - assert unpacked.header == message.header - assert isinstance(unpacked.payload, NDTPPayloadSpiketrain) - assert unpacked.payload == message.payload - - with pytest.raises(ValueError): - NDTPMessage.unpack(b"\x00" * (NDTPHeader.STRUCT.size + 8)) # Invalid data type - - -if __name__ == "__main__": - pytest.main() diff --git a/synapse/utils/ndtp.pyx b/synapse/utils/ndtp.pyx deleted file mode 100644 index b19df2c..0000000 --- a/synapse/utils/ndtp.pyx +++ /dev/null @@ -1,585 +0,0 @@ -import cython -import struct -from typing import List, Tuple - -from cython cimport boundscheck, wraparound -from cpython.buffer cimport PyBUF_SIMPLE -from cpython.bytes cimport PyBytes_FromStringAndSize -from cython.view cimport array as cvarray -from libc.stdint cimport uint8_t, uint16_t, uint64_t, int64_t - -from synapse.api.datatype_pb2 import DataType -import crcmod -from crcmod import * - - -cdef int DATA_TYPE_K_BROADBAND = DataType.kBroadband -cdef int DATA_TYPE_K_SPIKETRAIN = DataType.kSpiketrain - -NDTP_VERSION = 0x01 -cdef int NDTPPayloadSpiketrain_BIT_WIDTH = 4 - -CRC_16 = crcmod.predefined.mkCrcFun('crc-16') - -@boundscheck(False) -@wraparound(False) -def to_bytes( - values, - int bit_width, - existing: bytearray = None, - int writing_bit_offset = 0, - bint is_signed = False, - byteorder: str = 'big' -) -> Tuple[bytearray, int]: - cdef int num_values = len(values) - cdef int num_bits_to_write = num_values * bit_width - - # Initialize buffer - cdef bytearray buffer - cdef int buffer_length - cdef int bit_offset = 0 - if existing is None: - buffer = bytearray() - buffer_length = 0 - else: - buffer = existing - buffer_length = len(buffer) - if buffer_length > 0: - if writing_bit_offset > 0: - bit_offset = (buffer_length - 1) * 8 + writing_bit_offset - else: - bit_offset = (buffer_length) * 8 - - cdef int total_bits_needed = bit_offset + num_bits_to_write - cdef int total_bytes_needed = (total_bits_needed + 7) // 8 - - # Extend buffer if necessary - if len(buffer) < total_bytes_needed: - buffer.extend([0] * (total_bytes_needed - len(buffer))) - - # Get a writable memoryview of the buffer - cdef unsigned char[::1] buffer_view = buffer - - cdef int64_t min_value, max_value - if is_signed: - min_value = -(1 << (bit_width - 1)) - max_value = (1 << (bit_width - 1)) - 1 - else: - min_value = 0 - max_value = (1 << bit_width) - 1 - - cdef int64_t value - cdef uint64_t value_unsigned - cdef int bits_remaining, byte_index, bit_index, bits_in_current_byte, shift - cdef unsigned char bits_to_write - cdef bint byteorder_is_little - - if byteorder == 'little': - byteorder_is_little = True - elif byteorder == 'big': - byteorder_is_little = False - else: - raise ValueError("Invalid byteorder: " + byteorder) - - for py_value in values: - value = py_value - if not (min_value <= value <= max_value): - raise ValueError("Value " + str(value) + " cannot be represented in " + str(bit_width) + " bits") - - # Handle negative values for signed integers - if is_signed and value < 0: - value_unsigned = (1 << bit_width) + value # Two's complement - else: - value_unsigned = value - - bits_remaining = bit_width - while bits_remaining > 0: - byte_index = bit_offset // 8 - bit_index = bit_offset % 8 - - bits_in_current_byte = min(8 - bit_index, bits_remaining) - shift = bits_remaining - bits_in_current_byte - - # Extract the bits to write - bits_to_write = (value_unsigned >> shift) & ((1 << bits_in_current_byte) - 1) - - if byteorder_is_little: - # Align bits to the correct position in the byte - bits_to_write <<= bit_index - else: - bits_to_write <<= (8 - bit_index - bits_in_current_byte) - - # Write bits into the buffer - buffer_view[byte_index] |= bits_to_write - - bits_remaining -= bits_in_current_byte - bit_offset += bits_in_current_byte - - final_bit_offset = bit_offset % 8 - if final_bit_offset == 0 and total_bytes_needed < len(buffer): - # Trim the extra byte if we've not used any bits in it - buffer = buffer[:total_bytes_needed] - - return buffer, final_bit_offset - - -@boundscheck(False) -@wraparound(False) -def to_ints( - data, - int bit_width, - int count = 0, - int start_bit = 0, - bint is_signed = False, - byteorder: str = 'big' -) -> Tuple[List[int], int, object]: - if bit_width <= 0: - raise ValueError("bit width must be > 0") - - cdef int truncate_bytes = start_bit // 8 - start_bit = start_bit % 8 - - data = data[truncate_bytes:] - - # Convert data to a memoryview - cdef const unsigned char[::1] data_view - - if isinstance(data, (bytes, bytearray)): - data_view = data - else: - raise TypeError("Unsupported data type: " + str(type(data))) - - cdef Py_ssize_t data_len = len(data_view) - - if count > 0 and data_len < (bit_width * count + 7) // 8: - raise ValueError( - "insufficient data for " + str(count) + " x " + str(bit_width) + " bit values " + - "(expected " + str((bit_width * count + 7) // 8) + " bytes, given " + str(data_len) + " bytes)" - ) - - cdef int current_value = 0 - cdef int bits_in_current_value = 0 - cdef int mask = (1 << bit_width) - 1 - cdef int total_bits_read = 0 - cdef int byte_index, bit_index, bit - cdef int start - cdef int value_index = 0 - cdef int max_values = count if count > 0 else (data_len * 8) // bit_width - if max_values == 0: - raise ValueError("max_values must be > 0 (got " + str(len(data)) + " data, " + str(count) + " count, bit width " + str(bit_width) + ")") - cdef int[::1] values_array = cython.view.array(shape=(max_values,), itemsize=cython.sizeof(cython.int), format="i") - cdef int sign_bit = 1 << (bit_width - 1) - cdef uint8_t byte - - for byte_index in range(data_len): - byte = data_view[byte_index] - - if byteorder == 'little': - start = start_bit if byte_index == 0 else 0 - for bit_index in range(start, 8): - bit = (byte >> bit_index) & 1 - current_value |= bit << bits_in_current_value - bits_in_current_value += 1 - total_bits_read += 1 - - if bits_in_current_value == bit_width: - if is_signed: - if current_value & sign_bit: - current_value = current_value - (1 << bit_width) - else: - current_value = current_value & mask - values_array[value_index] = current_value - value_index += 1 - current_value = 0 - bits_in_current_value = 0 - - if count > 0 and value_index == count: - end_bit = start_bit + total_bits_read - return [values_array[i] for i in range(value_index)], end_bit, data - - elif byteorder == 'big': - start = start_bit if byte_index == 0 else 0 - for bit_index in range(7 - start, -1, -1): - bit = (byte >> bit_index) & 1 - current_value = (current_value << 1) | bit - bits_in_current_value += 1 - total_bits_read += 1 - - if bits_in_current_value == bit_width: - if is_signed: - if current_value & sign_bit: - current_value = current_value - (1 << bit_width) - else: - current_value = current_value & mask - values_array[value_index] = current_value - value_index += 1 - current_value = 0 - bits_in_current_value = 0 - - if count > 0 and value_index == count: - end_bit = start_bit + total_bits_read - return [values_array[i] for i in range(value_index)], end_bit, data - - else: - raise ValueError("Invalid byteorder: " + byteorder) - - if bits_in_current_value > 0: - if bits_in_current_value == bit_width: - if is_signed and (current_value & sign_bit): - current_value = current_value - (1 << bit_width) - else: - current_value = current_value & mask - values_array[value_index] = current_value - value_index += 1 - elif count == 0: - raise ValueError( - str(bits_in_current_value) + " bits left over, not enough to form a complete value of bit width " + str(bit_width) - ) - - if count > 0: - value_index = min(value_index, count) - - end_bit = start_bit + total_bits_read - return [values_array[i] for i in range(value_index)], end_bit, data - - -cdef class NDTPPayloadBroadbandChannelData: - cdef public int channel_id - cdef public object channel_data - - def __init__(self, int channel_id, channel_data): - self.channel_id = channel_id - self.channel_data = channel_data - - def __eq__(self, other): - if not isinstance(other, NDTPPayloadBroadbandChannelData): - return False - return ( - self.channel_id == other.channel_id and - list(self.channel_data) == list(other.channel_data) - ) - - def __ne__(self, other): - return not self.__eq__(other) - - -cdef class NDTPPayloadBroadband: - cdef public bint is_signed - cdef public int bit_width - cdef public int sample_rate - cdef public list channels # List of NDTPPayloadBroadbandChannelData objects - - def __init__(self, bint is_signed, int bit_width, int sample_rate, channels): - self.is_signed = is_signed - self.bit_width = bit_width - self.sample_rate = sample_rate - self.channels = channels - - def pack(self): - cdef int n_channels = len(self.channels) - cdef bytearray payload = bytearray() - - # First byte: bit width and signed flag - payload += struct.pack( - ">B", ((self.bit_width & 0x7F) << 1) | (1 if self.is_signed else 0) - ) - - # Next three bytes: number of channels (24-bit integer) - payload += n_channels.to_bytes(3, byteorder='big', signed=False) - - # Next three bytes: sample rate (24-bit integer) - payload += self.sample_rate.to_bytes(3, byteorder='big', signed=False) - - cdef NDTPPayloadBroadbandChannelData c - bit_offset = 0 - for c in self.channels: - payload, bit_offset = to_bytes( - values=[c.channel_id], - bit_width=24, - is_signed=False, - existing=payload, - writing_bit_offset=bit_offset, - ) - - payload, bit_offset = to_bytes( - values=[len(c.channel_data)], - bit_width=16, - is_signed=False, - existing=payload, - writing_bit_offset=bit_offset, - ) - - payload, bit_offset = to_bytes( - values=c.channel_data, - bit_width=self.bit_width, - is_signed=self.is_signed, - existing=payload, - writing_bit_offset=bit_offset, - ) - - return payload - - @staticmethod - def unpack(data): - if isinstance(data, bytes): - data = bytearray(data) - - cdef int payload_h_size = 7 - cdef int len_data = len(data) - if len_data < payload_h_size: - raise ValueError( - "Invalid broadband data size " + str(len_data) + ": expected at least " + str(payload_h_size) + " bytes" - ) - - cdef int bit_width = data[0] >> 1 - cdef bint is_signed = (data[0] & 1) == 1 - cdef int num_channels = int.from_bytes(data[1:4], 'big') - cdef int sample_rate = int.from_bytes(data[4:7], 'big') - - cdef list channels = [] - cdef int channel_id, num_samples - cdef list channel_data - cdef NDTPPayloadBroadbandChannelData channel - - truncated = data[7:] - bit_offset = 0 - - for c in range(num_channels): - a_channel_id, bit_offset, truncated = to_ints(data=truncated, bit_width=24, count=1, start_bit=bit_offset, is_signed=False) - channel_id = a_channel_id[0] - - a_num_samples, bit_offset, truncated = to_ints(data=truncated, bit_width=16, count=1, start_bit=bit_offset, is_signed=False) - num_samples = a_num_samples[0] - - channel_data, bit_offset, truncated = to_ints(data=truncated, bit_width=bit_width, count=num_samples, start_bit=bit_offset, is_signed=is_signed) - - channel = NDTPPayloadBroadbandChannelData(channel_id, channel_data) - channels.append(channel) - - return NDTPPayloadBroadband(is_signed, bit_width, sample_rate, channels) - - def __eq__(self, other): - if not isinstance(other, NDTPPayloadBroadband): - return False - return ( - self.is_signed == other.is_signed and - self.bit_width == other.bit_width and - self.sample_rate == other.sample_rate and - self.channels == other.channels - ) - - def __ne__(self, other): - return not self.__eq__(other) - - -cdef class NDTPPayloadSpiketrain: - cdef public int bin_size_ms - cdef public int[::1] spike_counts # Memoryview of integers - - def __init__(self, bin_size_ms, spike_counts): - cdef int size, i - self.bin_size_ms = bin_size_ms - self.spike_counts = None - - if isinstance(spike_counts, list): - size = len(spike_counts) - self.spike_counts = cython.view.array( - shape=(size,), - itemsize=cython.sizeof(cython.int), - format="i", - ) - for i in range(size): - self.spike_counts[i] = spike_counts[i] - else: - # Assume it's already a memoryview or array - self.spike_counts = spike_counts - - def pack(self): - cdef bytearray payload = bytearray() - cdef int spike_counts_len = len(self.spike_counts) - cdef int max_value = (1 << NDTPPayloadSpiketrain_BIT_WIDTH) - 1 # Maximum value for the given bit width - cdef int[::1] clamped_counts = cython.view.array(shape=(spike_counts_len,), itemsize=cython.sizeof(cython.int), format="i") - cdef int i - - # Clamp the values - for i in range(spike_counts_len): - clamped_counts[i] = min(self.spike_counts[i], max_value) - - # Pack the number of spikes (4 bytes) - payload += struct.pack(">I", spike_counts_len) - - # Pack the bin_size (1 byte) - payload += struct.pack(">B", self.bin_size_ms) - - # Pack clamped spike counts - spike_counts_bytes, _ = to_bytes( - clamped_counts, NDTPPayloadSpiketrain_BIT_WIDTH, is_signed=False - ) - payload += spike_counts_bytes - return payload - - @staticmethod - def unpack(data): - if isinstance(data, bytes): - data = bytearray(data) - - cdef str msg; - cdef int len_data = len(data) - if len_data < 5: - msg = "Invalid spiketrain data size " - msg += str(len_data) - msg += " bytes: expected at least 5 bytes" - raise ValueError(msg) - - cdef int num_spikes = struct.unpack(">I", data[:4])[0] - cdef int bin_size_ms = struct.unpack(">B", data[4:5])[0] - cdef bytearray payload = data[5:] - cdef int bits_needed = num_spikes * NDTPPayloadSpiketrain_BIT_WIDTH - cdef int bytes_needed = (bits_needed + 7) // 8 - - if len(payload) < bytes_needed: - msg = "Insufficient data for spiketrain data (expected " - msg += str(bytes_needed) - msg += "bytes for " - msg += str(num_spikes) - msg += " spikes, got " - msg += str(len(payload)) - msg += ")" - raise ValueError(msg) - - # Unpack spike_counts - spike_counts, _, _ = to_ints( - payload[:bytes_needed], NDTPPayloadSpiketrain_BIT_WIDTH, num_spikes, is_signed=False - ) - - return NDTPPayloadSpiketrain(bin_size_ms, spike_counts) - - def __eq__(self, other): - if not isinstance(other, NDTPPayloadSpiketrain): - return False - return list(self.spike_counts) == list(other.spike_counts) - - def __ne__(self, other): - return not self.__eq__(other) - - -cdef class NDTPHeader: - cdef public int data_type - cdef public long long timestamp - cdef public int seq_number - - STRUCT = struct.Struct(">BBQH") # Define as a Python class attribute - - def __init__(self, int data_type, long long timestamp, int seq_number): - self.data_type = data_type - self.timestamp = timestamp - self.seq_number = seq_number - - def __eq__(self, other): - if not isinstance(other, NDTPHeader): - return False - return ( - self.data_type == other.data_type and - self.timestamp == other.timestamp and - self.seq_number == other.seq_number - ) - - def __ne__(self, other): - return not self.__eq__(other) - - def pack(self): - cdef bytes packed_data = self.STRUCT.pack( - NDTP_VERSION, self.data_type, self.timestamp, self.seq_number - ) - return bytearray(packed_data) - - @staticmethod - def unpack(data): - if isinstance(data, bytes): - data = bytearray(data) - - cdef int expected_size = NDTPHeader.STRUCT.size - if len(data) < expected_size: - raise ValueError( - "Invalid header size " + str(len(data)) + ": expected " + str(expected_size) - ) - - version, data_type, timestamp, seq_number = NDTPHeader.STRUCT.unpack(bytes(data[:expected_size])) - if version != NDTP_VERSION: - raise ValueError( - "Incompatible version " + str(version) + ": expected " + hex(NDTP_VERSION) + ", got " + hex(version) - ) - - return NDTPHeader(data_type, timestamp, seq_number) - - -cdef class NDTPMessage: - cdef public NDTPHeader header - cdef public object payload - cdef public int _crc16 - - - - def __init__(self, NDTPHeader header, payload=None): - self.header = header - self.payload = payload - - @staticmethod - def crc16(bytearray data, int poly=8005, int init=0) -> int: - import crcmod - crc = CRC_16(data) - return crc - - @staticmethod - def crc16_verify(bytearray data, int crc16): - cdef bint result = NDTPMessage.crc16(data) == crc16 - return result - - def pack(self): - cdef bytearray message = bytearray() - cdef bytearray header_bytes = self.header.pack() - cdef bytearray payload_bytes = self.payload.pack() if self.payload else bytearray() - cdef int crc - cdef bytes crc_bytes - - message += header_bytes - message += payload_bytes - - self._crc16 = NDTPMessage.crc16(message) - crc_bytes = struct.pack(">H", self._crc16) - - message += crc_bytes # Appending bytes to bytearray is acceptable - - return message - - @staticmethod - def unpack(data): - if isinstance(data, bytes): - data = bytearray(data) - - cdef int header_size = NDTPHeader.STRUCT.size - cdef NDTPHeader header - cdef int crc16_value - cdef object pbytes - cdef int pdtype - cdef object payload = None - - header = NDTPHeader.unpack(data[:header_size]) - crc16_value = (data[-2] << 8) | data[-1] - - pbytes = data[header_size:-2] - pdtype = header.data_type - - if pdtype == DataType.kBroadband: - payload = NDTPPayloadBroadband.unpack(pbytes) - elif pdtype == DataType.kSpiketrain: - payload = NDTPPayloadSpiketrain.unpack(pbytes) - else: - raise ValueError("unknown data type " + str(pdtype)) - - if not NDTPMessage.crc16_verify(data[:-2], crc16_value): - raise ValueError("CRC16 verification failed (expected " + str(crc16_value) + ")") - - msg = NDTPMessage(header, payload) - msg._crc16 = crc16_value - return msg diff --git a/synapse/utils/ndtp_types.py b/synapse/utils/ndtp_types.py deleted file mode 100644 index 081a436..0000000 --- a/synapse/utils/ndtp_types.py +++ /dev/null @@ -1,165 +0,0 @@ -import math -from typing import List, Tuple, Union - -import numpy as np - -from synapse.api.datatype_pb2 import DataType -from synapse.utils.ndtp import ( - NDTPHeader, - NDTPMessage, - NDTPPayloadBroadband, - NDTPPayloadBroadbandChannelData, - NDTPPayloadSpiketrain, -) - -MAX_CH_PAYLOAD_SIZE_BYTES = 1400 - -def chunk_channel_data(bit_width: int, ch_data: List[float], max_payload_size_bytes: int): - n_packets = math.ceil(len(ch_data) * bit_width / (max_payload_size_bytes * 8)) - n_pts_per_packet = math.ceil(len(ch_data) / n_packets) - - for i in range(n_packets): - start_idx = i * n_pts_per_packet - end_idx = min(start_idx + n_pts_per_packet, len(ch_data)) - yield ch_data[start_idx:end_idx] - -class ElectricalBroadbandData: - """Electrical broadband data from neural recordings. - - Attributes: - t0 (int): Start timestamp in nanoseconds - is_signed (bool): Whether the data is represented using signed integers - bit_width (int): Number of bits used to represent each sample - samples (Tuple[int, List[float]]): Tuple of (channel_id, data_samples) - sample_rate (float): Sample rate in Hz - """ - __slots__ = ["data_type", "t0", "is_signed", "bit_width", "samples", "sample_rate"] - - def __init__(self, t0, bit_width, samples: Tuple[int, List[float]], sample_rate, is_signed=True): - self.data_type = DataType.kBroadband - - self.t0 = t0 # ns - self.is_signed = is_signed - self.bit_width = bit_width - self.samples = samples - self.sample_rate = sample_rate - - def pack(self, seq_number: int) -> Tuple[List[bytes], int]: - packets = [] - seq = seq_number - - try: - for ch_samples in self.samples: - ch_id = ch_samples[0] - ch_data = ch_samples[1] - if (len(ch_data) == 0): - continue - - n_samples = 0 - - for ch_sample_sub in chunk_channel_data(self.bit_width, ch_data, MAX_CH_PAYLOAD_SIZE_BYTES): - t_offset = round(n_samples * 1e6 / self.sample_rate) - timestamp = self.t0 + t_offset - msg = NDTPMessage( - header=NDTPHeader( - data_type=DataType.kBroadband, - timestamp=timestamp, - seq_number=seq, - ), - payload=NDTPPayloadBroadband( - is_signed=self.is_signed, - bit_width=self.bit_width, - sample_rate=self.sample_rate, - channels=[ - NDTPPayloadBroadbandChannelData( - channel_id=ch_id, channel_data=ch_sample_sub - ) - ], - ), - ) - n_samples += len(ch_sample_sub) - packed = msg.pack() - packets.append(packed) - seq = (seq + 1) % 2**16 - except Exception as e: - print(f"Error packing NDTP message: {e}") - - return packets, seq - - @staticmethod - def from_ndtp_message(msg: NDTPMessage): - dtype = np.int16 if msg.payload.is_signed else np.uint16 - return ElectricalBroadbandData( - t0=msg.header.timestamp, - bit_width=msg.payload.bit_width, - is_signed=msg.payload.is_signed, - sample_rate=msg.payload.sample_rate, - samples=[ - (ch.channel_id, np.array(ch.channel_data, dtype=dtype)) - for ch in msg.payload.channels - ], - ) - - @staticmethod - def unpack(data): - u = NDTPMessage.unpack(data) - return ElectricalBroadbandData.from_ndtp_message(u) - - def to_list(self): - return [ - self.t0, - [ - (int(channel_id), samples.tolist()) - for channel_id, samples in self.samples - ], - ] - -class SpiketrainData: - """Binned spike train data from neural recordings. - - Attributes: - t0 (int): Start timestamp in nanoseconds - bin_size_ms (float): Size of each time bin in milliseconds - spike_counts (List[int]): Number of spikes in each time bin - """ - __slots__ = ["data_type", "t0", "bin_size_ms", "spike_counts"] - - def __init__(self, t0, bin_size_ms, spike_counts): - self.data_type = DataType.kSpiketrain - self.t0 = t0 # ns - self.bin_size_ms = bin_size_ms - self.spike_counts = spike_counts - - def pack(self, seq_number: int): - message = NDTPMessage( - header=NDTPHeader( - data_type=DataType.kSpiketrain, - timestamp=self.t0, - seq_number=seq_number, - ), - payload=NDTPPayloadSpiketrain( - bin_size_ms=self.bin_size_ms, - spike_counts=self.spike_counts - ), - ) - seq_number = (seq_number + 1) % 2**16 - return [message.pack()], seq_number - - @staticmethod - def from_ndtp_message(msg: NDTPMessage): - return SpiketrainData( - t0=msg.header.timestamp, - bin_size_ms=msg.payload.bin_size_ms, - spike_counts=msg.payload.spike_counts, - ) - - @staticmethod - def unpack(data): - u = NDTPMessage.unpack(data) - return SpiketrainData.from_ndtp_message(u) - - def to_list(self): - return [self.t0, self.bin_size_ms, list(self.spike_counts)] - - -SynapseData = Union[SpiketrainData, ElectricalBroadbandData] diff --git a/synapse/utils/types.py b/synapse/utils/types.py new file mode 100644 index 0000000..9a497d9 --- /dev/null +++ b/synapse/utils/types.py @@ -0,0 +1,56 @@ +""" +Data types for internal node-to-node communication. + +These are simple dataclasses used for passing data between nodes in the +simulator and server. For over-the-wire transmission, data is serialized +to protobuf messages (e.g., BroadbandFrame, SpikeFrame). +""" + +from dataclasses import dataclass, field +from typing import List, Tuple, Union + +from synapse.api.datatype_pb2 import DataType + + +@dataclass +class ElectricalBroadbandData: + """Electrical broadband data from neural electrodes. + + Attributes: + t0: Timestamp in nanoseconds since epoch. + bit_width: Bit width of samples (e.g., 12, 16). + samples: List of (channel_id, samples_list) tuples. + sample_rate: Sample rate in Hz. + is_signed: Whether samples are signed integers. + """ + t0: int = 0 + bit_width: int = 16 + samples: List[Tuple[int, List[int]]] = field(default_factory=list) + sample_rate: int = 30000 + is_signed: bool = True + + @property + def data_type(self) -> DataType: + return DataType.kBroadband + + +@dataclass +class SpiketrainData: + """Binned spike train data. + + Attributes: + t0: Timestamp in nanoseconds since epoch. + bin_size_ms: Size of each bin in milliseconds. + spike_counts: List of spike counts per channel. + """ + t0: int = 0 + bin_size_ms: float = 20.0 + spike_counts: List[int] = field(default_factory=list) + + @property + def data_type(self) -> DataType: + return DataType.kSpiketrain + + +# Union type for all synapse data types +SynapseData = Union[ElectricalBroadbandData, SpiketrainData]