Skip to content

Commit

Permalink
Minor changes, mostly docs
Browse files Browse the repository at this point in the history
  • Loading branch information
joschu committed Aug 25, 2015
1 parent 54669e1 commit 5a07cba
Show file tree
Hide file tree
Showing 19 changed files with 434 additions and 411 deletions.
2 changes: 1 addition & 1 deletion cgt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .api import *
from .display import print_tree, print_expr, print_text, as_dot
from .compilation import function, numeric_eval, profiler
from .core import grad, get_config, update_config, simplify, reset_config, Device, scoped_update_config, infer_shape
from .core import grad, get_config, update_config, simplify, reset_config, Device, scoped_update_config, infer_shape, count_nodes
from .ez import EasyCustomOp
try:
import cycgt
Expand Down
7 changes: 6 additions & 1 deletion cgt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def getitem_nonfancy(arr, slis):
step = (1 if sli.step is None else sli.step)
if (isinstance(stop, int) and (stop < 0)):
stop = size(arr, ax) - stop
if isinstance(step, int):
assert step != 0
if step < 0:
raise NotImplementedError("negative `step parameter is not implemented. use flip(x,0) instead of x[::-1]")

out = core.Result(core.GetSli(ax), [out, start, stop, step])
ax += 1
if all(((x == 'k') for x in shapedesc)):
Expand Down Expand Up @@ -441,7 +446,7 @@ def repeat(x, repeats, axis):
"""
Like numpy.repeat
"""
return core.Result(core.Repeat([axis]), [x, constant(repeats)])
return core.Result(core.Repeat([axis]), [x, core.as_node(repeats)])

def reshape(x, shp):
"""
Expand Down
21 changes: 17 additions & 4 deletions cgt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import core, utils
import cgt
import ctypes, os.path as osp, hashlib, numpy as np, sys, subprocess, string, os, time, traceback
import ctypes, os.path as osp, hashlib, numpy as np, sys, subprocess, string, os, time, traceback, cPickle
from collections import defaultdict, namedtuple
from StringIO import StringIO
import logging
Expand Down Expand Up @@ -38,10 +38,16 @@ def _function_listout(inputs, outputs, dbg = None, updates=None, givens=None):
# Execution
# ================================================================

def python_only():
return not hasattr(cgt,"cycgt")

def determine_devices(nodes_sorted, updatetarg2src):
# Op definitions (available impls, inplace-ness, etc) define constraints
# on possible devices for a node

if python_only():
return {node:Device() for node in nodes_sorted}

# (1) Get available devices for nodes, determined by which impls are available and node types
compile_info = get_compile_info()

Expand Down Expand Up @@ -264,7 +270,6 @@ def get_callable(op, input_types, devtype, prefer_python=False):
else:
raise RuntimeError("Tried to put Op %s on the GPU but I only have a python impl :("%op)


def get_native_callable(op, input_types, devtype):
nci = op.get_native_compile_info(input_types, devtype)
nci.op_str = str(op)
Expand Down Expand Up @@ -711,6 +716,9 @@ def call_and_print(cmd):
ctypes.c_float : "float"
}


_struct_cache = {} # because creating ctypes.Structure class is slow for some reason

def _build_closure(triples):
if triples is None:
return ctypes.c_void_p(0)
Expand All @@ -719,8 +727,13 @@ def _build_closure(triples):
for (fieldname,fieldtype,val) in triples:
vals.append(val)
fields.append((fieldname,fieldtype))
class S(ctypes.Structure):
_fields_ = fields
try:
key = cPickle.dumps(fields)
S = _struct_cache[key]
except KeyError:
class S(ctypes.Structure):
_fields_ = fields
_struct_cache[key] = S
closure = S(*vals)
return closure

Expand Down
3 changes: 3 additions & 0 deletions cgt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,9 @@ def topsorted(outputs):
return out

def count_nodes(outputs):
"""
Given a list of output nodes, compute the number of ancestors
"""
if isinstance(outputs, Node): outputs = [outputs]
return len(list(topsorted(outputs)))

Expand Down
Loading

0 comments on commit 5a07cba

Please sign in to comment.