util.py :  » Database » SQLAlchemy » SQLAlchemy-0.6.0 » lib » sqlalchemy » orm » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Database » SQLAlchemy 
SQLAlchemy » SQLAlchemy 0.6.0 » lib » sqlalchemy » orm » util.py
# mapper/util.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

import sqlalchemy.exceptions as sa_exc
from sqlalchemy import sql,util
from sqlalchemy.sql import expression,util
from sqlalchemy.orm.interfaces import MapperExtension,EXT_CONTINUE,PropComparator,\
                                        MapperProperty, AttributeExtension
from sqlalchemy.orm import attributes,exc

mapperlib = None

all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
                          "expunge", "save-update", "refresh-expire",
                          "none"))

_INSTRUMENTOR = ('mapper', 'instrumentor')

class CascadeOptions(object):
    """Keeps track of the options sent to relationship().cascade"""

    def __init__(self, arg=""):
        if not arg:
            values = set()
        else:
            values = set(c.strip() for c in arg.split(','))
        self.delete_orphan = "delete-orphan" in values
        self.delete = "delete" in values or "all" in values
        self.save_update = "save-update" in values or "all" in values
        self.merge = "merge" in values or "all" in values
        self.expunge = "expunge" in values or "all" in values
        self.refresh_expire = "refresh-expire" in values or "all" in values

        if self.delete_orphan and not self.delete:
            util.warn("The 'delete-orphan' cascade option requires "
                        "'delete'.  This will raise an error in 0.6.")

        for x in values:
            if x not in all_cascades:
                raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)

    def __contains__(self, item):
        return getattr(self, item.replace("-", "_"), False)

    def __repr__(self):
        return "CascadeOptions(%s)" % repr(",".join(
            [x for x in ['delete', 'save_update', 'merge', 'expunge',
                         'delete_orphan', 'refresh-expire']
             if getattr(self, x, False) is True]))


class Validator(AttributeExtension):
    """Runs a validation method on an attribute value to be set or appended.

    The Validator class is used by the :func:`~sqlalchemy.orm.validates`
    decorator, and direct access is usually not needed.

    """

    def __init__(self, key, validator):
        """Construct a new Validator.

            key - name of the attribute to be validated;
            will be passed as the second argument to
            the validation method (the first is the object instance itself).

            validator - an function or instance method which accepts
            three arguments; an instance (usually just 'self' for a method),
            the key name of the attribute, and the value.  The function should
            return the same value given, unless it wishes to modify it.

        """
        self.key = key
        self.validator = validator

    def append(self, state, value, initiator):
        return self.validator(state.obj(), self.key, value)

    def set(self, state, value, oldvalue, initiator):
        return self.validator(state.obj(), self.key, value)

def polymorphic_union(table_map, typecolname, aliasname='p_union'):
    """Create a ``UNION`` statement used by a polymorphic mapper.

    See  :ref:`concrete_inheritance` for an example of how
    this is used.
    """

    colnames = set()
    colnamemaps = {}
    types = {}
    for key in table_map.keys():
        table = table_map[key]

        # mysql doesnt like selecting from a select; make it an alias of the select
        if isinstance(table, sql.Select):
            table = table.alias()
            table_map[key] = table

        m = {}
        for c in table.c:
            colnames.add(c.key)
            m[c.key] = c
            types[c.key] = c.type
        colnamemaps[table] = m

    def col(name, table):
        try:
            return colnamemaps[table][name]
        except KeyError:
            return sql.cast(sql.null(), types[name]).label(name)

    result = []
    for type, table in table_map.iteritems():
        if typecolname is not None:
            result.append(sql.select([col(name, table) for name in colnames] +
                             [sql.literal_column(
                                sql_util._quote_ddl_expr(type)).label(typecolname)
                              ],
                             from_obj=[table]))
        else:
            result.append(sql.select([col(name, table) for name in colnames],
                                     from_obj=[table]))
    return sql.union_all(*result).alias(aliasname)

def identity_key(*args, **kwargs):
    """Get an identity key.

    Valid call signatures:

    * ``identity_key(class, ident)``

      class
          mapped class (must be a positional argument)

      ident
          primary key, if the key is composite this is a tuple


    * ``identity_key(instance=instance)``

      instance
          object instance (must be given as a keyword arg)

    * ``identity_key(class, row=row)``

      class
          mapped class (must be a positional argument)

      row
          result proxy row (must be given as a keyword arg)

    """
    if args:
        if len(args) == 1:
            class_ = args[0]
            try:
                row = kwargs.pop("row")
            except KeyError:
                ident = kwargs.pop("ident")
        elif len(args) == 2:
            class_, ident = args
        elif len(args) == 3:
            class_, ident = args
        else:
            raise sa_exc.ArgumentError("expected up to three "
                "positional arguments, got %s" % len(args))
        if kwargs:
            raise sa_exc.ArgumentError("unknown keyword arguments: %s"
                % ", ".join(kwargs.keys()))
        mapper = class_mapper(class_)
        if "ident" in locals():
            return mapper.identity_key_from_primary_key(ident)
        return mapper.identity_key_from_row(row)
    instance = kwargs.pop("instance")
    if kwargs:
        raise sa_exc.ArgumentError("unknown keyword arguments: %s"
            % ", ".join(kwargs.keys()))
    mapper = object_mapper(instance)
    return mapper.identity_key_from_instance(instance)

