From 641fd87d367be0dd2b14f689db0aa642ff2528f0 Mon Sep 17 00:00:00 2001 From: Martyna Plomecka Date: Sun, 18 May 2025 18:39:25 -0400 Subject: [PATCH] Keep removed subject IDs --- pipeline_sindy.py | 23 +++++++++++------ resources/sindy_training.py | 44 ++++++++++++++++++++++++-------- resources/sindy_utils.py | 49 ++++++++++++++++-------------------- tests/test_pipeline_sindy.py | 19 +++++--------- 4 files changed, 77 insertions(+), 58 deletions(-) diff --git a/pipeline_sindy.py b/pipeline_sindy.py index 2372623c..7bc312cb 100644 --- a/pipeline_sindy.py +++ b/pipeline_sindy.py @@ -19,7 +19,7 @@ def main( model: str = None, data: str = None, - save: bool = False, + save: bool = False, #just change and save the parameters to not rerun it # generated dataset parameters participant_id: int = None, @@ -164,8 +164,8 @@ def main( if use_optuna and verbose: print("\nUsing Optuna to find optimal optimizer configuration for each participant") - # setup the SINDy-agent - agent_spice, loss_spice = fit_spice( + # Setup the SINDy-agent + agent_spice, filtered_participant_ids, loss_spice = fit_spice( rnn_modules=list_rnn_modules, control_signals=list_control_parameters, agent=agent_rnn, @@ -185,12 +185,16 @@ def main( verbose=verbose, use_optuna=use_optuna, filter_bad_participants=filter_bad_participants, - ) + ) + + # Update participant_ids with the filtered ones if filtering was applied + if filter_bad_participants: + participant_ids = filtered_participant_ids # If agent_spice is None, we couldn't fit the model, so return early if len(participant_ids) == 0: print("ERROR: Failed to fit SPICE model. Returning None.") - return None, None, None + return None, None, None, [] # save spice modules if save: @@ -202,6 +206,11 @@ def main( file_spice = os.path.join(*file_spice) save_spice(agent_spice=agent_spice, file=file_spice) print("Saved SPICE parameters to file " + file_spice) + + # save the filtered participant IDs ---- + ids_file = file_spice.replace('.pkl', '_participant_ids.npy') + np.save(ids_file, np.array(participant_ids, dtype=int)) + print("Saved filtered participant IDs to file " + ids_file) # --------------------------------------------------------------------------------------------------- # Analysis @@ -271,7 +280,7 @@ def main( features['beta_reward'][pid] = betas['x_value_reward'] features['beta_choice'][pid] = betas['x_value_choice'] - return agent_spice, features, loss_spice + return agent_spice, features, loss_spice, participant_ids if __name__=='__main__': @@ -296,7 +305,7 @@ def main( args = parser.parse_args() - agent_spice, features, loss = main( + agent_spice, features, loss, participant_ids = main( model=args.model, data=args.data, save=args.save, diff --git a/resources/sindy_training.py b/resources/sindy_training.py index cd382d15..54b04c40 100644 --- a/resources/sindy_training.py +++ b/resources/sindy_training.py @@ -166,9 +166,10 @@ def fit_spice( get_loss (bool, optional): Whether to compute loss. Defaults to False. verbose (bool, optional): Whether to print verbose output. Defaults to False. use_optuna (bool, optional): Whether to use Optuna for optimizer selection. Defaults to False. + filter_bad_participants (bool, optional): Whether to filter out badly fitted participants. Defaults to False. Returns: - Tuple[AgentSpice, float]: The SPICE agent and its loss + Tuple[AgentSpice, List[int], float]: The SPICE agent, list of well-fitted participant IDs, and loss """ if participant_id is not None: @@ -286,16 +287,19 @@ def fit_spice( # set up a SINDy-based agent by replacing the RNN-modules with the respective SINDy-model agent_spice = AgentSpice(model_rnn=deepcopy(agent._model), sindy_modules=sindy_models, n_actions=agent._n_actions, deterministic=deterministic) - # remove badly fitted participants - if filter_bad_participants: - agent_spice, participant_ids = remove_bad_participants( + # Initialize filtered_ids with all participant_ids + filtered_ids = np.array(participant_ids) + + # Filter badly fitted participants if requested + if filter_bad_participants and data is not None: + agent_spice, filtered_ids = remove_bad_participants( agent_spice=agent_spice, agent_rnn=agent, dataset=data, participant_ids=participant_ids, verbose=verbose, ) - + # compute loss loss = None if get_loss and data is None: @@ -305,10 +309,28 @@ def fit_spice( n_trials_total = 0 mapping_modules_values = {module: 'x_value_choice' if 'choice' in module else 'x_value_reward' for module in agent_spice._model.submodules_sindy} n_parameters = agent_spice.count_parameters(mapping_modules_values=mapping_modules_values) - for pid in participant_ids: - xs, ys = data.xs.cpu().numpy(), data.ys.cpu().numpy() - probs = get_update_dynamics(experiment=xs[pid], agent=agent_spice)[1] - loss += loss_metric(data=ys[pid, :len(probs)], probs=probs, n_parameters=n_parameters[pid]) + + # Use filtered_ids for loss calculation if filtering was applied + ids_to_use = filtered_ids if filter_bad_participants else participant_ids + + for pid in ids_to_use: + if pid not in agent_spice._model.submodules_sindy[list(agent_spice._model.submodules_sindy.keys())[0]]: + continue + + mask_participant_id = data.xs[:, 0, -1] == pid + if not mask_participant_id.any(): + continue + + participant_data = DatasetRNN(*data[mask_participant_id]) + xs, ys = participant_data.xs.cpu().numpy(), participant_data.ys.cpu().numpy() + + probs = get_update_dynamics(experiment=xs, agent=agent_spice)[1] + loss += loss_metric(data=ys[0, :len(probs)], probs=probs, n_parameters=n_parameters[pid]) n_trials_total += len(probs) - loss = loss/n_trials_total - return agent_spice, loss \ No newline at end of file + + if n_trials_total > 0: + loss = loss / n_trials_total + else: + loss = float('inf') # If no valid trials, set loss to infinity + + return agent_spice, filtered_ids, loss \ No newline at end of file diff --git a/resources/sindy_utils.py b/resources/sindy_utils.py index 107f63c9..bf1ee020 100644 --- a/resources/sindy_utils.py +++ b/resources/sindy_utils.py @@ -289,25 +289,24 @@ def remove_bad_participants(agent_spice: AgentSpice, agent_rnn: AgentNetwork, da """Check for badly fitted participants in the SPICE models w.r.t. the SPICE-RNN and return only the IDs of the well-fitted participants. Args: - agent_spice (AgentSpice): _description_ - agent_rnn (AgentNetwork): _description_ - dataset_test (DatasetRNN): _description_ - participant_ids (Iterable[int]): _description_ - verbose (bool, optional): _description_. Defaults to False. + agent_spice (AgentSpice): SPICE agent to filter + agent_rnn (AgentNetwork): Reference RNN agent + dataset (DatasetRNN): Dataset to evaluate on + participant_ids (Iterable[int]): Participant IDs to check + trial_likelihood_difference_threshold (float, optional): Threshold for filtering. Defaults to 0.05. + verbose (bool, optional): Whether to print verbose output. Defaults to False. Returns: AgentSpice: SPICE agent without badly fitted participants Iterable[int]: List of well-fitted participants """ - # if verbose: print("\nFiltering badly fitted SPICE models...") - filtered_participant_ids = [] + # Convert participant_ids to a list of integers + participant_ids = list(map(int, participant_ids)) + removed_participants = [] + good_participants = [] - # Create a copy of the valid participant IDs - valid_participant_ids = list(participant_ids) - - removed_pids = [] for pid in tqdm(participant_ids): # Skip if participant is not in the SPICE model if pid not in agent_spice._model.submodules_sindy[list(agent_spice._model.submodules_sindy.keys())[0]]: @@ -337,32 +336,28 @@ def remove_bad_participants(agent_spice: AgentSpice, agent_rnn: AgentNetwork, da spice_per_action_likelihood = np.exp(ll_spice/(n_trials_test*agent_rnn._n_actions)) rnn_per_action_likelihood = np.exp(ll_rnn/(n_trials_test*agent_rnn._n_actions)) - # Idea for filter criteria: - # If accuracy is very low for SPICE (near chance) but not so low for RNN then bad SPICE fitting (at least a bit higher than chance) - # TODO: Check for better filter criteria + # Filter out participants where SPICE performs much worse than RNN if rnn_per_action_likelihood - spice_per_action_likelihood > trial_likelihood_difference_threshold: - if verbose: - print(f'SPICE trial likelihood ({spice_per_action_likelihood:.2f}) is unplausibly low compared to RNN trial likelihood ({rnn_per_action_likelihood:.2f}).') - print(f'SPICE optimizer may be badly parameterized. Skipping participant {pid}.') + print(f'SPICE trial likelihood ({spice_per_action_likelihood:.2f}) is unplausibly low compared to RNN trial likelihood ({rnn_per_action_likelihood:.2f}).') + print(f'SPICE optimizer may be badly parameterized. Skipping participant {pid}.') # Remove this participant from the SPICE model for module in agent_spice._model.submodules_sindy: if pid in agent_spice._model.submodules_sindy[module]: del agent_spice._model.submodules_sindy[module][pid] - # Remove from valid participant IDs - if pid in valid_participant_ids: - valid_participant_ids.remove(pid) - removed_pids.append(pid) + removed_participants.append(np.int32(pid)) else: - # Keep track of filtered (good) participants - filtered_participant_ids.append(pid) + good_participants.append(np.int32(pid)) + + # Convert to numpy arrays for consistency + good_participants = np.array(good_participants) + removed_participants = np.array(removed_participants) - if verbose: - print(f"\nAfter filtering: {len(valid_participant_ids)} of {len(participant_ids)} participants have well-fitted SPICE models.") - print(f"Removed participants: {removed_pids}") + print(f"After filtering: {len(good_participants)} of {len(participant_ids)} participants have well-fitted SPICE models.") + print(f"Removed participants: {removed_participants}") - return agent_spice, np.array(valid_participant_ids) + return agent_spice, good_participants def save_spice(agent_spice: AgentSpice, file: str): diff --git a/tests/test_pipeline_sindy.py b/tests/test_pipeline_sindy.py index 32b90d39..816c6fc3 100644 --- a/tests/test_pipeline_sindy.py +++ b/tests/test_pipeline_sindy.py @@ -4,17 +4,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import pipeline_sindy - -agent_spice, features, loss = pipeline_sindy.main( - - # data='data/parameter_recovery/data_32p_0.csv', - # model='params/parameter_recovery/params_32p_0.pkl', - - model = 'params/eckstein2022/rnn_eckstein2022_l1_0_0001_l2_0_0001.pkl', - data = 'data/eckstein2022/eckstein2022.csv', - save = True, - - # general recovery parameters +agent_spice, features, loss, participant_ids = pipeline_sindy.main( + model='params/eckstein2022/rnn_eckstein2022_l1_0_0001_l2_0_0001.pkl', + data='data/eckstein2022/eckstein2022.csv', + save=True, participant_id=None, filter_bad_participants=True, use_optuna=False, @@ -39,9 +32,9 @@ alpha_choice=1., counterfactual=False, alpha_counterfactual=0., - analysis=True, get_loss=True, ) -print(loss) \ No newline at end of file +print(f"Final loss: {loss:.4f}") +print("Kept participant IDs:", participant_ids) \ No newline at end of file