1+ import pprint
12from enum import Enum
2- from functools import reduce
33from random import choices
44from typing import Dict
55import ast
66import operator as op
7+ import numpy as np
78
89from pix_framework .statistics .distribution import DurationDistribution
910
1011from prosimos .exceptions import InvalidEventAttributeException
12+ import math
1113
1214
1315class EVENT_ATTR_TYPE (Enum ):
1416 DISCRETE = "discrete"
1517 CONTINUOUS = "continuous"
1618 EXPRESSION = "expression"
19+ DTREE = "dtree"
1720
1821
1922operators = {
@@ -24,18 +27,67 @@ class EVENT_ATTR_TYPE(Enum):
2427 ast .FloorDiv : op .floordiv
2528}
2629
30+ def fix (mean ):
31+ return DurationDistribution (name = "fix" , mean = mean , var = 0.0 , std = 0.0 , minimum = mean , maximum = mean ).generate_sample (1 )[0 ]
2732
28- def parse_discrete_value (value_info_arr ):
29- prob_arr = []
30- options_arr = []
31- for item in value_info_arr :
32- options_arr .append (item ["key" ])
33- prob_arr .append (float (item ["value" ]))
3433
35- return {
36- "options" : options_arr ,
37- "probabilities" : prob_arr
38- }
34+ def uniform (minimum , maximum ):
35+ return DurationDistribution (name = "uniform" , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
36+
37+
38+ def norm (mean , std , minimum = None , maximum = None ):
39+ return DurationDistribution (name = "norm" , mean = mean , std = std , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
40+
41+
42+ def triang (c , minimum , maximum ):
43+ return DurationDistribution (name = "triang" , mean = c , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
44+
45+
46+ def expon (mean , minimum = None , maximum = None ):
47+ return DurationDistribution (name = "expon" , mean = mean , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
48+
49+
50+ def lognorm (mean , var , minimum = None , maximum = None ):
51+ return DurationDistribution (name = "lognorm" , mean = mean , var = var , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
52+
53+
54+ def gamma (mean , var , minimum = None , maximum = None ):
55+ return DurationDistribution (name = "gamma" , mean = mean , var = var , minimum = minimum , maximum = maximum ).generate_sample (1 )[0 ]
56+
57+
58+ distributions = {
59+ 'fix' : fix ,
60+ 'uniform' : uniform ,
61+ 'norm' : norm ,
62+ 'triang' : triang ,
63+ 'expon' : expon ,
64+ 'lognorm' : lognorm ,
65+ 'gamma' : gamma
66+ }
67+
68+ math_functions = {name : getattr (math , name ) for name in dir (math ) if callable (getattr (math , name ))}
69+
70+
71+ def parse_discrete_value (value_info ):
72+ if isinstance (value_info , list ):
73+ prob_arr = []
74+ options_arr = []
75+ for item in value_info :
76+ options_arr .append (item ["key" ])
77+ prob_arr .append (float (item ["value" ]))
78+
79+ return {
80+ "type" : "discrete" ,
81+ "options" : options_arr ,
82+ "probabilities" : prob_arr
83+ }
84+ elif isinstance (value_info , dict ):
85+ return {
86+ "type" : "markov" ,
87+ "transitions" : value_info
88+ }
89+ else :
90+ raise ValueError ("Unsupported value_info format for discrete value" )
3991
4092
4193def parse_continuous_value (value_info ) -> "DurationDistribution" :
@@ -57,7 +109,8 @@ def _eval(node):
57109 elif isinstance (node , ast .UnaryOp ):
58110 return operators [type (node .op )](_eval (node .operand ))
59111 elif isinstance (node , ast .Name ):
60- return vars_dict [node .id ]
112+ if node .id in vars_dict :
113+ return vars_dict [node .id ]
61114 elif isinstance (node , ast .Str ):
62115 return node .s
63116 elif isinstance (node , ast .Compare ):
@@ -67,14 +120,33 @@ def _eval(node):
67120 return all (_eval (value ) for value in node .values )
68121 elif type (node .op ) is ast .Or :
69122 return any (_eval (value ) for value in node .values )
123+ elif isinstance (node , ast .Call ):
124+ func_name = node .func .id
125+ args = [_eval (arg ) for arg in node .args ]
126+ if func_name in distributions :
127+ return distributions [func_name ](* args )
128+ elif func_name in math_functions :
129+ args = [_eval (arg ) for arg in node .args ]
130+ try :
131+ return math_functions [func_name ](* args )
132+ except OverflowError :
133+ return np .finfo (np .float32 ).max
134+ except ValueError as e :
135+ return 0
70136 else :
71- return None
137+ return 0
72138 except (SyntaxError , ZeroDivisionError , TypeError , KeyError ):
73- return None
139+ return 0
74140
75141 return _eval (tree .body )
76142
77143
144+ def evaluate_dtree (dtree , vars_dict ):
145+ for conditions , formula in dtree :
146+ if conditions is True or all (eval_expr (cond , vars_dict ) for cond in conditions ):
147+ return eval_expr (formula , vars_dict )
148+ return None
149+
78150class EventAttribute :
79151 def __init__ (self , event_id , name , event_attr_type , value ):
80152 self .event_id : str = event_id
@@ -87,27 +159,76 @@ def __init__(self, event_id, name, event_attr_type, value):
87159 self .value = parse_continuous_value (value )
88160 elif self .event_attr_type == EVENT_ATTR_TYPE .EXPRESSION :
89161 self .value = value
162+ elif self .event_attr_type == EVENT_ATTR_TYPE .DTREE :
163+ self .value = value
90164 else :
91165 raise Exception (f"Not supported event attribute { type } " )
92166
93167 self .validate ()
94168
95- def get_next_value (self , all_attributes = {} ):
169+ def get_next_value (self , all_attributes ):
96170 if self .event_attr_type == EVENT_ATTR_TYPE .DISCRETE :
97- one_choice_arr = choices (self .value ["options" ], self .value ["probabilities" ])
98- return one_choice_arr [0 ]
171+ if self .value ["type" ] == "markov" :
172+ current_value = all_attributes .get (self .name , None )
173+ next_state = self .get_next_markov_state (current_value )
174+ if next_state is not None :
175+ all_attributes [self .name ] = next_state
176+ return next_state
177+ return current_value
178+ else :
179+ one_choice_arr = choices (self .value ["options" ], self .value ["probabilities" ])
180+ return one_choice_arr [0 ]
181+
99182 elif self .event_attr_type == EVENT_ATTR_TYPE .EXPRESSION :
100- return eval_expr (self .value , all_attributes )
183+ result = eval_expr (self .value , all_attributes )
184+ if isinstance (result , (int , float , np .number )) and not isinstance (result , bool ):
185+ if result == 0 : # Specifically handle zero without adjusting to tiny (in case of any errors in eval)
186+ return 0
187+ elif result == float ('inf' ):
188+ return np .finfo (np .float32 ).max
189+ elif result == - float ('inf' ):
190+ return np .finfo (np .float32 ).min
191+ elif abs (result ) < np .finfo (np .float32 ).tiny :
192+ return np .finfo (np .float32 ).tiny
193+ elif abs (result ) > np .finfo (np .float32 ).max :
194+ return np .finfo (np .float32 ).max if result > 0 else np .finfo (np .float32 ).min
195+ else :
196+ return result
197+ else :
198+ return result
199+ elif self .event_attr_type == EVENT_ATTR_TYPE .DTREE :
200+ result = evaluate_dtree (self .value , all_attributes )
201+ if result is not None :
202+ return result
203+ return 0
101204 else :
102205 return self .value .generate_sample (1 )[0 ]
103206
207+ def get_next_markov_state (self , current_value ):
208+ transitions = self .value ["transitions" ]
209+ if current_value in transitions :
210+ current_transitions = transitions [current_value ]
211+ options , probabilities = zip (* current_transitions .items ())
212+ return choices (options , probabilities )[0 ]
213+ return current_value
214+
104215 def validate (self ):
105- if self .event_attr_type == EVENT_ATTR_TYPE .DISCRETE :
106- actual_sum_probabilities = reduce (lambda acc , item : acc + item , self .value ["probabilities" ], 0 )
216+ epsilon = 1e-6
107217
108- if actual_sum_probabilities != 1 :
109- raise InvalidEventAttributeException (
110- f"Event attribute ${ self .name } : probabilities' sum should be equal to 1" )
218+ if self .event_attr_type == EVENT_ATTR_TYPE .DISCRETE :
219+ if self .value ["type" ] == "discrete" :
220+ actual_sum_probabilities = sum (self .value ["probabilities" ])
221+
222+ if not (1 - epsilon <= actual_sum_probabilities <= 1 + epsilon ):
223+ raise InvalidEventAttributeException (
224+ f"Event attribute { self .name } : probabilities' sum should be equal to 1" )
225+ elif self .value ["type" ] == "markov" :
226+ for state , transitions in self .value ["transitions" ].items ():
227+ actual_sum_probabilities = sum (transitions .values ())
228+ if not (1 - epsilon <= actual_sum_probabilities <= 1 + epsilon ):
229+ raise InvalidEventAttributeException (
230+ f"Event attribute { self .name } , state { state } : "
231+ f"probabilities' sum should be equal to 1" )
111232
112233 return True
113234
0 commit comments