From 7de8074ec1d4dbab7b8f55bbd81b9c99ad248d56 Mon Sep 17 00:00:00 2001 From: Aniketos07 <51787443+Aniketos07@users.noreply.github.com> Date: Thu, 27 Jan 2022 16:43:33 +0530 Subject: [PATCH] tf.Session deprecated, changed to tf.compat.v1.Session --- python/dpu_utils/tfutils/tfvariablesaver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dpu_utils/tfutils/tfvariablesaver.py b/python/dpu_utils/tfutils/tfvariablesaver.py index 9f77cf3..5b7b7be 100644 --- a/python/dpu_utils/tfutils/tfvariablesaver.py +++ b/python/dpu_utils/tfutils/tfvariablesaver.py @@ -10,7 +10,7 @@ class TFVariableSaver: def __init__(self): self.__saved_variables = {} # type: Dict[str, np.ndarray] - def save_all(self, session: tf.Session, exclude_variable: Optional[Callable[[str], bool]]=None) -> None: + def save_all(self, session: tf.compat.v1.Session, exclude_variable: Optional[Callable[[str], bool]]=None) -> None: self.__saved_variables = {} for variable in session.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): assert variable.name not in self.__saved_variables @@ -21,7 +21,7 @@ def save_all(self, session: tf.Session, exclude_variable: Optional[Callable[[str def has_saved_variables(self) -> bool: return len(self.__saved_variables) > 0 - def restore_saved_values(self, session: tf.Session) -> None: + def restore_saved_values(self, session: tf.v1.compat.Session) -> None: assert len(self.__saved_variables) > 0 save_ops = [] with tf.name_scope("restore"):