Skip to content

Commit

Permalink
adding some paper sizing and font size features
Browse files Browse the repository at this point in the history
  • Loading branch information
Cody Karcher committed Sep 26, 2023
1 parent e417da6 commit e9232d2
Showing 1 changed file with 146 additions and 11 deletions.
157 changes: 146 additions & 11 deletions pyomo/util/latex_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
from pyomo.core.base.constraint import ScalarConstraint, IndexedConstraint
from pyomo.common.collections.component_map import ComponentMap
from pyomo.common.collections.component_set import ComponentSet
from pyomo.core.expr.template_expr import (
NPV_Numeric_GetItemExpression,
NPV_Structural_GetItemExpression,
Numeric_GetAttrExpression
)
from pyomo.core.expr.numeric_expr import NPV_SumExpression
from pyomo.core.base.block import IndexedBlock

from pyomo.core.base.external import _PythonCallbackFunctionID

Expand Down Expand Up @@ -314,12 +321,17 @@ def handle_functionID_node(visitor, node, *args):


def handle_indexTemplate_node(visitor, node, *args):
if node._set in ComponentSet(visitor.setMap.keys()):
# already detected set, do nothing
pass
else:
visitor.setMap[node._set] = 'SET%d'%(len(visitor.setMap.keys())+1)

return '__I_PLACEHOLDER_8675309_GROUP_%s_%s__' % (
node._group,
visitor.setMap[node._set],
)


def handle_numericGIE_node(visitor, node, *args):
joinedName = args[0]

Expand Down Expand Up @@ -350,6 +362,40 @@ def handle_templateSumExpression_node(visitor, node, *args):
def handle_param_node(visitor, node):
return visitor.parameterMap[node]

def handle_str_node(visitor, node):
return node.replace('_', '\\_')

def handle_npv_numericGetItemExpression_node(visitor, node, *args):
joinedName = args[0]

pstr = ''
pstr += joinedName + '_{'
for i in range(1, len(args)):
pstr += args[i]
if i <= len(args) - 2:
pstr += ','
else:
pstr += '}'
return pstr

def handle_npv_structuralGetItemExpression_node(visitor, node, *args):
joinedName = args[0]

pstr = ''
pstr += joinedName + '['
for i in range(1, len(args)):
pstr += args[i]
if i <= len(args) - 2:
pstr += ','
else:
pstr += ']'
return pstr

def handle_indexedBlock_node(visitor, node, *args):
return str(node)

def handle_numericGetAttrExpression_node(visitor, node, *args):
return args[0] + '.' + args[1]

class _LatexVisitor(StreamBasedExpressionVisitor):
def __init__(self):
Expand Down Expand Up @@ -388,10 +434,24 @@ def __init__(self):
TemplateSumExpression: handle_templateSumExpression_node,
ScalarParam: handle_param_node,
_ParamData: handle_param_node,
IndexedParam: handle_param_node,
NPV_Numeric_GetItemExpression: handle_npv_numericGetItemExpression_node,
IndexedBlock: handle_indexedBlock_node,
NPV_Structural_GetItemExpression: handle_npv_structuralGetItemExpression_node,
str: handle_str_node,
Numeric_GetAttrExpression: handle_numericGetAttrExpression_node,
NPV_SumExpression: handle_sumExpression_node,
}

def exitNode(self, node, data):
return self._operator_handles[node.__class__](self, node, *data)
try:
return self._operator_handles[node.__class__](self, node, *data)
except:
print(node.__class__)
print(node)
print(data)

return 'xxx'

