photolib/photoapp/dbutils.py

137 lines
4.8 KiB
Python

from contextlib import closing
import sqlalchemy
import cherrypy
from cherrypy.process import plugins
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy import func
Base = declarative_base()
engine_specific_options = {"sqlite": dict(connect_args={'check_same_thread': False},
poolclass=NullPool,
pool_pre_ping=True),
"mysql": dict(pool_pre_ping=True)}
def get_engine_options(uri):
for engine_prefix, options in engine_specific_options.items():
if uri.startswith(engine_prefix):
return options
return {}
def get_db_engine(uri, debug=False):
engine = sqlalchemy.create_engine(uri, **get_engine_options(uri), echo=debug)
Base.metadata.create_all(engine)
return engine
def create_db_sessionmaker(engine):
session = sqlalchemy.orm.scoped_session(sessionmaker(autoflush=True, autocommit=False))
session.configure(bind=engine)
return session
def driver_statement(statements):
"""
Select a value from the passed dict based on the sql driver in use. Must be used in request context. For example:
Sqlite and mysql use different date functions. This function can be used to build queries supporting either:
date_format = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
"mysql": lambda date_format, value: func.date_format(value, date_format)})
rows = db.query(PhotoSet.id, date_format('%Y-%m=%d', PhotoSet.date).label('year')).all()
:param statements: dict of driver_type->value. since sqlalchemy drivers vary per language (e.g. pymysql, pysqlite),
it is checked if the driver_type is a substring of sqlalchemy's driver name.
:type statements: dict
"""
driver = cherrypy.request.db.connection().engine.driver
for key, lambda_ in statements.items():
if key in driver:
return lambda_
raise Exception(f"Statement not supported for driver {driver}")
def date_format(date_format, value):
stmt = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
"mysql": lambda date_format, value: func.date_format(value, date_format)})
return stmt(date_format, value)
class DbAlias(object):
"""
This provides a shorter alias for the cherrypy.request.db object, which is a database session created bound to the
current request. Since the `db` attribute doesn't exist until a request is received, we cannot simply reference it
with another variable.
And instance of this class acts as an object proxy to the database object in cherrypy.request.db.
"""
def __getattr__(self, attr):
return getattr(cherrypy.request.db, attr)
db = DbAlias()
class SAEnginePlugin(plugins.SimplePlugin):
def __init__(self, bus, dbcon):
plugins.SimplePlugin.__init__(self, bus)
self.sa_engine = dbcon
self.bus.subscribe("bind", self.bind)
def start(self):
Base.metadata.create_all(self.sa_engine)
def bind(self, session):
session.configure(bind=self.sa_engine)
class SATool(cherrypy.Tool):
def __init__(self):
cherrypy.Tool.__init__(self, 'before_request_body',
self.bind_session,
priority=49) # slightly earlier than Sessions tool, which is 50 or 60
self.session = sqlalchemy.orm.scoped_session(sessionmaker(autoflush=True, autocommit=False))
def _setup(self):
cherrypy.Tool._setup(self)
cherrypy.request.hooks.attach('on_end_resource', self.commit_transaction, priority=80)
def bind_session(self):
cherrypy.engine.publish('bind', self.session)
cherrypy.request.db = self.session
def commit_transaction(self):
cherrypy.request.db = None
try:
self.session.commit()
except Exception:
self.session.rollback()
raise
finally:
self.session.remove()
def cursorwrap(func):
"""
Provides a cursor to the wrapped method as the first arg. This assumes that the wrapped function belongs to an
object because the cursor is sourced from the object's session attribute which is assumed to be a
sessionmaker callable.
"""
def wrapped(*args, **kwargs):
self = args[0]
# passthru if someone already passed a session
if len(args) >= 2 and isinstance(args[1], (Session, sqlalchemy.orm.scoping.scoped_session, DbAlias)):
return func(*args, **kwargs)
else:
with closing(self.session()) as c:
return func(self, c, *args[1:], **kwargs)
return wrapped