Skip to content

Commit

Permalink
Merge pull request erdewit#205 from antequant/typing-no-dataclasses
Browse files Browse the repository at this point in the history
Adding more typing annotations to external APIs
  • Loading branch information
erdewit authored Dec 29, 2019
2 parents 340fca5 + 1df8044 commit d59c042
Show file tree
Hide file tree
Showing 10 changed files with 844 additions and 320 deletions.
4 changes: 2 additions & 2 deletions ib_insync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

__all__ = ['util', 'Event']
for _m in (
objects, contract, order, ticker, ib,
client, wrapper, flexreport, ibcontroller):
objects, contract, order, ticker, ib, # type: ignore
client, wrapper, flexreport, ibcontroller): # type: ignore
__all__ += _m.__all__

del sys
22 changes: 21 additions & 1 deletion ib_insync/contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from ib_insync.objects import Object
from ib_insync.objects import ComboLeg, DeltaNeutralContract, Object

__all__ = (
'Contract Stock Option Future ContFuture Forex Index CFD '
Expand Down Expand Up @@ -98,6 +99,25 @@ class Contract(Object):
)
__slots__ = defaults.keys()

secType: str
conId: int
symbol: str
lastTradeDateOrContractMonth: str
strike: float
right: str
multiplier: str
exchange: str
primaryExchange: str
currency: str
localSymbol: str
tradingClass: str
includeExpired: bool
secIdType: str
secId: str
comboLegsDescrip: str
comboLegs: Optional[List[ComboLeg]]
deltaNeutralContract: Optional[DeltaNeutralContract]

@staticmethod
def create(**kwargs):
"""
Expand Down
111 changes: 64 additions & 47 deletions ib_insync/ib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import time
from contextlib import suppress
from typing import List, Iterator, Awaitable, Union
from typing import List, Iterator, Awaitable, Optional, Union

from eventkit import Event

Expand Down Expand Up @@ -286,7 +286,7 @@ def isConnected(self) -> bool:
timeRangeAsync = staticmethod(util.timeRangeAsync)
waitUntil = staticmethod(util.waitUntil)

def _run(self, *awaitables: List[Awaitable]):
def _run(self, *awaitables: Awaitable):
return util.run(*awaitables, timeout=self.RequestTimeout)

def waitOnUpdate(self, timeout: float = 0) -> bool:
Expand Down Expand Up @@ -510,7 +510,7 @@ def realtimeBars(self) -> BarList:
Get a list of all live updated bars. These can be 5 second realtime
bars or live updated historical bars.
"""
return list(self.wrapper.reqId2Subscriber.values())
return BarList(self.wrapper.reqId2Subscriber.values())

def newsTicks(self) -> List[NewsTick]:
"""
Expand All @@ -526,7 +526,7 @@ def newsBulletins(self) -> List[NewsBulletin]:
return list(self.wrapper.newsBulletins.values())

def reqTickers(
self, *contracts: List[Contract],
self, *contracts: Contract,
regulatorySnapshot: bool = False) -> List[Ticker]:
"""
Request and return a list of snapshot tickers.
Expand All @@ -542,7 +542,7 @@ def reqTickers(
self.reqTickersAsync(
*contracts, regulatorySnapshot=regulatorySnapshot))