def analyze_variable(vr):
domainMap = {
Expand Down Expand Up @@ -571,6 +631,8 @@ def latex_printer(
use_equation_environment=False,
split_continuous_sets=False,
use_short_descriptors=False,
fontsize = None,
paper_dimensions=None,
):
"""This function produces a string that can be rendered as LaTeX
Expand Down Expand Up @@ -615,6 +677,49 @@ def latex_printer(

isSingle = False

fontSizes = ['\\tiny', '\\scriptsize', '\\footnotesize', '\\small', '\\normalsize', '\\large', '\\Large', '\\LARGE', '\\huge', '\\Huge']
fontSizes_noSlash = ['tiny', 'scriptsize', 'footnotesize', 'small', 'normalsize', 'large', 'Large', 'LARGE', 'huge', 'Huge']
fontsizes_ints = [ -4, -3, -2, -1, 0, 1, 2, 3, 4, 5 ]

if fontsize is None:
fontsize = 0

elif fontsize in fontSizes:
#no editing needed
pass
elif fontsize in fontSizes_noSlash:
fontsize = '\\' + fontsize
elif fontsize in fontsizes_ints:
fontsize = fontSizes[fontsizes_ints.index(fontsize)]
else:
raise ValueError('passed an invalid font size option %s'%(fontsize))

paper_dimensions_used = {}
paper_dimensions_used['height'] = 11.0
paper_dimensions_used['width'] = 8.5
paper_dimensions_used['margin_left'] = 1.0
paper_dimensions_used['margin_right'] = 1.0
paper_dimensions_used['margin_top'] = 1.0
paper_dimensions_used['margin_bottom'] = 1.0

if paper_dimensions is not None:
for ky in [ 'height', 'width', 'margin_left', 'margin_right', 'margin_top', 'margin_bottom' ]:
if ky in paper_dimensions.keys():
paper_dimensions_used[ky] = paper_dimensions[ky]
else:
if paper_dimensions_used['height'] >= 225 :
raise ValueError('Paper height exceeds maximum dimension of 225')
if paper_dimensions_used['width'] >= 225 :
raise ValueError('Paper width exceeds maximum dimension of 225')
if paper_dimensions_used['margin_left'] < 0.0:
raise ValueError('Paper margin_left must be greater than or equal to zero')
if paper_dimensions_used['margin_right'] < 0.0:
raise ValueError('Paper margin_right must be greater than or equal to zero')
if paper_dimensions_used['margin_top'] < 0.0:
raise ValueError('Paper margin_top must be greater than or equal to zero')
if paper_dimensions_used['margin_bottom'] < 0.0:
raise ValueError('Paper margin_bottom must be greater than or equal to zero')

if isinstance(pyomo_component, pyo.Objective):
objectives = [pyomo_component]
constraints = []
Expand Down Expand Up @@ -863,15 +968,22 @@ def latex_printer(
+ ' %s %s' % (visitor.walk_expression(con_template), trailingAligner)
)

# setMap = visitor.setMap
# Multiple constraints are generated using a set
if len(indices) > 0:
if indices[0]._set in ComponentSet(visitor.setMap.keys()):
# already detected set, do nothing
pass
else:
visitor.setMap[indices[0]._set] = 'SET%d'%(len(visitor.setMap.keys())+1)

idxTag = '__I_PLACEHOLDER_8675309_GROUP_%s_%s__' % (
indices[0]._group,
setMap[indices[0]._set],
visitor.setMap[indices[0]._set],
)
setTag = '__S_PLACEHOLDER_8675309_GROUP_%s_%s__' % (
indices[0]._group,
setMap[indices[0]._set],
visitor.setMap[indices[0]._set],
)

conLine += ' \\qquad \\forall %s \\in %s ' % (idxTag, setTag)
Expand Down Expand Up @@ -933,6 +1045,8 @@ def latex_printer(
'Variable is not a variable. Should not happen. Contact developers'
)

# print(varBoundData)

# print the accumulated data to the string
bstr = ''
appendBoundString = False
Expand All @@ -945,8 +1059,8 @@ def latex_printer(
and vbd['domainName'] == 'Reals'
):
# unbounded all real, do not print
if i <= len(varBoundData) - 2:
bstr = bstr[0:-2]
if i == len(varBoundData) - 1:
bstr = bstr[0:-4]
else:
if not useThreeAlgn:
algn = '& &'
Expand Down Expand Up @@ -998,11 +1112,19 @@ def latex_printer(
pstr += ' \\label{%s} \n' % (pyomo_component.name)
pstr += '\\end{equation} \n'


setMap = visitor.setMap
setMap_inverse = {vl: ky for ky, vl in setMap.items()}
# print(setMap)

# print('\n\n\n\n')
# print(pstr)

# Handling the iterator indices
defaultSetLatexNames = ComponentMap()
for i in range(0, len(setList)):
st = setList[i]
defaultSetLatexNames[st] = setList[i].name.replace('_', '\\_')
for ky,vl in setMap.items():
st = ky
defaultSetLatexNames[st] = st.name.replace('_', '\\_')
if st in ComponentSet(latex_component_map.keys()):
defaultSetLatexNames[st] = latex_component_map[st][0]#.replace('_', '\\_')

Expand Down Expand Up @@ -1034,7 +1156,7 @@ def latex_printer(

for ky, vl in setInfo.items():
ix = int(ky[3:]) - 1
setInfo[ky]['setObject'] = setList[ix]
setInfo[ky]['setObject'] = setMap_inverse[ky]#setList[ix]
setInfo[ky][
'setRegEx'
] = r'__S_PLACEHOLDER_8675309_GROUP_([0-9*])_%s__' % (ky)
Expand Down Expand Up @@ -1134,6 +1256,8 @@ def latex_printer(
latexLines[jj] = ln

pstr = '\n'.join(latexLines)
# print('\n\n\n\n')
# print(pstr)

vrIdx = 0
new_variableMap = ComponentMap()
Expand Down Expand Up @@ -1213,13 +1337,19 @@ def latex_printer(
for ky, vl in label_rep_dict.items():
label_rep_dict[ky] = vl.replace('{', '').replace('}', '').replace('\\', '')

# print('\n\n\n\n')
# print(pstr)

splitLines = pstr.split('\n')
for i in range(0, len(splitLines)):
if use_equation_environment:
splitLines[i] = multiple_replace(splitLines[i], rep_dict)
else:
if '\\label{' in splitLines[i]:
epr, lbl = splitLines[i].split('\\label{')
try:
epr, lbl = splitLines[i].split('\\label{')
except:
print(splitLines[i])
epr = multiple_replace(epr, rep_dict)
# rep_dict[ky] = vl.replace('_', '\\_')
lbl = multiple_replace(lbl, label_rep_dict)
Expand Down Expand Up @@ -1249,8 +1379,13 @@ def latex_printer(
fstr += '\\usepackage{amsmath} \n'
fstr += '\\usepackage{amssymb} \n'
fstr += '\\usepackage{dsfont} \n'
fstr += '\\usepackage[paperheight=%.4fin, paperwidth=%.4fin, left=%.4fin,right=%.4fin, top=%.4fin, bottom=%.4fin]{geometry} \n'%(
paper_dimensions_used['height'], paper_dimensions_used['width'],
paper_dimensions_used['margin_left'], paper_dimensions_used['margin_right'],
paper_dimensions_used['margin_top'], paper_dimensions_used['margin_bottom'] )
fstr += '\\allowdisplaybreaks \n'
fstr += '\\begin{document} \n'
fstr += fontsize + ' \n'
fstr += pstr + '\n'
fstr += '\\end{document} \n'

Expand Down

0 comments on commit e9232d2

Please sign in to comment.