import sqlalchemy import cherrypy from cherrypy.process import plugins from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.pool import NullPool from sqlalchemy.orm import sessionmaker from sqlalchemy import func Base = declarative_base() engine_specific_options = {"sqlite": dict(connect_args={'check_same_thread': False}, poolclass=NullPool, pool_pre_ping=True), "mysql": dict(pool_pre_ping=True)} def get_engine_options(uri): for engine_prefix, options in engine_specific_options.items(): if uri.startswith(engine_prefix): return options return {} def get_db_engine(uri, debug=False): engine = sqlalchemy.create_engine(uri, **get_engine_options(uri), echo=debug) Base.metadata.create_all(engine) return engine def get_db_session(uri): engine = get_db_engine(uri) session = sessionmaker() session.configure(bind=engine) return session def driver_statement(statements): """ Select a value from the passed dict based on the sql driver in use. Must be used in request context. For example: Sqlite and mysql use different date functions. This function can be used to build queries supporting either: date_format = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value), "mysql": lambda date_format, value: func.date_format(value, date_format)}) rows = db.query(PhotoSet.id, date_format('%Y-%m=%d', PhotoSet.date).label('year')).all() :param statements: dict of driver_type->value. since sqlalchemy drivers vary per language (e.g. pymysql, pysqlite), it is checked if the driver_type is a substring of sqlalchemy's driver name. :type statements: dict """ driver = cherrypy.request.db.connection().engine.driver for key, lambda_ in statements.items(): if key in driver: return lambda_ raise Exception(f"Statement not supported for driver {driver}") def date_format(date_format, value): stmt = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value), "mysql": lambda date_format, value: func.date_format(value, date_format)}) return stmt(date_format, value) class DbAlias(object): """ This provides a shorter alias for the cherrypy.request.db object, which is a database session created bound to the current request. Since the `db` attribute doesn't exist until a request is received, we cannot simply reference it with another variable. And instance of this class acts as an object proxy to the database object in cherrypy.request.db. """ def __getattr__(self, attr): return getattr(cherrypy.request.db, attr) db = DbAlias() class SAEnginePlugin(plugins.SimplePlugin): def __init__(self, bus, dbcon): plugins.SimplePlugin.__init__(self, bus) self.sa_engine = dbcon self.bus.subscribe("bind", self.bind) def start(self): Base.metadata.create_all(self.sa_engine) def bind(self, session): session.configure(bind=self.sa_engine) class SATool(cherrypy.Tool): def __init__(self): cherrypy.Tool.__init__(self, 'before_request_body', self.bind_session, priority=49) # slightly earlier than Sessions tool, which is 50 or 60 self.session = sqlalchemy.orm.scoped_session( sqlalchemy.orm.sessionmaker(autoflush=True, autocommit=False)) def _setup(self): cherrypy.Tool._setup(self) cherrypy.request.hooks.attach('on_end_resource', self.commit_transaction, priority=80) def bind_session(self): cherrypy.engine.publish('bind', self.session) cherrypy.request.db = self.session def commit_transaction(self): cherrypy.request.db = None try: self.session.commit() except Exception: self.session.rollback() raise finally: self.session.remove()