Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions csa/_elementary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import random
import numpy
import copy
import hashlib

from . import connset as cs
from . import intervalset as iset
Expand Down Expand Up @@ -56,7 +57,7 @@ def iterator (self, low0, high0, low1, high1, state):
class ConstantRandomMask (cs.Mask):
tag = 'randomMask'

def __init__ (self, p):
def __init__ (self, p:float, seed: int = None):
cs.Mask.__init__ (self)
self.p = p
self.state = random.getstate ()
Expand Down Expand Up @@ -151,6 +152,7 @@ def startIteration (self, state):
obj.perTarget = numpy.random.multinomial (obj.N, [1.0 / N1] * N1)
return obj

class BaseRandomMask(cs.Mask):
def iterator (self, low0, high0, low1, high1, state):
m = self.mask.set1.count (0, low1)

Expand Down Expand Up @@ -243,7 +245,7 @@ def startIteration (self, state):
else:
seed = 'FanInRandomMask'
# Numpy.random.seed requires an unsigned 32 bit integer
numpy.random.seed (hash (seed) % (numpy.iinfo(numpy.uint32).max + 1))
numpy.random.seed (int(hashlib.md5(seed.encode()).hexdigest(), 16) % (numpy.iinfo(numpy.uint32.max + 1)))

selected = state['selected']
obj.mask = partitions[selected]
Expand Down