#!/usr/bin/env python
# coding: utf-8

# # Towards Execution

# ## Evaluation

# In[18]:


from pymbolic import parse, var
from pymbolic.mapper import Mapper

expr = parse("(x**2 + y**2)**0.5")
expr


# Write an evaluator:

# In[8]:


class Evaluator(Mapper):
    def __init__(self, context):
        self.context = context

    def map_variable(self, expr):
        return self.context[expr.name]

    def map_constant(self, expr):
        return expr

    def map_sum(self, expr):
        return sum(self.rec(ch) for ch in expr.children)

    def map_power(self, expr):
        return self.rec(expr.base) ** self.rec(expr.exponent)


# In[9]:


Evaluator({"x": 5, "y": 7})(expr)


# ## Towards Executable Code

# In[25]:


class CodeWriter(Mapper):
    def __init__(self):
        self.lines = []
        self.name_nr = 0
        
    def make_name(self):
        self.name_nr += 1
        return var(f"tmp{self.name_nr}")

    def map_variable(self, expr):
        return expr

    def map_constant(self, expr):
        return expr

    def map_sum(self, expr):
        a, b = expr.children
        tmp = self.make_name()
                   
        self.lines.append("%s <- %s + %s" % (tmp, self.rec(a), self.rec(b)))
        return tmp

    def map_power(self, expr):
        tmp = self.make_name()
                   
        self.lines.append("%s <- %s ** %s" % (tmp, self.rec(expr.base), self.rec(expr.exponent)))
        return tmp


# In[27]:


cw = CodeWriter()
result = cw(expr)
for l in cw.lines:
    print(l)
print()
print(result)


# ## Common Subexpressions

# In[12]:


from pymbolic.mapper.c_code import CCodeMapper

ccm = CCodeMapper()
x = parse("x")
ccm((x+4)**17)


# Often, some parts of an expression occur multiple times in a bigger expression.

# In[14]:


u = (x+4)**3

h = parse("h")

expr = u + 2*u*h + 4*u*h**2
ccm(expr)


# - Obviously, that doesn't lead to great code. In particular, the redundancy is carried through to the code side.
# - Impulse: define variables.
# - Resist for a moment: Use expression as idenntifier. (Valid?)

# In[15]:


from pymbolic.primitives import CommonSubexpression as CSE

u = CSE((x+4)**3)

h = parse("h")

expr = u + 2*u*h + 4*u*h**2

result = ccm(expr)

for name, value in ccm.cse_name_list:
    print(name, "=", value)
    
print(result)


# In[ ]:




