|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import sqlalchemy
- import cherrypy
- from cherrypy.process import plugins
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.pool import NullPool
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy import func
-
- Base = declarative_base()
-
-
- engine_specific_options = {"sqlite": dict(connect_args={'check_same_thread': False},
- poolclass=NullPool,
- pool_pre_ping=True),
- "mysql": dict(pool_pre_ping=True)}
-
-
- def get_engine_options(uri):
- for engine_prefix, options in engine_specific_options.items():
- if uri.startswith(engine_prefix):
- return options
- return {}
-
-
- def get_db_engine(uri, debug=False):
- engine = sqlalchemy.create_engine(uri, **get_engine_options(uri), echo=debug)
- Base.metadata.create_all(engine)
- return engine
-
-
- def get_db_session(uri):
- engine = get_db_engine(uri)
- session = sessionmaker()
- session.configure(bind=engine)
- return session
-
-
- def driver_statement(statements):
- """
- Select a value from the passed dict based on the sql driver in use. Must be used in request context. For example:
- Sqlite and mysql use different date functions. This function can be used to build queries supporting either:
-
- date_format = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
- "mysql": lambda date_format, value: func.date_format(value, date_format)})
-
- rows = db.query(PhotoSet.id, date_format('%Y-%m=%d', PhotoSet.date).label('year')).all()
-
- :param statements: dict of driver_type->value. since sqlalchemy drivers vary per language (e.g. pymysql, pysqlite),
- it is checked if the driver_type is a substring of sqlalchemy's driver name.
- :type statements: dict
- """
- driver = cherrypy.request.db.connection().engine.driver
- for key, lambda_ in statements.items():
- if key in driver:
- return lambda_
- raise Exception(f"Statement not supported for driver {driver}")
-
-
- def date_format(date_format, value):
- stmt = driver_statement({"sqlite": lambda date_format, value: func.strftime(date_format, value),
- "mysql": lambda date_format, value: func.date_format(value, date_format)})
- return stmt(date_format, value)
-
-
- class DbAlias(object):
- """
- This provides a shorter alias for the cherrypy.request.db object, which is a database session created bound to the
- current request. Since the `db` attribute doesn't exist until a request is received, we cannot simply reference it
- with another variable.
-
- And instance of this class acts as an object proxy to the database object in cherrypy.request.db.
- """
- def __getattr__(self, attr):
- return getattr(cherrypy.request.db, attr)
-
-
- db = DbAlias()
-
-
- class SAEnginePlugin(plugins.SimplePlugin):
- def __init__(self, bus, dbcon):
- plugins.SimplePlugin.__init__(self, bus)
- self.sa_engine = dbcon
- self.bus.subscribe("bind", self.bind)
-
- def start(self):
- Base.metadata.create_all(self.sa_engine)
-
- def bind(self, session):
- session.configure(bind=self.sa_engine)
-
-
- class SATool(cherrypy.Tool):
- def __init__(self):
- cherrypy.Tool.__init__(self, 'before_request_body',
- self.bind_session,
- priority=49) # slightly earlier than Sessions tool, which is 50 or 60
-
- self.session = sqlalchemy.orm.scoped_session(
- sqlalchemy.orm.sessionmaker(autoflush=True, autocommit=False))
-
- def _setup(self):
- cherrypy.Tool._setup(self)
- cherrypy.request.hooks.attach('on_end_resource', self.commit_transaction, priority=80)
-
- def bind_session(self):
- cherrypy.engine.publish('bind', self.session)
- cherrypy.request.db = self.session
-
- def commit_transaction(self):
- cherrypy.request.db = None
- try:
- self.session.commit()
- except Exception:
- self.session.rollback()
- raise
- finally:
- self.session.remove()
|