diff --git a/sqlalchemy_pytds/dialect.py b/sqlalchemy_pytds/dialect.py index ddc920e..d0a0258 100644 --- a/sqlalchemy_pytds/dialect.py +++ b/sqlalchemy_pytds/dialect.py @@ -1,16 +1,18 @@ +import re from sqlalchemy import util +from sqlalchemy.engine.interfaces import ExecuteStyle +from sqlalchemy.dialects.mssql import SQL_VARIANT from sqlalchemy.dialects.mssql.base import ( MSDialect, MSExecutionContext, MSIdentifierPreparer, MSSQLCompiler, ) - +from pytds import tds_base, tds_types from .connector import PyTDSConnector _server_side_id = util.counter() - class SSCursor(object): def __init__(self, c): self._c = c @@ -155,6 +157,28 @@ def __init__(self, server_side_cursors=False, **params): self.use_scope_identity = True self.server_side_cursors = server_side_cursors + def do_execute(self, cursor, statement, parameters, context=None): + if context and context.isinsert: + tbl = context.compiled.compile_state.dml_table + #print('*'*20, 'do_execute') + #print('stmt:', statement) + #print('parm:', parameters) + for c in tbl._columns: + if isinstance(c.type, SQL_VARIANT): + todo = [name for name in parameters.keys() if c.name == name or re.match(c.name+r'__\d', name)] + #print('todo:', todo) + for name in todo: + v = parameters.get(name, None) + #print('cvt', name, v) + if isinstance(v, str): + assert len(v) < 8000 + parameters[name] = tds_base.Param(name=name, type=tds_types.NVarCharType(size=len(v)), value=v, flags=0) + elif isinstance(v, bytes): + assert len(v) < 8000 + parameters[name] = tds_base.Param(name=name, type=tds_types.NVarCharType(size=len(v)), value=v, flags=0) + #print('parm:', parameters) + cursor.execute(statement, parameters) + def set_isolation_level(self, connection, level): if level == "AUTOCOMMIT": connection.autocommit(True)