class ExtensionCarrier(dict):
    """Fronts an ordered collection of MapperExtension objects.

    Bundles multiple MapperExtensions into a unified callable unit,
    encapsulating ordering, looping and EXT_CONTINUE logic.  The
    ExtensionCarrier implements the MapperExtension interface, e.g.::

      carrier.after_insert(...args...)

    The dictionary interface provides containment for implemented
    method names mapped to a callable which executes that method
    for participating extensions.

    """

    interface = set(method for method in dir(MapperExtension)
                    if not method.startswith('_'))

    def __init__(self, extensions=None):
        self._extensions = []
        for ext in extensions or ():
            self.append(ext)

    def copy(self):
        return ExtensionCarrier(self._extensions)

    def push(self, extension):
        """Insert a MapperExtension at the beginning of the collection."""
        self._register(extension)
        self._extensions.insert(0, extension)

    def append(self, extension):
        """Append a MapperExtension at the end of the collection."""
        self._register(extension)
        self._extensions.append(extension)

    def __iter__(self):
        """Iterate over MapperExtensions in the collection."""
        return iter(self._extensions)

    def _register(self, extension):
        """Register callable fronts for overridden interface methods."""

        for method in self.interface.difference(self):
            impl = getattr(extension, method, None)
            if impl and impl is not getattr(MapperExtension, method):
                self[method] = self._create_do(method)

    def _create_do(self, method):
        """Return a closure that loops over impls of the named method."""

        def _do(*args, **kwargs):
            for ext in self._extensions:
                ret = getattr(ext, method)(*args, **kwargs)
                if ret is not EXT_CONTINUE:
                    return ret
            else:
                return EXT_CONTINUE
        _do.__name__ = method
        return _do

    @staticmethod
    def _pass(*args, **kwargs):
        return EXT_CONTINUE

    def __getattr__(self, key):
        """Delegate MapperExtension methods to bundled fronts."""

        if key not in self.interface:
            raise AttributeError(key)
        return self.get(key, self._pass)

class ORMAdapter(sql_util.ColumnAdapter):
    """Extends ColumnAdapter to accept ORM entities.

    The selectable is extracted from the given entity,
    and the AliasedClass if any is referenced.

    """
    def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
        self.mapper, selectable, is_aliased_class = _entity_info(entity)
        if is_aliased_class:
            self.aliased_class = entity
        else:
            self.aliased_class = None
        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)

    def replace(self, elem):
        entity = elem._annotations.get('parentmapper', None)
        if not entity or entity.isa(self.mapper):
            return sql_util.ColumnAdapter.replace(self, elem)
        else:
            return None

class AliasedClass(object):
    """Represents an "aliased" form of a mapped class for usage with Query.

    The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias`
    construct, this object mimics the mapped class using a
    __getattr__ scheme and maintains a reference to a
    real :class:`~sqlalchemy.sql.expression.Alias` object.   
    
    Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym::

        # find all pairs of users with the same name
        user_alias = aliased(User)
        session.query(User, user_alias).\\
                        join((user_alias, User.id > user_alias.id)).\\
                        filter(User.name==user_alias.name)

    """
    def __init__(self, cls, alias=None, name=None):
        self.__mapper = _class_to_mapper(cls)
        self.__target = self.__mapper.class_
        if alias is None:
            alias = self.__mapper._with_polymorphic_selectable.alias()
        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
        self.__alias = alias
        # used to assign a name to the RowTuple object
        # returned by Query.
        self._sa_label_name = name
        self.__name__ = 'AliasedClass_' + str(self.__target)

    def __getstate__(self):
        return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}

    def __setstate__(self, state):
        self.__mapper = state['mapper']
        self.__target = self.__mapper.class_
        alias = state['alias']
        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
        self.__alias = alias
        name = state['name']
        self._sa_label_name = name
        self.__name__ = 'AliasedClass_' + str(self.__target)

    def __adapt_element(self, elem):
        return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})

    def __adapt_prop(self, prop):
        existing = getattr(self.__target, prop.key)
        comparator = existing.comparator.adapted(self.__adapt_element)

        queryattr = attributes.QueryableAttribute(prop.key,
            impl=existing.impl, parententity=self, comparator=comparator)
        setattr(self, prop.key, queryattr)
        return queryattr

    def __getattr__(self, key):
        prop = self.__mapper._get_property(key, raiseerr=False)
        if prop:
            return self.__adapt_prop(prop)

        for base in self.__target.__mro__:
            try:
                attr = object.__getattribute__(base, key)
            except AttributeError:
                continue
            else:
                break
        else:
            raise AttributeError(key)

        if hasattr(attr, 'func_code'):
            is_method = getattr(self.__target, key, None)
            if is_method and is_method.im_self is not None:
                return util.types.MethodType(attr.im_func, self, self)
            else:
                return None
        elif hasattr(attr, '__get__'):
            return attr.__get__(None, self)
        else:
            return attr

    def __repr__(self):
        return '<AliasedClass at 0x%x; %s>' % (
            id(self), self.__target.__name__)

