# $SnapHashLicense:
# SnapLogic - Open source data services
# Copyright (C) 2008-2009, SnapLogic, Inc.  All rights reserved.
# See http://www.snaplogic.org for more information about
# the SnapLogic project. 
# This program is free software, distributed under the terms of
# the GNU General Public License Version 2. See the LEGAL file
# at the top of the source tree.
# "SnapLogic" is a trademark of SnapLogic, Inc.
# $

# $Id: DB2.py 8009 2009-06-26 04:30:52Z grisha $

# Doing this instead of "import DB2" because 
# our name convention for DB wrapper modules and
# classes happens to coincide with the DB-API driver
# in this case. 

import types
db2_driver = __import__('DB2')

from snaplogic.common.data_types import SnapDateTime,SnapNumber,SnapString

TYPE_CODE_TO_NATIVE_TYPE = db2_driver.SQL_type_dict

               'character' : SnapString,
               'varchar' : SnapString,
               'longvarchar' : SnapString,
               'long varchar' : SnapString,
               'char' : SnapString,
               'binary' : SnapString,
               'varbinary' : SnapString,
               'longvarbinary' : SnapString,
               'long varbinary' : SnapString,
               'bigint' : SnapNumber,
               'integer' : SnapNumber,
               'smallint' : SnapNumber,
               'decimal' : SnapNumber,
               'decimal' : SnapNumber,
               'decfloat' : SnapNumber,
               'double' : SnapNumber,
               'real' : SnapNumber,
               'time' : SnapDateTime,
               'date' : SnapDateTime,
               'timestamp' : SnapDateTime,

def _convert_result_rows(self, rows):
    This is a workaround for PyDB2 bug #2618159 
    TupleType = types.TupleType
    # Begin workaround
    if rows is None:
        return rows
    # End workaround
    for r in rows:
        for i in range(len(r)):
            if type(r[i]) == TupleType:
                r[i] = self._convert_result_col(r[i])
    return rows

db2_driver.Cursor._convert_result_rows = _convert_result_rows

# See PyDB2 bug #2618188
# (https://sourceforge.net/tracker2/?func=detail&aid=2618188&group_id=67548&atid=518206)
db2_driver.SQL_type_dict[-360] = 'DECFLOAT'

from datetime import datetime,time
from decimal import Decimal
import os
import sys
from time import strptime

from snaplogic.components.DBUtils import SnapDBAdapter,_acquireLock,_releaseLock
from snaplogic.common.snap_exceptions import SnapException,SnapComponentError

