Skip to content

Commit

Permalink
Update code to comply with stricter mypy checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
erdewit committed Nov 11, 2022
1 parent cb68451 commit 2bd71f7
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 98 deletions.
8 changes: 4 additions & 4 deletions ib_insync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import struct
import time
from collections import deque
from typing import List, Optional
from typing import Deque, List, Optional

from eventkit import Event

Expand Down Expand Up @@ -127,8 +127,8 @@ def reset(self):
self._numBytesRecv = 0
self._numMsgRecv = 0
self._isThrottling = False
self._msgQ = deque()
self._timeQ = deque()
self._msgQ: Deque[str] = deque()
self._timeQ: Deque[float] = deque()

def serverVersion(self) -> int:
return self._serverVersion
Expand Down Expand Up @@ -267,7 +267,7 @@ def send(self, *fields):
msg.write('\0')
self.sendMsg(msg.getvalue())

def sendMsg(self, msg):
def sendMsg(self, msg: str):
loop = getLoop()
t = loop.time()
times = self._timeQ
Expand Down
7 changes: 4 additions & 3 deletions ib_insync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def isConnected(self):
return self.transport is not None

def sendMsg(self, msg):
self.transport.write(msg)
self.numBytesSent += len(msg)
self.numMsgSent += 1
if self.transport:
self.transport.write(msg)
self.numBytesSent += len(msg)
self.numMsgSent += 1

def connection_lost(self, exc):
self.transport = None
Expand Down
20 changes: 8 additions & 12 deletions ib_insync/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import logging
from datetime import datetime, timezone
from typing import Any, cast

from .contract import (
ComboLeg, Contract, ContractDescription, ContractDetails,
Expand Down Expand Up @@ -443,7 +444,7 @@ def execDetails(self, fields):

self.parse(c)
self.parse(ex)
time = parseIBDatetime(timeStr)
time = cast(datetime, parseIBDatetime(timeStr))
tz = self.wrapper.ib.TimezoneTWS
if tz:
time = tz.localize(time)
Expand Down Expand Up @@ -524,14 +525,9 @@ def tickOptionComputation(self, fields):

self.wrapper.tickOptionComputation(
int(reqId), int(tickTypeInt), int(tickAttrib),
float(impliedVol) if impliedVol != '-1' else None,
float(delta) if delta != '-2' else None,
float(optPrice) if optPrice != '-1' else None,
float(pvDividend) if pvDividend != '-1' else None,
float(gamma) if gamma != '-2' else None,
float(vega) if vega != '-2' else None,
float(theta) if theta != '-2' else None,
float(undPrice) if undPrice != '-1' else None)
float(impliedVol), float(delta), float(optPrice),
float(pvDividend), float(gamma), float(vega),
float(theta), float(undPrice))

def deltaNeutralValidation(self, fields):
_, _, reqId, conId, delta, price = fields
Expand Down Expand Up @@ -789,7 +785,7 @@ def tickByTick(self, fields):
if tickType in (1, 2):
price, size, mask, exchange, specialConditions = fields
mask = int(mask)
attrib = TickAttribLast(
attrib: Any = TickAttribLast(
pastLimit=bool(mask & 1),
unreported=bool(mask & 2))

Expand Down Expand Up @@ -908,7 +904,7 @@ def openOrder(self, fields):
numLegs = int(fields.pop(0))
c.comboLegs = []
for _ in range(numLegs):
leg = ComboLeg()
leg: Any = ComboLeg()
(
leg.conId,
leg.ratio,
Expand Down Expand Up @@ -1149,7 +1145,7 @@ def completedOrder(self, fields):
numLegs = int(fields.pop(0))
c.comboLegs = []
for _ in range(numLegs):
leg = ComboLeg()
leg: Any = ComboLeg()
(
leg.conId,
leg.ratio,
Expand Down
24 changes: 16 additions & 8 deletions ib_insync/flexreport.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ class FlexReport:
minutes. In the weekends the query servers can be down.
"""

data: bytes
root: et.Element

def __init__(self, token=None, queryId=None, path=None):
"""
Download a report by giving a valid ``token`` and ``queryId``,
or load from file by giving a valid ``path``.
"""
self.data = None
self.root = None
if token and queryId:
self.download(token, queryId)
elif path:
Expand Down Expand Up @@ -74,13 +75,20 @@ def download(self, token, queryId):
data = resp.read()

root = et.fromstring(data)
if root.find('Status').text == 'Success':
code = root.find('ReferenceCode').text
baseUrl = root.find('Url').text
elem = root.find('Status')
if elem and elem.text == 'Success':
elem = root.find('ReferenceCode')
assert elem
code = elem.text
elem = root.find('Url')
assert elem
baseUrl = elem.text
_logger.info('Statement is being prepared...')
else:
errorCode = root.find('ErrorCode').text
errorMsg = root.find('ErrorMessage').text
elem = root.find('ErrorCode')
errorCode = elem.text if elem else ''
elem = root.find('ErrorMessage')
errorMsg = elem.text if elem else ''
raise FlexError(f'{errorCode}: {errorMsg}')

while True:
Expand All @@ -91,7 +99,7 @@ def download(self, token, queryId):
self.root = et.fromstring(self.data)
if self.root[0].tag == 'code':
msg = self.root[0].text
if msg.startswith('Statement generation in progress'):
if msg and msg.startswith('Statement generation in progress'):
_logger.info('still working...')
continue
else:
Expand Down
Loading

0 comments on commit 2bd71f7

Please sign in to comment.