forked from basicmachines-co/basic-memory
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb.py
More file actions
211 lines (167 loc) · 7.25 KB
/
db.py
File metadata and controls
211 lines (167 loc) · 7.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import asyncio
from contextlib import asynccontextmanager
from enum import Enum, auto
from pathlib import Path
from typing import AsyncGenerator, Optional
from basic_memory.config import BasicMemoryConfig, ConfigManager
from alembic import command
from alembic.config import Config
from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
AsyncSession,
AsyncEngine,
async_scoped_session,
)
from basic_memory.repository.search_repository import SearchRepository
# Module level state
_engine: Optional[AsyncEngine] = None
_session_maker: Optional[async_sessionmaker[AsyncSession]] = None
_migrations_completed: bool = False
class DatabaseType(Enum):
"""Types of supported databases."""
MEMORY = auto()
FILESYSTEM = auto()
@classmethod
def get_db_url(cls, db_path: Path, db_type: "DatabaseType") -> str:
"""Get SQLAlchemy URL for database path."""
if db_type == cls.MEMORY:
logger.info("Using in-memory SQLite database")
return "sqlite+aiosqlite://"
return f"sqlite+aiosqlite:///{db_path}" # pragma: no cover
def get_scoped_session_factory(
session_maker: async_sessionmaker[AsyncSession],
) -> async_scoped_session:
"""Create a scoped session factory scoped to current task."""
return async_scoped_session(session_maker, scopefunc=asyncio.current_task)
@asynccontextmanager
async def scoped_session(
session_maker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncSession, None]:
"""
Get a scoped session with proper lifecycle management.
Args:
session_maker: Session maker to create scoped sessions from
"""
factory = get_scoped_session_factory(session_maker)
session = factory()
try:
await session.execute(text("PRAGMA foreign_keys=ON"))
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
await factory.remove()
def _create_engine_and_session(
db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
"""Internal helper to create engine and session maker."""
db_url = DatabaseType.get_db_url(db_path, db_type)
logger.debug(f"Creating engine for db_url: {db_url}")
engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
session_maker = async_sessionmaker(engine, expire_on_commit=False)
return engine, session_maker
async def get_or_create_db(
db_path: Path,
db_type: DatabaseType = DatabaseType.FILESYSTEM,
ensure_migrations: bool = True,
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover
"""Get or create database engine and session maker."""
global _engine, _session_maker
if _engine is None:
_engine, _session_maker = _create_engine_and_session(db_path, db_type)
# Run migrations automatically unless explicitly disabled
if ensure_migrations:
app_config = ConfigManager().config
await run_migrations(app_config, db_type)
# These checks should never fail since we just created the engine and session maker
# if they were None, but we'll check anyway for the type checker
if _engine is None:
logger.error("Failed to create database engine", db_path=str(db_path))
raise RuntimeError("Database engine initialization failed")
if _session_maker is None:
logger.error("Failed to create session maker", db_path=str(db_path))
raise RuntimeError("Session maker initialization failed")
return _engine, _session_maker
async def shutdown_db() -> None: # pragma: no cover
"""Clean up database connections."""
global _engine, _session_maker, _migrations_completed
if _engine:
await _engine.dispose()
_engine = None
_session_maker = None
_migrations_completed = False
@asynccontextmanager
async def engine_session_factory(
db_path: Path,
db_type: DatabaseType = DatabaseType.MEMORY,
) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]:
"""Create engine and session factory.
Note: This is primarily used for testing where we want a fresh database
for each test. For production use, use get_or_create_db() instead.
"""
global _engine, _session_maker, _migrations_completed
db_url = DatabaseType.get_db_url(db_path, db_type)
logger.debug(f"Creating engine for db_url: {db_url}")
_engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
try:
_session_maker = async_sessionmaker(_engine, expire_on_commit=False)
# Verify that engine and session maker are initialized
if _engine is None: # pragma: no cover
logger.error("Database engine is None in engine_session_factory")
raise RuntimeError("Database engine initialization failed")
if _session_maker is None: # pragma: no cover
logger.error("Session maker is None in engine_session_factory")
raise RuntimeError("Session maker initialization failed")
yield _engine, _session_maker
finally:
if _engine:
await _engine.dispose()
_engine = None
_session_maker = None
_migrations_completed = False
async def run_migrations(
app_config: BasicMemoryConfig, database_type=DatabaseType.FILESYSTEM, force: bool = False
): # pragma: no cover
"""Run any pending alembic migrations."""
global _migrations_completed
# Skip if migrations already completed unless forced
if _migrations_completed and not force:
logger.debug("Migrations already completed in this session, skipping")
return
logger.info("Running database migrations...")
try:
# Get the absolute path to the alembic directory relative to this file
alembic_dir = Path(__file__).parent / "alembic"
config = Config()
# Set required Alembic config options programmatically
config.set_main_option("script_location", str(alembic_dir))
config.set_main_option(
"file_template",
"%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s",
)
config.set_main_option("timezone", "UTC")
config.set_main_option("revision_environment", "false")
config.set_main_option(
"sqlalchemy.url", DatabaseType.get_db_url(app_config.database_path, database_type)
)
command.upgrade(config, "head")
logger.info("Migrations completed successfully")
# Get session maker - ensure we don't trigger recursive migration calls
if _session_maker is None:
_, session_maker = _create_engine_and_session(app_config.database_path, database_type)
else:
session_maker = _session_maker
# initialize the search Index schema
# the project_id is not used for init_search_index, so we pass a dummy value
await SearchRepository(session_maker, 1).init_search_index()
# Mark migrations as completed
_migrations_completed = True
except Exception as e: # pragma: no cover
logger.error(f"Error running migrations: {e}")
raise