def _orm_annotate(element, exclude=None):
    """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.

    Elements within the exclude collection will be cloned but not annotated.

    """
    return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)

_orm_deannotate = sql_util._deep_deannotate

class _ORMJoin(expression.Join):
    """Extend Join to support ORM constructs as input."""

    __visit_name__ = expression.Join.__visit_name__

    def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
        adapt_from = None

        if hasattr(left, '_orm_mappers'):
            left_mapper = left._orm_mappers[1]
            if join_to_left:
                adapt_from = left.right
        else:
            left_mapper, left, left_is_aliased = _entity_info(left)
            if join_to_left and (left_is_aliased or not left_mapper):
                adapt_from = left

        right_mapper, right, right_is_aliased = _entity_info(right)
        if right_is_aliased:
            adapt_to = right
        else:
            adapt_to = None

        if left_mapper or right_mapper:
            self._orm_mappers = (left_mapper, right_mapper)

            if isinstance(onclause, basestring):
                prop = left_mapper.get_property(onclause)
            elif isinstance(onclause, attributes.QueryableAttribute):
                if adapt_from is None:
                    adapt_from = onclause.__clause_element__()
                prop = onclause.property
            elif isinstance(onclause, MapperProperty):
                prop = onclause
            else:
                prop = None

            if prop:
                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
                                source_selectable=adapt_from,
                                dest_selectable=adapt_to,
                                source_polymorphic=True,
                                dest_polymorphic=True,
                                of_type=right_mapper)

                if sj is not None:
                    left = sql.join(left, secondary, pj, isouter)
                    onclause = sj
                else:
                    onclause = pj
                self._target_adapter = target_adapter

        expression.Join.__init__(self, left, right, onclause, isouter)

    def join(self, right, onclause=None, isouter=False, join_to_left=True):
        return _ORMJoin(self, right, onclause, isouter, join_to_left)

    def outerjoin(self, right, onclause=None, join_to_left=True):
        return _ORMJoin(self, right, onclause, True, join_to_left)

def join(left, right, onclause=None, isouter=False, join_to_left=True):
    """Produce an inner join between left and right clauses.

    In addition to the interface provided by
    :func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped
    classes or AliasedClass instances. The onclause may be a
    string name of a relationship(), or a class-bound descriptor
    representing a relationship.

    join_to_left indicates to attempt aliasing the ON clause,
    in whatever form it is passed, to the selectable
    passed as the left side.  If False, the onclause
    is used as is.

    """
    return _ORMJoin(left, right, onclause, isouter, join_to_left)

def outerjoin(left, right, onclause=None, join_to_left=True):
    """Produce a left outer join between left and right clauses.

    In addition to the interface provided by
    :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
    classes or AliasedClass instances. The onclause may be a
    string name of a relationship(), or a class-bound descriptor
    representing a relationship.

    """
    return _ORMJoin(left, right, onclause, True, join_to_left)

def with_parent(instance, prop):
    """Return criterion which selects instances with a given parent.

    instance
      a parent instance, which should be persistent or detached.

    property
      a class-attached descriptor, MapperProperty or string property name
      attached to the parent instance.

    \**kwargs
      all extra keyword arguments are propagated to the constructor of
      Query.

    """
    if isinstance(prop, basestring):
        mapper = object_mapper(instance)
        prop = mapper.get_property(prop, resolve_synonyms=True)
    elif isinstance(prop, attributes.QueryableAttribute):
        prop = prop.property

    return prop.compare(operators.eq, instance, value_is_parent=True)


