Ignition v0.1.0 documentation

Source code for ignition.flame.tensors.printers

"""Simple printers for tensor expressions"""

import operator
from copy import copy
from sympy import Add, Number, Mul, latex, Pow, S

from tensor_expr import expr_rank
from tensor import Tensor
from basic_operators import Inner, Inverse, Transpose

defaults = {"add" : " + ",
            "div_under_one": "1.0 / (%s)",
            "dot" : "(%s, %s)",
            "inverse" : "(%s)**-1",
            "neg" : "-%s",
            "name_attr" : "name",
            "pow" : "(%s)**%s",
            "transpose" : "T(%s)",
            }


[docs]def numpy_print(expr): """Prints a tensor expression in Python using numpy.""" my_strs = copy(defaults) my_strs["dot"] = "dot(%s, %s)" my_strs["transpose"] = "(%s).transpose()" my_strs["inverse"] = "inv(%s)" return print_visitor(expr, my_strs)
[docs]def latex_print(expr): """Prints a tensor expression in Latex.""" # my_strs = copy(defaults) # my_strs["dot"] = "%s %s" # my_strs["inverse"] = "{%s}^{-1}" # my_strs["name_attr"] = "latex" # my_strs["transpose"] = "{%s}^t" # my_strs["pow"] = "{%s}^{%s}" # my_strs["div_under_one"] = "{%s}^{-1}" # return print_visitor(expr, my_strs) return latex(expr)
def wrap_long_latex_line(line, length=100): ret_val = "" while (len(line) > length): i = min(line.find(' + ', length), line.find(' - ', length)) if i == -1: break nl = line[:i] line = line[i:] paren_mismatch = nl.count("\\left(") - nl.count("\\right)") if paren_mismatch > 0: nl += "\\right." * paren_mismatch line = "\\left." * paren_mismatch + line if paren_mismatch < 0: print "Unmatched parens somewhere." ret_val += nl ret_val += "\\\\\n & &" return ret_val + line
[docs]def update_dict_to_latex(update_dict, order): """Returns update dictionary and order as latex string.""" ret_val = "\\begin{eqnarray*}\n" get_line = lambda obj: wrap_long_latex_line(latex_print(obj) + "\\\\\n") for v in reversed(order): ret_val += latex_print(v) + " &=& " if isinstance(update_dict[v], set): if len(update_dict[v]) == 1: ret_val += get_line(list(update_dict[v])[0]) else: for n, eq in enumerate(update_dict[v]): ret_val += get_line(eq) if n != len(update_dict[v]) - 1: ret_val += " &||& " else: ret_val += get_line(update_dict[v]) ret_val += "\\end{eqnarray*}\n" # ret_val = ret_val.replace('(', '\\left(') # ret_val = ret_val.replace(')', '\\right)') # ret_val = ret_val.replace('--', '') return ret_val