Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions pipeline_sindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def main(
model: str = None,
data: str = None,
save: bool = False,
save: bool = False, #just change and save the parameters to not rerun it

# generated dataset parameters
participant_id: int = None,
Expand Down Expand Up @@ -164,8 +164,8 @@ def main(
if use_optuna and verbose:
print("\nUsing Optuna to find optimal optimizer configuration for each participant")

# setup the SINDy-agent
agent_spice, loss_spice = fit_spice(
# Setup the SINDy-agent
agent_spice, filtered_participant_ids, loss_spice = fit_spice(
rnn_modules=list_rnn_modules,
control_signals=list_control_parameters,
agent=agent_rnn,
Expand All @@ -185,12 +185,16 @@ def main(
verbose=verbose,
use_optuna=use_optuna,
filter_bad_participants=filter_bad_participants,
)
)

# Update participant_ids with the filtered ones if filtering was applied
if filter_bad_participants:
participant_ids = filtered_participant_ids

# If agent_spice is None, we couldn't fit the model, so return early
if len(participant_ids) == 0:
print("ERROR: Failed to fit SPICE model. Returning None.")
return None, None, None
return None, None, None, []

# save spice modules
if save:
Expand All @@ -202,6 +206,11 @@ def main(
file_spice = os.path.join(*file_spice)
save_spice(agent_spice=agent_spice, file=file_spice)
print("Saved SPICE parameters to file " + file_spice)

# save the filtered participant IDs ----
ids_file = file_spice.replace('.pkl', '_participant_ids.npy')
np.save(ids_file, np.array(participant_ids, dtype=int))
print("Saved filtered participant IDs to file " + ids_file)

# ---------------------------------------------------------------------------------------------------
# Analysis
Expand Down Expand Up @@ -271,7 +280,7 @@ def main(
features['beta_reward'][pid] = betas['x_value_reward']
features['beta_choice'][pid] = betas['x_value_choice']

return agent_spice, features, loss_spice
return agent_spice, features, loss_spice, participant_ids


if __name__=='__main__':
Expand All @@ -296,7 +305,7 @@ def main(

args = parser.parse_args()

agent_spice, features, loss = main(
agent_spice, features, loss, participant_ids = main(
model=args.model,
data=args.data,
save=args.save,
Expand Down
44 changes: 33 additions & 11 deletions resources/sindy_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ def fit_spice(
get_loss (bool, optional): Whether to compute loss. Defaults to False.
verbose (bool, optional): Whether to print verbose output. Defaults to False.
use_optuna (bool, optional): Whether to use Optuna for optimizer selection. Defaults to False.
filter_bad_participants (bool, optional): Whether to filter out badly fitted participants. Defaults to False.

Returns:
Tuple[AgentSpice, float]: The SPICE agent and its loss
Tuple[AgentSpice, List[int], float]: The SPICE agent, list of well-fitted participant IDs, and loss
"""

if participant_id is not None:
Expand Down Expand Up @@ -286,16 +287,19 @@ def fit_spice(
# set up a SINDy-based agent by replacing the RNN-modules with the respective SINDy-model
agent_spice = AgentSpice(model_rnn=deepcopy(agent._model), sindy_modules=sindy_models, n_actions=agent._n_actions, deterministic=deterministic)

# remove badly fitted participants
if filter_bad_participants:
agent_spice, participant_ids = remove_bad_participants(
# Initialize filtered_ids with all participant_ids
filtered_ids = np.array(participant_ids)

# Filter badly fitted participants if requested
if filter_bad_participants and data is not None:
agent_spice, filtered_ids = remove_bad_participants(
agent_spice=agent_spice,
agent_rnn=agent,
dataset=data,
participant_ids=participant_ids,
verbose=verbose,
)

# compute loss
loss = None
if get_loss and data is None:
Expand All @@ -305,10 +309,28 @@ def fit_spice(
n_trials_total = 0
mapping_modules_values = {module: 'x_value_choice' if 'choice' in module else 'x_value_reward' for module in agent_spice._model.submodules_sindy}
n_parameters = agent_spice.count_parameters(mapping_modules_values=mapping_modules_values)
for pid in participant_ids:
xs, ys = data.xs.cpu().numpy(), data.ys.cpu().numpy()
probs = get_update_dynamics(experiment=xs[pid], agent=agent_spice)[1]
loss += loss_metric(data=ys[pid, :len(probs)], probs=probs, n_parameters=n_parameters[pid])

# Use filtered_ids for loss calculation if filtering was applied
ids_to_use = filtered_ids if filter_bad_participants else participant_ids

for pid in ids_to_use:
if pid not in agent_spice._model.submodules_sindy[list(agent_spice._model.submodules_sindy.keys())[0]]:
continue

mask_participant_id = data.xs[:, 0, -1] == pid
if not mask_participant_id.any():
continue

participant_data = DatasetRNN(*data[mask_participant_id])
xs, ys = participant_data.xs.cpu().numpy(), participant_data.ys.cpu().numpy()

probs = get_update_dynamics(experiment=xs, agent=agent_spice)[1]
loss += loss_metric(data=ys[0, :len(probs)], probs=probs, n_parameters=n_parameters[pid])
n_trials_total += len(probs)
loss = loss/n_trials_total
return agent_spice, loss

if n_trials_total > 0:
loss = loss / n_trials_total
else:
loss = float('inf') # If no valid trials, set loss to infinity

return agent_spice, filtered_ids, loss
49 changes: 22 additions & 27 deletions resources/sindy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,25 +289,24 @@ def remove_bad_participants(agent_spice: AgentSpice, agent_rnn: AgentNetwork, da
"""Check for badly fitted participants in the SPICE models w.r.t. the SPICE-RNN and return only the IDs of the well-fitted participants.

Args:
agent_spice (AgentSpice): _description_
agent_rnn (AgentNetwork): _description_
dataset_test (DatasetRNN): _description_
participant_ids (Iterable[int]): _description_
verbose (bool, optional): _description_. Defaults to False.
agent_spice (AgentSpice): SPICE agent to filter
agent_rnn (AgentNetwork): Reference RNN agent
dataset (DatasetRNN): Dataset to evaluate on
participant_ids (Iterable[int]): Participant IDs to check
trial_likelihood_difference_threshold (float, optional): Threshold for filtering. Defaults to 0.05.
verbose (bool, optional): Whether to print verbose output. Defaults to False.

Returns:
AgentSpice: SPICE agent without badly fitted participants
Iterable[int]: List of well-fitted participants
"""
# if verbose:
print("\nFiltering badly fitted SPICE models...")

filtered_participant_ids = []
# Convert participant_ids to a list of integers
participant_ids = list(map(int, participant_ids))
removed_participants = []
good_participants = []

# Create a copy of the valid participant IDs
valid_participant_ids = list(participant_ids)

removed_pids = []
for pid in tqdm(participant_ids):
# Skip if participant is not in the SPICE model
if pid not in agent_spice._model.submodules_sindy[list(agent_spice._model.submodules_sindy.keys())[0]]:
Expand Down Expand Up @@ -337,32 +336,28 @@ def remove_bad_participants(agent_spice: AgentSpice, agent_rnn: AgentNetwork, da
spice_per_action_likelihood = np.exp(ll_spice/(n_trials_test*agent_rnn._n_actions))
rnn_per_action_likelihood = np.exp(ll_rnn/(n_trials_test*agent_rnn._n_actions))

# Idea for filter criteria:
# If accuracy is very low for SPICE (near chance) but not so low for RNN then bad SPICE fitting (at least a bit higher than chance)
# TODO: Check for better filter criteria
# Filter out participants where SPICE performs much worse than RNN
if rnn_per_action_likelihood - spice_per_action_likelihood > trial_likelihood_difference_threshold:
if verbose:
print(f'SPICE trial likelihood ({spice_per_action_likelihood:.2f}) is unplausibly low compared to RNN trial likelihood ({rnn_per_action_likelihood:.2f}).')
print(f'SPICE optimizer may be badly parameterized. Skipping participant {pid}.')
print(f'SPICE trial likelihood ({spice_per_action_likelihood:.2f}) is unplausibly low compared to RNN trial likelihood ({rnn_per_action_likelihood:.2f}).')
print(f'SPICE optimizer may be badly parameterized. Skipping participant {pid}.')

# Remove this participant from the SPICE model
for module in agent_spice._model.submodules_sindy:
if pid in agent_spice._model.submodules_sindy[module]:
del agent_spice._model.submodules_sindy[module][pid]

# Remove from valid participant IDs
if pid in valid_participant_ids:
valid_participant_ids.remove(pid)
removed_pids.append(pid)
removed_participants.append(np.int32(pid))
else:
# Keep track of filtered (good) participants
filtered_participant_ids.append(pid)
good_participants.append(np.int32(pid))

# Convert to numpy arrays for consistency
good_participants = np.array(good_participants)
removed_participants = np.array(removed_participants)

if verbose:
print(f"\nAfter filtering: {len(valid_participant_ids)} of {len(participant_ids)} participants have well-fitted SPICE models.")
print(f"Removed participants: {removed_pids}")
print(f"After filtering: {len(good_participants)} of {len(participant_ids)} participants have well-fitted SPICE models.")
print(f"Removed participants: {removed_participants}")

return agent_spice, np.array(valid_participant_ids)
return agent_spice, good_participants


def save_spice(agent_spice: AgentSpice, file: str):
Expand Down
19 changes: 6 additions & 13 deletions tests/test_pipeline_sindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,10 @@
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import pipeline_sindy


agent_spice, features, loss = pipeline_sindy.main(

# data='data/parameter_recovery/data_32p_0.csv',
# model='params/parameter_recovery/params_32p_0.pkl',

model = 'params/eckstein2022/rnn_eckstein2022_l1_0_0001_l2_0_0001.pkl',
data = 'data/eckstein2022/eckstein2022.csv',
save = True,

# general recovery parameters
agent_spice, features, loss, participant_ids = pipeline_sindy.main(
model='params/eckstein2022/rnn_eckstein2022_l1_0_0001_l2_0_0001.pkl',
data='data/eckstein2022/eckstein2022.csv',
save=True,
participant_id=None,
filter_bad_participants=True,
use_optuna=False,
Expand All @@ -39,9 +32,9 @@
alpha_choice=1.,
counterfactual=False,
alpha_counterfactual=0.,

analysis=True,
get_loss=True,
)

print(loss)
print(f"Final loss: {loss:.4f}")
print("Kept participant IDs:", participant_ids)