Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ __pycache__
htmlcov/
coverage.xml
.coverage

/notebooks/*.html
temp
204 changes: 106 additions & 98 deletions TCT/TCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ def visulization_one_hop_ranking_input_as_list(result_ranked_by_primary_infores,
output_png1="NE_heatmap1.png",
output_png2="NE_heatmap2.png"
):
# edited Dec 5, 2023

# if result_parsed is empty, print a message and return an empty dataframe
predicates_list = []
primary_infore_list = []
aggregator_infore_list = []
Expand Down Expand Up @@ -745,106 +746,109 @@ def visulization_one_hop_ranking(result_ranked_by_primary_infores,result_parsed
output_png2="NE_heatmap2.png"
):
# edited Dec 5, 2023
predicates_list = []
primary_infore_list = []
aggregator_infore_list = []

for i in range(0, result_ranked_by_primary_infores.shape[0]):
oupput_node = result_ranked_by_primary_infores['output_node'][i]
type_of_node = result_ranked_by_primary_infores['type_of_nodes'][i]
if type_of_node == 'object':
subject = input_query
object = oupput_node
else:
subject = oupput_node
object = input_query

predicates_list = predicates_list + result_parsed[subject + "_" + object]['predicate']
primary_infore_list = primary_infore_list + result_parsed[subject + "_" + object]['primary_knowledge_source']

if 'aggregator_knowledge_source' in result_parsed[subject + "_" + object]:
aggregator_infore_list = aggregator_infore_list + result_parsed[subject + "_" + object]['aggregator_knowledge_source']
aggregator_infore_list = list(set(aggregator_infore_list))

predicates_list = list(set(predicates_list))
primary_infore_list = list(set(primary_infore_list))
if result_parsed == {}:
print("No results found in result_parsed. Please check your input data.")
return pd.DataFrame() # Return an empty DataFrame if there are no results

else:
predicates_list = []
primary_infore_list = []
aggregator_infore_list = []

for i in range(0, result_ranked_by_primary_infores.shape[0]):
oupput_node = result_ranked_by_primary_infores['output_node'][i]
type_of_node = result_ranked_by_primary_infores['type_of_nodes'][i]
if type_of_node == 'object':
subject = input_query
object = oupput_node
else:
subject = oupput_node
object = input_query

predicates_list = predicates_list + result_parsed[subject + "_" + object]['predicate']
primary_infore_list = primary_infore_list + result_parsed[subject + "_" + object]['primary_knowledge_source']

predicates_by_nodes = {}
for predict in predicates_list:
predicates_by_nodes[predict] = []
if 'aggregator_knowledge_source' in result_parsed[subject + "_" + object]:
aggregator_infore_list = aggregator_infore_list + result_parsed[subject + "_" + object]['aggregator_knowledge_source']
aggregator_infore_list = list(set(aggregator_infore_list))

primary_infore_by_nodes = {}
for predict in primary_infore_list:
primary_infore_by_nodes[predict] = []
predicates_list = list(set(predicates_list))
primary_infore_list = list(set(primary_infore_list))

aggregator_infore_by_nodes = {}
for predict in aggregator_infore_list:
aggregator_infore_by_nodes[predict] = []

names = []
for i in range(0, result_ranked_by_primary_infores.shape[0]):
#for i in range(0, 10):
oupput_node = result_ranked_by_primary_infores['output_node'].values[i]
names.append(oupput_node)
type_of_node = result_ranked_by_primary_infores['type_of_nodes'].values[i]
if type_of_node == 'object':
subject = input_query
object = oupput_node
else:
subject = oupput_node
object = input_query
new_id = subject + "_" + object
predicates_by_nodes = {}
for predict in predicates_list:
predicates_by_nodes[predict] = []

cur_primary_infore = result_parsed[new_id]['primary_knowledge_source']
primary_infore_by_nodes = {}
for predict in primary_infore_list:
if predict in cur_primary_infore:
primary_infore_by_nodes[predict].append(1)
primary_infore_by_nodes[predict] = []

aggregator_infore_by_nodes = {}
for predict in aggregator_infore_list:
aggregator_infore_by_nodes[predict] = []

names = []
for i in range(0, result_ranked_by_primary_infores.shape[0]):
#for i in range(0, 10):
oupput_node = result_ranked_by_primary_infores['output_node'].values[i]
names.append(oupput_node)
type_of_node = result_ranked_by_primary_infores['type_of_nodes'].values[i]
if type_of_node == 'object':
subject = input_query
object = oupput_node
else:
primary_infore_by_nodes[predict].append(0)
subject = oupput_node
object = input_query
new_id = subject + "_" + object

cur_primary_infore = result_parsed[new_id]['primary_knowledge_source']
for predict in primary_infore_list:
if predict in cur_primary_infore:
primary_infore_by_nodes[predict].append(1)
else:
primary_infore_by_nodes[predict].append(0)



cur_predicates = result_parsed[new_id]['predicate']
for predict in predicates_list:
if predict in cur_predicates:
predicates_by_nodes[predict].append(1)
else:
predicates_by_nodes[predict].append(0)

#convert = False

#for item in colnames:
# if 'NCBIGene' in item:
# convert = True
#if convert:
#Gene_id_map = Gene_id_converter(colnames, "http://127.0.0.1:8000/query_name_by_id") # option 1
#Gene_id_map = Generate_Gene_id_map() # option 2
cur_predicates = result_parsed[new_id]['predicate']
for predict in predicates_list:
if predict in cur_predicates:
predicates_by_nodes[predict].append(1)
else:
predicates_by_nodes[predict].append(0)

dic_id_map = ID_convert_to_preferred_name_nodeNormalizer(names)
new_colnames = []
for item in names:
if item in dic_id_map:
new_colnames.append(dic_id_map[item])
else:
new_colnames.append(item)
dic_id_map = ID_convert_to_preferred_name_nodeNormalizer(names)
new_colnames = []
for item in names:
if item in dic_id_map:
new_colnames.append(dic_id_map[item])
else:
new_colnames.append(item)

#else:
# new_colnames = colnames

primary_infore_by_nodes_df = pd.DataFrame(primary_infore_by_nodes)
primary_infore_by_nodes_df.index = new_colnames
primary_infore_by_nodes_df = primary_infore_by_nodes_df.T
primary_infore_by_nodes_df = pd.DataFrame(primary_infore_by_nodes)
primary_infore_by_nodes_df.index = new_colnames
primary_infore_by_nodes_df = primary_infore_by_nodes_df.T


predicates_by_nodes_df = pd.DataFrame(predicates_by_nodes)
predicates_by_nodes_df.index = new_colnames
predicates_by_nodes_df = predicates_by_nodes_df.T
predicates_by_nodes_df = pd.DataFrame(predicates_by_nodes)
predicates_by_nodes_df.index = new_colnames
predicates_by_nodes_df = predicates_by_nodes_df.T

plot_heatmap(primary_infore_by_nodes_df, num_of_nodes, fontsize, title_fontsize,output_png1)
plot_heatmap(predicates_by_nodes_df, num_of_nodes, fontsize, title_fontsize,output_png2)
if not primary_infore_by_nodes_df.empty:
plot_heatmap(primary_infore_by_nodes_df, num_of_nodes, fontsize, title_fontsize, output_png1)
else:
print("No primary infores found in primary_infore_by_nodes_df.")

return(predicates_by_nodes_df)
if not predicates_by_nodes_df.empty:
plot_heatmap(predicates_by_nodes_df, num_of_nodes, fontsize, title_fontsize, output_png2)
else:
print("No predicates found in predicates_by_nodes_df.")
return pd.DataFrame() # Return empty if there's no predicate data to plot

return(predicates_by_nodes_df)

def plot_heatmap(predicates_by_nodes_df,num_of_nodes = 20,
fontsize = 6,
Expand All @@ -862,22 +866,26 @@ def plot_heatmap(predicates_by_nodes_df,num_of_nodes = 20,

# create the heatmap
# heatmap with border
p1 = sns.heatmap(df, cmap="Blues", cbar=False, ax=ax, linecolor='grey', linewidth=0.2)
# Adjust font size for x and y tick labels
p1.set_xticklabels(p1.get_xticklabels(), rotation=90, fontsize=fontsize)
p1.set_yticklabels(p1.get_yticklabels(), fontsize=fontsize)

#p1.set_title(title)
#p1.set_ylabel(ylab)
print(p1.get_xticklabels())
# set xticklabels with colnames

#p1.set_xticklabels(colnames, rotation=90, fontsize = fontsize)
plt.xticks(ticks=range(len(df.columns)), labels=df.columns)

# set title font size
p1.title.set_size(title_fontsize)
plt.show()
if df.empty:
print("No data to plot in the heatmap. Please check your input data.")
return()
else:
p1 = sns.heatmap(df, cmap="Blues", cbar=False, ax=ax, linecolor='grey', linewidth=0.2)
# Adjust font size for x and y tick labels
p1.set_xticklabels(p1.get_xticklabels(), rotation=90, fontsize=fontsize)
p1.set_yticklabels(p1.get_yticklabels(), fontsize=fontsize)

#p1.set_title(title)
#p1.set_ylabel(ylab)
print(p1.get_xticklabels())
# set xticklabels with colnames

#p1.set_xticklabels(colnames, rotation=90, fontsize = fontsize)
plt.xticks(ticks=range(len(df.columns)), labels=df.columns)

# set title font size
p1.title.set_size(title_fontsize)
plt.show()
# save the figure
#plt.savefig(output_png, bbox_inches='tight', dpi=300)

Expand Down
Loading
Loading