""" Integration test — Hypothesis #2: oracle pool concurrency — no state leak between connections. Verifies that: 1. Two concurrent asyncio tasks can acquire connections from the same pool simultaneously 2. Each connection operates independently (no shared cursor state, no cross-connection leakage) 3. session_callback (if set) runs per-connection, not per-pool Uses the 'mariusm_test' pool (ROA_WEB user). Requires live Oracle connection. Run: cd /workspace/roa2web python -m pytest backend/modules/service_auto/tests/test_pool_concurrency.py -v -m integration """ import asyncio import os import sys import time import pytest sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..')) from shared.database.oracle_pool import OracleMultiPool SECRETS_FILE = os.path.join( os.path.dirname(__file__), '..', '..', '..', 'secrets', 'mariusm_test.oracle_pass' ) HOST = "10.0.20.121" PORT = 1521 SERVICE_NAME = "ROA" USER = "ROA_WEB" @pytest.fixture(scope="module") def pool(): """Dedicated OracleMultiPool instance (not the global one) for isolation.""" with open(SECRETS_FILE) as f: pwd = f.read().strip() p = OracleMultiPool.__new__(OracleMultiPool) p._pools = {} p._pool_configs = {} p._initialized = False import asyncio as _asyncio p._pool_lock = _asyncio.Lock() p.register_server( server_id="concurrency_test", host=HOST, port=PORT, user=USER, password=pwd, service_name=SERVICE_NAME, min_connections=2, max_connections=5, ) yield p # Cleanup import asyncio as _asyncio _asyncio.get_event_loop().run_until_complete(p.close_pool()) @pytest.mark.integration @pytest.mark.asyncio async def test_two_concurrent_connections_return_correct_results(pool): """ Two asyncio tasks run simultaneously on the same pool. Each uses a different bind value — results must not cross. """ results = {} async def query_task(task_id: int, expected_val: int): async with pool.get_connection("concurrency_test") as conn: with conn.cursor() as cur: # Short sleep to maximise overlap window await asyncio.sleep(0.01) cur.execute("SELECT :v FROM DUAL", {"v": expected_val}) row = cur.fetchone() results[task_id] = row[0] await asyncio.gather( query_task(1, 111), query_task(2, 222), ) assert results[1] == 111, f"Task 1 expected 111, got {results[1]}" assert results[2] == 222, f"Task 2 expected 222, got {results[2]}" @pytest.mark.integration @pytest.mark.asyncio async def test_session_callback_runs_per_connection(pool): """ Register a session_callback that writes to a list. Verify it fires each time a new connection is acquired. session_callback must NOT bleed state across connections. """ callback_log = [] def schema_callback(connection, requested_tag): """Simulates ALTER SESSION SET CURRENT_SCHEMA; logs invocation.""" callback_log.append(id(connection)) # Register a second server config with session_callback with open(SECRETS_FILE) as f: pwd = f.read().strip() pool.register_server( server_id="cb_test", host=HOST, port=PORT, user=USER, password=pwd, service_name=SERVICE_NAME, min_connections=1, max_connections=3, session_callback=schema_callback, ) # Acquire two connections sequentially (pool min=1 so first reuses, second may create) conn_ids = [] async with pool.get_connection("cb_test") as conn: conn_ids.append(id(conn)) async with pool.get_connection("cb_test") as conn: conn_ids.append(id(conn)) # Callback must have fired at least once (at pool creation / first acquire) assert len(callback_log) >= 1, ( "session_callback never called — pool did not invoke it on connection creation" ) await pool.close_pool("cb_test") @pytest.mark.integration @pytest.mark.asyncio async def test_ten_concurrent_queries_no_errors(pool): """ Stress: 10 concurrent queries on pool with max_connections=5. All must complete without errors (pool queues excess requests via POOL_GETMODE_WAIT). """ errors = [] async def single_query(i: int): try: async with pool.get_connection("concurrency_test") as conn: with conn.cursor() as cur: cur.execute("SELECT :i FROM DUAL", {"i": i}) val = cur.fetchone()[0] assert val == i, f"Expected {i}, got {val}" except Exception as e: errors.append(f"task {i}: {e}") t0 = time.perf_counter() await asyncio.gather(*[single_query(i) for i in range(10)]) elapsed = time.perf_counter() - t0 assert not errors, f"Errors in concurrent queries: {errors}" print(f"\n[CONCURRENCY] 10 queries completed in {elapsed*1000:.0f}ms, no errors ✅")