def qualifyContracts(self, *contracts: List[Contract]) -> List[Contract]:
def qualifyContracts(self, *contracts: Contract) -> List[Contract]:
"""
Fully qualify the given contracts in-place. This will fill in
the missing fields in the contract, especially the conId.
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def cancelRealTimeBars(self, bars: RealTimeBarList):
self.wrapper.endSubscription(bars)

def reqHistoricalData(
self, contract: Contract, endDateTime: object,
self, contract: Contract, endDateTime: Union[datetime.datetime, datetime.date, str, None],
durationStr: str, barSizeSetting: str,
whatToShow: str, useRTH: bool,
formatDate: int = 1, keepUpToDate: bool = False,
Expand Down Expand Up @@ -1613,8 +1613,8 @@ def replaceFA(self, faDataType: int, xml: str):
# now entering the parallel async universe

async def connectAsync(
self, host='127.0.0.1', port=7497, clientId=1,
timeout=2, readonly=False, account=''):
self, host: str = '127.0.0.1', port: int = 7497,
clientId: int = 1, timeout: float = 2, readonly: bool = False, account: str = ''):

async def connect():
self.wrapper.clientId = clientId
Expand Down Expand Up @@ -1643,7 +1643,7 @@ async def connect():
self._logger.warn('Already connected')
return self

async def qualifyContractsAsync(self, *contracts):
async def qualifyContractsAsync(self, *contracts: Contract) -> List[Contract]:
detailsLists = await asyncio.gather(
*(self.reqContractDetailsAsync(c) for c in contracts))
result = []
Expand All @@ -1669,7 +1669,7 @@ async def qualifyContractsAsync(self, *contracts):
result.append(contract)
return result

async def reqTickersAsync(self, *contracts, regulatorySnapshot=False):
async def reqTickersAsync(self, *contracts: Contract, regulatorySnapshot: bool = False) -> List[Ticker]:
futures = []
tickers = []
for contract in contracts:
Expand All @@ -1685,30 +1685,30 @@ async def reqTickersAsync(self, *contracts, regulatorySnapshot=False):
self.wrapper.endTicker(ticker, 'snapshot')
return tickers

def whatIfOrderAsync(self, contract, order):
def whatIfOrderAsync(self, contract: Contract, order: Order) -> Awaitable[OrderState]:
whatIfOrder = Order(**order.dict()).update(whatIf=True)
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.placeOrder(reqId, contract, whatIfOrder)
return future

def reqCurrentTimeAsync(self):
def reqCurrentTimeAsync(self) -> Awaitable[datetime.datetime]:
future = self.wrapper.startReq('currentTime')
self.client.reqCurrentTime()
return future

def reqAccountUpdatesAsync(self, account):
def reqAccountUpdatesAsync(self, account: str) -> Awaitable[None]:
future = self.wrapper.startReq('accountValues')
self.client.reqAccountUpdates(True, account)
return future

def reqAccountUpdatesMultiAsync(self, account, modelCode=''):
def reqAccountUpdatesMultiAsync(self, account: str, modelCode: str = '') -> Awaitable[None]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
self.client.reqAccountUpdatesMulti(reqId, account, modelCode, False)
return future

def reqAccountSummaryAsync(self):
def reqAccountSummaryAsync(self) -> Awaitable[None]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
tags = (
Expand All @@ -1724,40 +1724,40 @@ def reqAccountSummaryAsync(self):
self.client.reqAccountSummary(reqId, 'All', tags)
return future

def reqOpenOrdersAsync(self):
def reqOpenOrdersAsync(self) -> Awaitable[List[Order]]:
future = self.wrapper.startReq('openOrders')
self.client.reqOpenOrders()
return future

def reqAllOpenOrdersAsync(self):
def reqAllOpenOrdersAsync(self) -> Awaitable[List[Order]]:
future = self.wrapper.startReq('openOrders')
self.client.reqAllOpenOrders()
return future

def reqCompletedOrdersAsync(self, apiOnly):
def reqCompletedOrdersAsync(self, apiOnly: bool) -> Awaitable[List[Trade]]:
future = self.wrapper.startReq('completedOrders')
self.client.reqCompletedOrders(apiOnly)
return future

def reqExecutionsAsync(self, execFilter=None):
def reqExecutionsAsync(self, execFilter: ExecutionFilter = None) -> Awaitable[List[Fill]]:
execFilter = execFilter or ExecutionFilter()
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
self.client.reqExecutions(reqId, execFilter)
return future

def reqPositionsAsync(self):
def reqPositionsAsync(self) -> Awaitable[List[Position]]:
future = self.wrapper.startReq('positions')
self.client.reqPositions()
return future

def reqContractDetailsAsync(self, contract):
def reqContractDetailsAsync(self, contract: Contract) -> Awaitable[List[ContractDetails]]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.reqContractDetails(reqId, contract)
return future

async def reqMatchingSymbolsAsync(self, pattern):
async def reqMatchingSymbolsAsync(self, pattern: str) -> Optional[List[ContractDescription]]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
self.client.reqMatchingSymbols(reqId, pattern)
Expand All @@ -1766,20 +1766,24 @@ async def reqMatchingSymbolsAsync(self, pattern):
return future.result()
except asyncio.TimeoutError:
self._logger.error('reqMatchingSymbolsAsync: Timeout')
return None

async def reqMarketRuleAsync(self, marketRuleId):
async def reqMarketRuleAsync(self, marketRuleId: int) -> Optional[PriceIncrement]:
future = self.wrapper.startReq(f'marketRule-{marketRuleId}')
try:
self.client.reqMarketRule(marketRuleId)
await asyncio.wait_for(future, 1)
return future.result()
except asyncio.TimeoutError:
self._logger.error('reqMarketRuleAsync: Timeout')
return None

def reqHistoricalDataAsync(
self, contract, endDateTime,
durationStr, barSizeSetting, whatToShow, useRTH,
formatDate=1, keepUpToDate=False, chartOptions=None):
self, contract: Contract, endDateTime: Union[datetime.datetime, datetime.date, str, None],
durationStr: str, barSizeSetting: str,
whatToShow: str, useRTH: bool,
formatDate: int = 1, keepUpToDate: bool = False,
chartOptions: Optional[List[TagValue]] = None) -> Awaitable[BarDataList]:
reqId = self.client.getReqId()
bars = BarDataList()
bars.reqId = reqId
Expand All @@ -1802,9 +1806,12 @@ def reqHistoricalDataAsync(
return future

def reqHistoricalTicksAsync(
self, contract, startDateTime, endDateTime,
numberOfTicks, whatToShow, useRth,
ignoreSize=False, miscOptions=None):
self, contract: Contract,
startDateTime: Union[str, datetime.date],
endDateTime: Union[str, datetime.date],
numberOfTicks: int, whatToShow: str, useRth: bool,
ignoreSize: bool = False,
miscOptions: Optional[List[TagValue]] = None) -> Awaitable[List]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
start = util.formatIBDatetime(startDateTime)
Expand All @@ -1815,35 +1822,36 @@ def reqHistoricalTicksAsync(
return future

def reqHeadTimeStampAsync(
self, contract, whatToShow, useRTH, formatDate):
self, contract: Contract, whatToShow: str,
useRTH: bool, formatDate: int) -> Awaitable[datetime.datetime]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.reqHeadTimeStamp(
reqId, contract, whatToShow, useRTH, formatDate)
return future

def reqMktDepthExchangesAsync(self):
def reqMktDepthExchangesAsync(self) -> Awaitable[List[DepthMktDataDescription]]:
future = self.wrapper.startReq('mktDepthExchanges')
self.client.reqMktDepthExchanges()
return future

def reqHistogramDataAsync(self, contract, useRTH, period):
def reqHistogramDataAsync(self, contract: Contract, useRTH: bool, period: str) -> Awaitable[List[HistogramData]]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.reqHistogramData(reqId, contract, useRTH, period)
return future

def reqFundamentalDataAsync(
self, contract, reportType, fundamentalDataOptions=None):
self, contract: Contract, reportType: str, fundamentalDataOptions: Optional[List[TagValue]] = None) -> Awaitable[str]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.reqFundamentalData(
reqId, contract, reportType, fundamentalDataOptions)
return future

async def reqScannerDataAsync(
self, subscription, scannerSubscriptionOptions=None,
scannerSubscriptionFilterOptions=None):
self, subscription: ScannerSubscription, scannerSubscriptionOptions: Optional[List[TagValue]] = None,
scannerSubscriptionFilterOptions: Optional[List[TagValue]] = None) -> ScanDataList:
dataList = self.reqScannerSubscription(
subscription, scannerSubscriptionOptions,
scannerSubscriptionFilterOptions)
Expand All @@ -1852,13 +1860,15 @@ async def reqScannerDataAsync(
self.client.cancelScannerSubscription(dataList.reqId)
return future.result()

def reqScannerParametersAsync(self):
def reqScannerParametersAsync(self) -> Awaitable[str]:
future = self.wrapper.startReq('scannerParams')
self.client.reqScannerParameters()
return future

async def calculateImpliedVolatilityAsync(
self, contract, optionPrice, underPrice, implVolOptions):
self, contract: Contract,
optionPrice: float, underPrice: float,
implVolOptions: List[TagValue] = None) -> Optional[OptionComputation]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.calculateImpliedVolatility(
Expand All @@ -1868,12 +1878,14 @@ async def calculateImpliedVolatilityAsync(
return future.result()
except asyncio.TimeoutError:
self._logger.error('calculateImpliedVolatilityAsync: Timeout')
return
return None
finally:
self.client.cancelCalculateImpliedVolatility(reqId)

async def calculateOptionPriceAsync(
self, contract, volatility, underPrice, optPrcOptions):
self, contract: Contract,
volatility: float, underPrice: float,
optPrcOptions=None) -> Optional[OptionComputation]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId, contract)
self.client.calculateOptionPrice(
Expand All @@ -1883,36 +1895,40 @@ async def calculateOptionPriceAsync(
return future.result()
except asyncio.TimeoutError:
self._logger.error('calculateOptionPriceAsync: Timeout')
return
return None
finally:
self.client.cancelCalculateOptionPrice(reqId)

def reqSecDefOptParamsAsync(
self, underlyingSymbol, futFopExchange,
underlyingSecType, underlyingConId):
self, underlyingSymbol: str,
futFopExchange: str, underlyingSecType: str,
underlyingConId: int) -> Awaitable[List[OptionChain]]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
self.client.reqSecDefOptParams(
reqId, underlyingSymbol, futFopExchange,
underlyingSecType, underlyingConId)
return future

def reqNewsProvidersAsync(self):
def reqNewsProvidersAsync(self) -> Awaitable[List[NewsProvider]]:
future = self.wrapper.startReq('newsProviders')
self.client.reqNewsProviders()
return future

def reqNewsArticleAsync(
self, providerCode, articleId, newsArticleOptions):
self, providerCode: str, articleId: str, newsArticleOptions: Optional[List[TagValue]]) -> Awaitable[NewsArticle]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
self.client.reqNewsArticle(
reqId, providerCode, articleId, newsArticleOptions)
return future

async def reqHistoricalNewsAsync(
self, conId, providerCodes, startDateTime, endDateTime,
totalResults, historicalNewsOptions=None):
self, conId: int, providerCodes: str,
startDateTime: Union[str, datetime.date],
endDateTime: Union[str, datetime.date],
totalResults: int,
historicalNewsOptions: List[TagValue] = None) -> Optional[HistoricalNews]:
reqId = self.client.getReqId()
future = self.wrapper.startReq(reqId)
start = util.formatIBDatetime(startDateTime)
Expand All @@ -1925,8 +1941,9 @@ async def reqHistoricalNewsAsync(
return future.result()
except asyncio.TimeoutError:
self._logger.error('reqHistoricalNewsAsync: Timeout')
return None

async def requestFAAsync(self, faDataType):
async def requestFAAsync(self, faDataType: int):
future = self.wrapper.startReq('requestFA')
self.client.requestFA(faDataType)
try:
Expand Down
Loading

0 comments on commit d59c042

Please sign in to comment.