Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def _test_isolation_level_read_only(

maybe_await(cursor.execute_scheme("DROP TABLE foo"))

def _test_commit_rollback_after_begin(
self,
connection: dbapi.Connection,
isolation_level: str,
) -> None:
connection.set_isolation_level(isolation_level)

for _ in range(10):
maybe_await(connection.begin())
maybe_await(connection.commit())

for _ in range(10):
maybe_await(connection.begin())
maybe_await(connection.rollback())


def _test_connection(self, connection: dbapi.Connection) -> None:
maybe_await(connection.commit())
maybe_await(connection.rollback())
Expand Down Expand Up @@ -377,6 +393,26 @@ def test_isolation_level_read_only(
connection, isolation_level, read_only
)

@pytest.mark.parametrize(
("isolation_level"),
[
(dbapi.IsolationLevel.SERIALIZABLE),
(dbapi.IsolationLevel.AUTOCOMMIT),
(dbapi.IsolationLevel.ONLINE_READONLY),
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT),
(dbapi.IsolationLevel.STALE_READONLY),
(dbapi.IsolationLevel.SNAPSHOT_READONLY),
],
)
def test_commit_rollback_after_begin(
self,
isolation_level: str,
connection: dbapi.Connection,
) -> None:
self._test_commit_rollback_after_begin(
connection, isolation_level
)

def test_connection(self, connection: dbapi.Connection) -> None:
self._test_connection(connection)

Expand Down Expand Up @@ -448,6 +484,29 @@ async def test_isolation_level_read_only(
read_only,
)

@pytest.mark.asyncio
@pytest.mark.parametrize(
("isolation_level"),
[
(dbapi.IsolationLevel.SERIALIZABLE),
(dbapi.IsolationLevel.AUTOCOMMIT),
(dbapi.IsolationLevel.ONLINE_READONLY),
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT),
(dbapi.IsolationLevel.STALE_READONLY),
(dbapi.IsolationLevel.SNAPSHOT_READONLY),
],
)
async def test_commit_rollback_after_begin(
self,
isolation_level: str,
connection: dbapi.AsyncConnection,
) -> None:
await greenlet_spawn(
self._test_commit_rollback_after_begin,
connection,
isolation_level
)

@pytest.mark.asyncio
async def test_connection(self, connection: dbapi.AsyncConnection) -> None:
await greenlet_spawn(self._test_connection, connection)
Expand Down
20 changes: 12 additions & 8 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,22 @@ def begin(self) -> None:

@handle_ydb_errors
def commit(self) -> None:
if self._tx_context and self._tx_context.tx_id:
if self._tx_context:
settings = self._get_request_settings()
self._tx_context.commit(settings=settings)
self._session_pool.release(self._session)
self._tx_context = None
if self._session:
self._session_pool.release(self._session)
self._session = None

@handle_ydb_errors
def rollback(self) -> None:
if self._tx_context and self._tx_context.tx_id:
if self._tx_context:
settings = self._get_request_settings()
self._tx_context.rollback(settings=settings)
self._session_pool.release(self._session)
self._tx_context = None
if self._session:
self._session_pool.release(self._session)
self._session = None

@handle_ydb_errors
Expand Down Expand Up @@ -424,21 +426,23 @@ async def begin(self) -> None:

@handle_ydb_errors
async def commit(self) -> None:
if self._session and self._tx_context and self._tx_context.tx_id:
if self._tx_context:
settings = self._get_request_settings()
await self._tx_context.commit(settings=settings)
self._tx_context = None
if self._session:
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None

@handle_ydb_errors
async def rollback(self) -> None:
if self._session and self._tx_context and self._tx_context.tx_id:
if self._tx_context:
settings = self._get_request_settings()
await self._tx_context.rollback(settings=settings)
self._tx_context = None
if self._session:
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None

@handle_ydb_errors
async def close(self) -> None:
Expand Down