138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
from contextlib import closing
|
|
import sqlalchemy
|
|
import cherrypy
|
|
from cherrypy.process import plugins
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.pool import NullPool
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.orm.session import Session
|
|
from sqlalchemy import func
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
engine_specific_options = {"sqlite": dict(connect_args={'check_same_thread': False},
|
|
poolclass=NullPool,
|
|
pool_pre_ping=True),
|
|
"mysql": dict(pool_pre_ping=True)}
|
|
|
|
|
|
def get_engine_options(uri):
|
|
for engine_prefix, options in engine_specific_options.items():
|
|
if uri.startswith(engine_prefix):
|
|
return options
|
|
return {}
|
|
|
|
|
|
def get_db_engine(uri, debug=False):
|
|
engine = sqlalchemy.create_engine(uri, **get_engine_options(uri), echo=debug)
|
|
Base.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
def create_db_sessionmaker(engine):
|
|
session = sessionmaker()
|
|
session.configure(bind=engine)
|
|
return session
|
|
|
|
|
|
def driver_statement(statements):
|
|
"""
|
|
Select a value from the passed dict based on the sql driver in use. Must be used in request context. For example:
|
|
Sqlite and mysql use different date functions. This function can be used to build queries supporting either:
|
|
|
|
date_format = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
|
|
"mysql": lambda date_format, value: func.date_format(value, date_format)})
|
|
|
|
rows = db.query(PhotoSet.id, date_format('%Y-%m=%d', PhotoSet.date).label('year')).all()
|
|
|
|
:param statements: dict of driver_type->value. since sqlalchemy drivers vary per language (e.g. pymysql, pysqlite),
|
|
it is checked if the driver_type is a substring of sqlalchemy's driver name.
|
|
:type statements: dict
|
|
"""
|
|
driver = cherrypy.request.db.connection().engine.driver
|
|
for key, lambda_ in statements.items():
|
|
if key in driver:
|
|
return lambda_
|
|
raise Exception(f"Statement not supported for driver {driver}")
|
|
|
|
|
|
def date_format(date_format, value):
|
|
stmt = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
|
|
"mysql": lambda date_format, value: func.date_format(value, date_format)})
|
|
return stmt(date_format, value)
|
|
|
|
|
|
class DbAlias(object):
|
|
"""
|
|
This provides a shorter alias for the cherrypy.request.db object, which is a database session created bound to the
|
|
current request. Since the `db` attribute doesn't exist until a request is received, we cannot simply reference it
|
|
with another variable.
|
|
|
|
And instance of this class acts as an object proxy to the database object in cherrypy.request.db.
|
|
"""
|
|
def __getattr__(self, attr):
|
|
return getattr(cherrypy.request.db, attr)
|
|
|
|
|
|
db = DbAlias()
|
|
|
|
|
|
class SAEnginePlugin(plugins.SimplePlugin):
|
|
def __init__(self, bus, dbcon):
|
|
plugins.SimplePlugin.__init__(self, bus)
|
|
self.sa_engine = dbcon
|
|
self.bus.subscribe("bind", self.bind)
|
|
|
|
def start(self):
|
|
Base.metadata.create_all(self.sa_engine)
|
|
|
|
def bind(self, session):
|
|
session.configure(bind=self.sa_engine)
|
|
|
|
|
|
class SATool(cherrypy.Tool):
|
|
def __init__(self):
|
|
cherrypy.Tool.__init__(self, 'before_request_body',
|
|
self.bind_session,
|
|
priority=49) # slightly earlier than Sessions tool, which is 50 or 60
|
|
|
|
self.session = sqlalchemy.orm.scoped_session(
|
|
sqlalchemy.orm.sessionmaker(autoflush=True, autocommit=False))
|
|
|
|
def _setup(self):
|
|
cherrypy.Tool._setup(self)
|
|
cherrypy.request.hooks.attach('on_end_resource', self.commit_transaction, priority=80)
|
|
|
|
def bind_session(self):
|
|
cherrypy.engine.publish('bind', self.session)
|
|
cherrypy.request.db = self.session
|
|
|
|
def commit_transaction(self):
|
|
cherrypy.request.db = None
|
|
try:
|
|
self.session.commit()
|
|
except Exception:
|
|
self.session.rollback()
|
|
raise
|
|
finally:
|
|
self.session.remove()
|
|
|
|
|
|
def cursorwrap(func):
|
|
"""
|
|
Provides a cursor to the wrapped method as the first arg. This assumes that the wrapped function belongs to an
|
|
object because the cursor is sourced from the object's session attribute which is assumed to be a
|
|
sessionmaker callable.
|
|
"""
|
|
def wrapped(*args, **kwargs):
|
|
self = args[0]
|
|
# passthru if someone already passed a session
|
|
if len(args) >= 2 and isinstance(args[1], (Session, sqlalchemy.orm.scoping.scoped_session, DbAlias)):
|
|
return func(*args, **kwargs)
|
|
else:
|
|
with closing(self.session()) as c:
|
|
return func(self, c, *args[1:], **kwargs)
|
|
return wrapped
|