class DB2CursorWrapper(object):
    A wrapper around DB API 2.0 cursor, to handle encoding and type conversion.
    def __init__(self, cursor, snap_con):
        @param cursor: DB API 2.0 cursor object, to which most requests will
        be delegated.
        @type: cursor
        @param snap_con: Snap DB connection that produced this cursor
        @type snap_con: SnapDBAdapter
        self._snap_con = snap_con
        self._delegate = cursor
        self._metadata = None
        self._date_fields = None
        self._num_fields = None
        self._str_fields = None
    def execute(self, operation, params = None):
        self._metadata = None
        operation = operation.encode('utf-8')
        new_params = None
        if params:
            new_params = []
            for p in params:
                p_t = type(p)
                if p_t == unicode:
        return self._delegate.execute(operation, new_params)
    def convert_row(self, row):
        Convert a row of data in native data types into a row of Snap types.
        @param row: row returned by database
        @type row: tuple
        @return: row converted to Snap data types
        @rtype: list 
        if self._metadata is not None and \
            self._date_fields is None and \
            self._num_fields is None and \
            self._str_fields is None:
            return row
        if not row:
            return row
        if self._metadata is None:
            self._metadata = self._delegate.description
            self._date_fields = None
            self._num_fields = None
            self._str_fields = None
            i = 0
            for col_metadata in self._metadata:
                type_code = col_metadata[1]
                native_type = self._snap_con.type_code_to_native_type(type_code)
                snap_type = self._snap_con.native_type_to_snap_type(native_type)
                if snap_type == SnapNumber:
                    if self._num_fields is None:
                        self._num_fields = {}
                    self._num_fields[i] = native_type
                elif snap_type == SnapString:
                    if self._str_fields is None:
                        self._str_fields = {}
                    self._str_fields[i] = native_type
                elif snap_type == SnapDateTime:
                    if self._date_fields is None:
                        self._date_fields = {}
                    self._date_fields[i] = native_type
                i += 1

        new_row = list(row)
        if self._num_fields is not None:
            for idx in self._num_fields:
                val = row[idx]
                if val is None:
                t = type(val)
                if t == int or t == long or t == bool:
                    new_row[idx] = Decimal(val)
                    new_row[idx] = Decimal(str(val))
        if self._str_fields is not None:
            for idx in self._str_fields:
                val = row[idx]
                if val is None:
                new_row[idx] = val.decode('utf-8')
        if self._date_fields is not None:
            for idx in self._date_fields.keys():
                val = row[idx]
                if val is None:
                type_name = self._date_fields[idx]
                str_val = row[idx]
                if type_name == 'TIMESTAMP':
                    dot_idx = str_val.rindex('.')
                    micros = str_val[dot_idx+1:]
                    micros = int(micros)
                    str_val = str_val[:dot_idx]
                    tt = strptime(str_val,'%Y-%m-%d-%H.%M.%S')
                    tt = list(tt[0:6])
                    dt = datetime(*tt)
                    new_row[idx] = dt
                elif type_name == 'DATE':
                    tt = strptime(str_val,'%Y-%m-%d')
                    tt = tt[0:3]
                    dt = datetime(*tt)
                    new_row[idx] = dt
                elif type_name == 'TIME':
                    tt = strptime(str_val,'%H:%M:%S')
                    tt = tt[3:6]
                    dt = time(*tt)
                    dt = datetime.combine(datetime.today(), dt)
                    new_row[idx] = dt
        return new_row
    def convert_results(self, rs):
        Convert the result set from native data types to Snap data types.
        This is similar to L{convert_row}, except it acts on the entire result
        @param rs: Result set to convert
        @type rs: list or tuple
        @return: converted result set
        @type: list
        if self._metadata is not None and \
            self._str_fields is None and \
            self._num_fields is None and \
            self._date_fields is None:
            return rs
        if not rs:
            return rs
        converted_rs = []
        for row in rs:
            new_row = self.convert_row(row)
        return converted_rs 
    def fetchone(self):
        Same as cursor.fetchone() specified in DB API 2.0, except returning
        Snap data types.
        row = self._delegate.fetchone()
        if row is not None:
            row = self.convert_row(row)
        return row
    def fetchmany(self, size=None):
        Same as cursor.fetchmany() specified in DB API 2.0, except returning
        Snap data types.
        rs = self._delegate.fetchmany(size)
        rs = self.convert_results(rs)
        return rs 
    def fetchall(self):
        Same as cursor.fetchall() specified in DB API 2.0, except returning
        Snap data types.
        rs = self._delegate.fetchall()
        rs = self.convert_results(rs)
        return rs
    def __getattr__(self, name):
        Used to delegate to the native cursor object those methods that are not
        wrapped by this class.
        result = getattr(self._delegate, name)
        return result