def _entity_info(entity, compile=True):
    """Return mapping information given a class, mapper, or AliasedClass.

    Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
    is an aliased() construct.

    If the given entity is not a mapper, mapped class, or aliased construct,
    returns None, the entity, False.  This is typically used to allow
    unmapped selectables through.

    """
    if isinstance(entity, AliasedClass):
        return entity._AliasedClass__mapper, entity._AliasedClass__alias, True

    global mapperlib
    if mapperlib is None:
        from sqlalchemy.orm import mapperlib
    
    if isinstance(entity, mapperlib.Mapper):
        mapper = entity
        
    elif isinstance(entity, type):
        class_manager = attributes.manager_of_class(entity)
        
        if class_manager is None:
            return None, entity, False
            
        mapper = class_manager.mapper
    else:
        return None, entity, False
        
    if compile:
        mapper = mapper.compile()
    return mapper, mapper._with_polymorphic_selectable, False

def _entity_descriptor(entity, key):
    """Return attribute/property information given an entity and string name.

    Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.

    """
    if isinstance(entity, AliasedClass):
        try:
            desc = getattr(entity, key)
            return desc, desc.property
        except AttributeError:
            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
            
    elif isinstance(entity, type):
        try:
            desc = attributes.manager_of_class(entity)[key]
            return desc, desc.property
        except KeyError:
            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
            
    else:
        try:
            desc = entity.class_manager[key]
            return desc, desc.property
        except KeyError:
            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))

def _orm_columns(entity):
    mapper, selectable, is_aliased_class = _entity_info(entity)
    if isinstance(selectable, expression.Selectable):
        return [c for c in selectable.c]
    else:
        return [selectable]

def _orm_selectable(entity):
    mapper, selectable, is_aliased_class = _entity_info(entity)
    return selectable

def _is_aliased_class(entity):
    return isinstance(entity, AliasedClass)

def _state_mapper(state):
    return state.manager.mapper

def object_mapper(instance):
    """Given an object, return the primary Mapper associated with the object instance.

    Raises UnmappedInstanceError if no mapping is configured.

    """
    try:
        state = attributes.instance_state(instance)
        if not state.manager.mapper:
            raise exc.UnmappedInstanceError(instance)
        return state.manager.mapper
    except exc.NO_STATE:
        raise exc.UnmappedInstanceError(instance)

def class_mapper(class_, compile=True):
    """Given a class, return the primary Mapper associated with the key.

    Raises UnmappedClassError if no mapping is configured.

    """
    try:
        class_manager = attributes.manager_of_class(class_)
        mapper = class_manager.mapper

        # HACK until [ticket:1142] is complete
        if mapper is None:
            raise AttributeError

    except exc.NO_STATE:
        raise exc.UnmappedClassError(class_)

    if compile:
        mapper = mapper.compile()
    return mapper

def _class_to_mapper(class_or_mapper, compile=True):
    if _is_aliased_class(class_or_mapper):
        return class_or_mapper._AliasedClass__mapper
    elif isinstance(class_or_mapper, type):
        return class_mapper(class_or_mapper, compile=compile)
    elif hasattr(class_or_mapper, 'compile'):
        if compile:
            return class_or_mapper.compile()
        else:
            return class_or_mapper
    else:
        raise exc.UnmappedClassError(class_or_mapper)

def has_identity(object):
    state = attributes.instance_state(object)
    return _state_has_identity(state)

def _state_has_identity(state):
    return bool(state.key)

def _is_mapped_class(cls):
    global mapperlib
    if mapperlib is None:
        from sqlalchemy.orm import mapperlib
    if isinstance(cls, (AliasedClass, mapperlib.Mapper)):
        return True
    if isinstance(cls, expression.ClauseElement):
        return False
    if isinstance(cls, type):
        manager = attributes.manager_of_class(cls)
        return manager and _INSTRUMENTOR in manager.info
    return False

def instance_str(instance):
    """Return a string describing an instance."""

    return state_str(attributes.instance_state(instance))

def state_str(state):
    """Return a string describing an instance via its InstanceState."""

    if state is None:
        return "None"
    else:
        return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))

def attribute_str(instance, attribute):
    return instance_str(instance) + "." + attribute

def state_attribute_str(state, attribute):
    return state_str(state) + "." + attribute

def identity_equal(a, b):
    if a is b:
        return True
    if a is None or b is None:
        return False
    try:
        state_a = attributes.instance_state(a)
        state_b = attributes.instance_state(b)
    except exc.NO_STATE:
        return False
    if state_a.key is None or state_b.key is None:
        return False
    return state_a.key == state_b.key


# TODO: Avoid circular import.
attributes.identity_equal = identity_equal
attributes._is_aliased_class = _is_aliased_class
attributes._entity_info = _entity_info
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.