diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index dd7c9d2e..603967aa 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -169,7 +169,6 @@ struct DuckDBPyConnection : public enable_shared_from_this { //! MemoryFileSystem used to temporarily store file-like objects for reading shared_ptr internal_object_filesystem; case_insensitive_map_t> registered_functions; - case_insensitive_set_t registered_objects; public: explicit DuckDBPyConnection() { diff --git a/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp b/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp index 8e329ea7..f655a138 100644 --- a/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp +++ b/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp @@ -4,10 +4,25 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/function/replacement_scan.hpp" +#include "duckdb_python/python_dependency.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" namespace duckdb { +class PythonRegisteredObjectState : public ClientContextState { +public: + static constexpr const char *Key = "python_registered_objects"; + + void Register(const string &name, const py::object &object); + void Unregister(const string &name); + py::object Get(const string &name); + bool Contains(const string &name); + +private: + mutex lock; + case_insensitive_map_t> registered_objects; +}; + struct PythonReplacementScan { public: static unique_ptr Replace(ClientContext &context, ReplacementScanInput &input, diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index a2350ef8..ecbb8fd4 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -1,5 +1,7 @@ #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" #include "duckdb/common/enums/profiler_format.hpp" #include "duckdb/common/types.hpp" @@ -52,6 +54,17 @@ shared_ptr DuckDBPyConnection::import_cache = nullptr; PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global std::string DuckDBPyConnection::formatted_python_version = ""; +static shared_ptr GetPythonRegisteredObjectState(ClientContext &context) { + return context.registered_state->GetOrCreate(PythonRegisteredObjectState::Key); +} + +static bool TemporaryObjectExists(ClientContext &context, const string &name) { + auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); + EntryLookupInfo lookup_info(CatalogType::TABLE_ENTRY, name); + auto entry = catalog.GetEntry(context, DEFAULT_SCHEMA, lookup_info, OnEntryNotFound::RETURN_NULL); + return entry != nullptr; +} + DuckDBPyConnection::~DuckDBPyConnection() { try { py::gil_scoped_release gil; @@ -743,11 +756,16 @@ shared_ptr DuckDBPyConnection::RegisterPythonObject(const st const py::object &python_object) { auto &connection = con.GetConnection(); auto &client = *connection.context; - auto object = PythonReplacementScan::ReplacementObject(python_object, name, client); - auto view_rel = make_shared_ptr(connection.context, std::move(object), name); - bool replace = registered_objects.count(name); - view_rel->CreateView(name, replace, true); - registered_objects.insert(name); + auto registered_state = GetPythonRegisteredObjectState(client); + if (!registered_state->Contains(name)) { + bool temp_object_exists = false; + client.RunFunctionInTransaction([&]() { temp_object_exists = TemporaryObjectExists(client, name); }, false); + if (temp_object_exists) { + throw CatalogException("View with name \"%s\" already exists!", name); + } + } + PythonReplacementScan::ReplacementObject(python_object, name, client); + registered_state->Register(name, python_object); return shared_from_this(); } @@ -1821,15 +1839,12 @@ unordered_set DuckDBPyConnection::GetTableNames(const string &query, boo shared_ptr DuckDBPyConnection::UnregisterPythonObject(const string &name) { auto &connection = con.GetConnection(); - if (!registered_objects.count(name)) { + auto registered_state = GetPythonRegisteredObjectState(*connection.context); + if (!registered_state->Contains(name)) { return shared_from_this(); } D_ASSERT(py::gil_check()); - py::gil_scoped_release release; - // FIXME: DROP TEMPORARY VIEW? doesn't exist? - const auto quoted_name = SQLQuotedIdentifier::ToString(name); - connection.Query("DROP VIEW " + quoted_name + ""); - registered_objects.erase(name); + registered_state->Unregister(name); return shared_from_this(); } diff --git a/src/duckdb_py/python_replacement_scan.cpp b/src/duckdb_py/python_replacement_scan.cpp index 8bff9e8f..de8cb06d 100644 --- a/src/duckdb_py/python_replacement_scan.cpp +++ b/src/duckdb_py/python_replacement_scan.cpp @@ -16,6 +16,34 @@ namespace duckdb { +void PythonRegisteredObjectState::Register(const string &name, const py::object &object) { + py::gil_scoped_acquire gil; + lock_guard guard(lock); + registered_objects[name] = PythonDependencyItem::Create(object); +} + +void PythonRegisteredObjectState::Unregister(const string &name) { + py::gil_scoped_acquire gil; + lock_guard guard(lock); + registered_objects.erase(name); +} + +py::object PythonRegisteredObjectState::Get(const string &name) { + py::gil_scoped_acquire gil; + lock_guard guard(lock); + auto entry = registered_objects.find(name); + if (entry == registered_objects.end()) { + return py::none(); + } + auto &dependency = entry->second->Cast(); + return dependency.object->obj; +} + +bool PythonRegisteredObjectState::Contains(const string &name) { + lock_guard guard(lock); + return registered_objects.find(name) != registered_objects.end(); +} + static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function, vector> &children, ClientProperties &client_properties, PyArrowObjectType type, DatabaseInstance &db) { @@ -238,6 +266,16 @@ static unique_ptr ReplaceInternal(ClientContext &context, const string return nullptr; } + auto registered_objects = + context.registered_state->Get(PythonRegisteredObjectState::Key); + if (registered_objects) { + py::gil_scoped_acquire acquire; + auto entry = registered_objects->Get(table_name); + if (!entry.is_none()) { + return PythonReplacementScan::TryReplacementObject(entry, table_name, context); + } + } + lookup_result = context.TryGetCurrentSetting("python_scan_all_frames", result); D_ASSERT((bool)lookup_result); auto scan_all_frames = result.GetValue(); diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index c89ae320..41243a5c 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -1,5 +1,6 @@ import gc import tempfile +import weakref import pandas as pd import pytest @@ -50,3 +51,20 @@ def test_pandas_unregister2(self, duckdb_cursor): with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() + + def test_pandas_unregister_releases_object_inside_transaction(self, duckdb_cursor): + duckdb_cursor.execute("CREATE TABLE t(i BIGINT)") + duckdb_cursor.begin() + + df = pd.DataFrame({"i": [1, 2, 3]}) + ref = weakref.ref(df) + + duckdb_cursor.register("dataframe", df) + duckdb_cursor.execute("INSERT INTO t SELECT * FROM dataframe") + duckdb_cursor.unregister("dataframe") + + del df + gc.collect() + + assert ref() is None + duckdb_cursor.rollback()