class DB2(SnapDBAdapter):
    Implementation of L{SnapDBAdapter} for DB2.
    def __init__(self, *args, **kwargs):
        dsn = 'driver={IBM DB2 ODBC DRIVER};database=%(db)s;hostname=%(host)s;port=%(port)s;protocol=tcpip;' % kwargs        
        self._user = kwargs['user']
        passwd = kwargs['password']
        conn = db2_driver.connect(dsn=dsn,uid=self._user,pwd=passwd)
        super(DB2, self).__init__(conn, db2_driver)
        self._table_meta = {}

    def upsert(self, table, row, keys, table_metadata, cur=None):
        DB2-specific implementation of L{SnapDBAdapter.upsert()
        by using MERGE.
        native_types = table_metadata['native_types']
        field_names = row.keys()
        bind_container = [row[f] for f in field_names]
        inner_select_clause = ['CAST(? AS %s) AS %s' % (native_types[f.upper()], f) for f in field_names]
        sql = "MERGE INTO " + \
                table + \
                " t1 USING (SELECT " + \
                ', '.join(inner_select_clause) + \
                " FROM sysibm.sysdummy1) t2 ON (";
        set_clause = ["t1.%s = t2.%s" % (key, key) for key in keys]
        sql += ' AND '.join(set_clause)
        sql += ")"
        fields_to_set = list(set(field_names) - set(keys))
        update_clause = ["t1.%s = t2.%s" % (f, f) for f in fields_to_set]
        sql += ",".join(update_clause)
        sql += " WHEN NOT MATCHED THEN INSERT ("
        sql += ",".join(field_names)
        sql += ") VALUES ("
        sql += ",".join(['?' for f in field_names])
        bind_container += bind_container
        sql += ")"
        cur = self.cursor()
        bind_container = self.fix_bound_values(bind_container)
        cur.execute(sql, bind_container)
    def cursor(self):
        See L{SnapDBAdapter.cursor} and L{DB2CursorWrapper}
        native_cursor = SnapDBAdapter.cursor(self)
        my_cursor = DB2CursorWrapper(native_cursor, self)
        # Used for a workaround for Python issue #5377
        # (http://bugs.python.org/issue5377) 
        version_info = sys.version_info
        self._long_workaround = False
        if version_info[1] == 6 or (version_info[1] == 5 and version_info[2] >= 2):
            self._long_workaround = True
        return my_cursor

    def fix_bound_values(self, record):
        Given a record (really, a sequence) whose elements are 
        Python objects, returns a sequence whose elements
        are of the type that the DB expects.
        @param record: sequence of values 
        @type record: sequence
        @return: a record with elements converted to types the DB expects. 
        @rtype: sequence
        new_result = []
        for value in record:
            value_t = type(value)
            if value_t == unicode:
                value = value.encode('utf-8')
            elif value_t == datetime:
                value = db2_driver.Timestamp(value.year,   \
                                             value.month,  \
                                             value.day,    \
                                             value.hour,   \
                                             value.minute, \
                                             value.second, \
            elif value_t == Decimal:
                int_value = int(value)
                str_value = str(value)
                if '.' in str_value:
                    value = float(value)
                    value = int_value
                    if self._long_workaround:
                        value = int(int_value)
        return new_result

    def get_default_schema(self):
        See L{SnapDBAdapter.get_default_schema}. Default here is assumed
        to be the user connected to Oracle.
        return self._user
    def list_tables(self, schema = None):
        See L{SnapDBAdapter.list_tables}. 
        if not schema:
            schema = self.get_default_schema()
        cur = self.cursor()
        sql = "SELECT tabname FROM syscat.tables WHERE LOWER(tabschema) = LOWER(CAST(? AS VARCHAR(128)))"
        cur.execute(sql, (schema,))        
        result = cur.fetchall()
        result = [row[0] for row in result]
        return result
    def limit_rows_clause(self, limit=1):
        See L{SnapDBAdapter.limit_rows_clause()}
        return "FETCH FIRST %s ROWS ONLY OPTIMIZE FOR %s ROW FOR FETCH ONLY" % (limit, limit)
    def get_snap_view_metadata(self, table_name):
        view_def = {}
        field_defs = []
        (schema, table_name) = self._parse_table_name(table_name)
        view_def['schema'] = schema
        view_def['native_types'] = {}
        cur = self.cursor()
        sql =   """
                SELECT * FROM syscat.columns 
                LOWER(tabname) = LOWER(CAST(? AS VARCHAR(128))) 
                LOWER(tabschema) = LOWER(CAST(? AS VARCHAR(128)))
                ORDER BY
        cur.execute(sql, (table_name, schema))        
        result = cur._delegate.fetchall()
        if not result:
            raise SnapComponentError("Table '%s' not found in schema '%s'" % (table_name, schema))
        indices = {}
        for i in range(len(cur.description)):
            meta = cur.description[i]
            col_name = meta[0]
            if not indices.has_key(col_name):
                indices[col_name] = i 
        for row in result:
            # These we need for actual metadata
            name = row[indices['COLNAME']]
            data_type = row[indices['TYPENAME']]
            snap_type = self.native_type_to_snap_type(data_type)
            desc = []
            nullable = row[indices['NULLS']]
            desc.append("Nullable: %s" % nullable)
            length = row[indices['LENGTH']]
            if length:
                desc.append("Length: %s" % length)
            scale = row[indices['SCALE']]
            if scale:
                desc.append("Scale: %s" % scale)
            native_type = data_type.upper()
            if native_type == 'VARCHAR' or native_type == 'CHARACTER':
                native_type += "(%s)" % length
            elif native_type == 'DECIMAL':
                native_type += "(%s,%s)" % (length, scale)
            view_def['native_types'][name] = native_type
            default = row[indices['DEFAULT']]
            if default:
                desc.append("Default: %s" % default)
            codepage = row[indices['CODEPAGE']]
            if codepage:
                desc.append("Codepage: %s" % codepage) 
            collation = row[indices['COLLATIONNAME']]
            if collation:
                desc.append("Collation: %s" % collation)
            remarks = row[indices['REMARKS']]
            if remarks:
                desc.append("Remarks: %s" % remarks)
            desc = '; '.join(desc)
            field_def = (name, snap_type, desc,)
        view_def['fields'] = tuple(field_defs)
        pkey_sql = """
                    SELECT colnames FROM syscat.indexes WHERE
                    LOWER(tabschema) = LOWER(CAST(? AS VARCHAR(128)))  AND 
                    LOWER(tabname) = LOWER(CAST(? AS VARCHAR(128)))  AND 
        cur.execute(pkey_sql, (schema, table_name,))
        pkey_rs = cur.fetchone()
        if pkey_rs:
            pkey_cols = pkey_rs[0]
            if pkey_cols:
                view_def['primary_key'] = pkey_cols.split('+')[1:]
        if not view_def.has_key('primary_key'):
            view_def['primary_key'] = []        
        return view_def
