diff --git a/tests/test_connections.py b/tests/test_connections.py index 75bb3af..f6516fb 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -50,6 +50,22 @@ def _test_isolation_level_read_only( maybe_await(cursor.execute_scheme("DROP TABLE foo")) + def _test_commit_rollback_after_begin( + self, + connection: dbapi.Connection, + isolation_level: str, + ) -> None: + connection.set_isolation_level(isolation_level) + + for _ in range(10): + maybe_await(connection.begin()) + maybe_await(connection.commit()) + + for _ in range(10): + maybe_await(connection.begin()) + maybe_await(connection.rollback()) + + def _test_connection(self, connection: dbapi.Connection) -> None: maybe_await(connection.commit()) maybe_await(connection.rollback()) @@ -377,6 +393,26 @@ def test_isolation_level_read_only( connection, isolation_level, read_only ) + @pytest.mark.parametrize( + ("isolation_level"), + [ + (dbapi.IsolationLevel.SERIALIZABLE), + (dbapi.IsolationLevel.AUTOCOMMIT), + (dbapi.IsolationLevel.ONLINE_READONLY), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT), + (dbapi.IsolationLevel.STALE_READONLY), + (dbapi.IsolationLevel.SNAPSHOT_READONLY), + ], + ) + def test_commit_rollback_after_begin( + self, + isolation_level: str, + connection: dbapi.Connection, + ) -> None: + self._test_commit_rollback_after_begin( + connection, isolation_level + ) + def test_connection(self, connection: dbapi.Connection) -> None: self._test_connection(connection) @@ -448,6 +484,29 @@ async def test_isolation_level_read_only( read_only, ) + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("isolation_level"), + [ + (dbapi.IsolationLevel.SERIALIZABLE), + (dbapi.IsolationLevel.AUTOCOMMIT), + (dbapi.IsolationLevel.ONLINE_READONLY), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT), + (dbapi.IsolationLevel.STALE_READONLY), + (dbapi.IsolationLevel.SNAPSHOT_READONLY), + ], + ) + async def test_commit_rollback_after_begin( + self, + isolation_level: str, + connection: dbapi.AsyncConnection, + ) -> None: + await greenlet_spawn( + self._test_commit_rollback_after_begin, + connection, + isolation_level + ) + @pytest.mark.asyncio async def test_connection(self, connection: dbapi.AsyncConnection) -> None: await greenlet_spawn(self._test_connection, connection) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 345af88..61c692f 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -234,20 +234,22 @@ def begin(self) -> None: @handle_ydb_errors def commit(self) -> None: - if self._tx_context and self._tx_context.tx_id: + if self._tx_context: settings = self._get_request_settings() self._tx_context.commit(settings=settings) - self._session_pool.release(self._session) self._tx_context = None + if self._session: + self._session_pool.release(self._session) self._session = None @handle_ydb_errors def rollback(self) -> None: - if self._tx_context and self._tx_context.tx_id: + if self._tx_context: settings = self._get_request_settings() self._tx_context.rollback(settings=settings) - self._session_pool.release(self._session) self._tx_context = None + if self._session: + self._session_pool.release(self._session) self._session = None @handle_ydb_errors @@ -424,21 +426,23 @@ async def begin(self) -> None: @handle_ydb_errors async def commit(self) -> None: - if self._session and self._tx_context and self._tx_context.tx_id: + if self._tx_context: settings = self._get_request_settings() await self._tx_context.commit(settings=settings) + self._tx_context = None + if self._session: await self._session_pool.release(self._session) self._session = None - self._tx_context = None @handle_ydb_errors async def rollback(self) -> None: - if self._session and self._tx_context and self._tx_context.tx_id: + if self._tx_context: settings = self._get_request_settings() await self._tx_context.rollback(settings=settings) + self._tx_context = None + if self._session: await self._session_pool.release(self._session) self._session = None - self._tx_context = None @handle_ydb_errors async def close(self) -> None: