"""
All symbolic functions take ast expression trees (not expressions) as input.
This one can be constructed as : ast.parse(s).body[0].value
"""
import os
import copy
import ast
from yaml import ScalarNode
known_functions = ['log','exp','sin','cos','tan','max','min','Max','Min','Heaviside','abs','sign','IfThenElse','IfThen','Positive','Negative','LambertW','myzif']
Expression = ast.Expr
from ast import UnaryOp, UAdd, USub, Name, Load, Call
from ast import NodeTransformer
from typing import Tuple, List
from snowdrop.src.preprocessor.language import functions as functions_dict
from snowdrop.src.preprocessor.codegen import to_source
functions = list(functions_dict.keys())
[docs]
def list_variables(expr: Expression, funs: List[str]=None, vars: List[str]=None) -> List[Tuple[str,int]]:
if funs is None:
funs=[]
if vars is None:
vars=[]
l = ListSymbols(known_functions=functions+funs, known_variables=vars)
l.visit(expr)
if l.problems:
e = Exception('Symbolic error.')
e.problems = l.problems
raise e
return [v[0] for v in l.variables]
[docs]
def time_shift(expr: Expression, n, vars: List[str] = []) -> Expression:
"""
Shifts timing in equations variables.
Example:
time_shift(:(a+b(1)+c),1,[:b,:c]) == :(a+b(2)+c(1))
"""
eexpr = copy.deepcopy(expr)
return TimeShiftTransformer(shift=n, variables=vars).visit(eexpr)
[docs]
def stringify_variable(arg: Tuple[str, int]) -> str:
"""
Stringify a variable.
This method encodes varaible name with its lead or lag.
"""
s = arg[0]
date = arg[1]
if date == 0:
#return '{}__'.format(s)
return s+'__'
#return f'{s}__'
elif date <= 0:
#return '{}__m{}_'.format(s,-date)
return s+'__m'+str(-date)+'_'
#return f'{s}__m{-date}_'
elif date > 0:
#return '{}__{}_'.format(s,date)
return s+'__p'+str(date)+'_'
#return f'{s}__{date}_'
[docs]
def log_stringify_variable(arg: Tuple[str, int]) -> str:
"""
Return variable with a log function of time shifted variable.
This method encodes varaible name with its lead or lag.
"""
s = arg[0]
date = arg[1]
if date == 0:
#return '{}__'.format(s)
return 'log('+s+'__)'
#return f'{s}__'
elif date <= 0:
#return '{}__m{}_'.format(s,-date)
return 'log('+s+'__m'+str(-date)+'_)'
#return f'{s}__m{-date}_'
elif date > 0:
#return '{}__{}_'.format(s,date)
return 'log('+s+'__p'+str(date)+'_)'
#return f'{s}__{date}_'
[docs]
def stringify_parameter(p: str) -> str:
"""Stringify a parameter."""
return '{}'.format(p)
[docs]
def stringify(arg) -> str:
"""Stingify a variable or a parameter."""
if isinstance(arg, str):
return stringify_parameter(arg)
elif isinstance(arg, tuple):
if len(arg)==2 and isinstance(arg[0],str) and isinstance(arg[1],int):
return stringify_variable(arg)
raise Exception("Unknown canonical form: {}".format(arg))
[docs]
def stringify_symbol(arg) -> str:
"""Stingify symbol."""
if isinstance(arg, str):
return stringify_parameter(arg)
elif isinstance(arg, tuple):
if len(arg) == 2 and isinstance(arg[0], str) and isinstance(arg[1], int):
return stringify_variable(arg)
raise Exception("Unknown canonical form: {}".format(arg))
[docs]
def destringify(s: str, variables: List[str] = []) -> Tuple[int, int]:
"""Find leads and lags of a variable from its name."""
i = 0
v = ''
if "__" in s:
ind = s.rindex("__")
if not s.endswith("__"):
lead_lag = s[ind:]
v = s[:ind]
lead_lag = lead_lag.replace("_","")
if "m" in lead_lag:
lead_lag = lead_lag[1:]
if lead_lag.isnumeric():
i = - int(lead_lag)
if "p" in lead_lag:
lead_lag = lead_lag[1:]
if lead_lag.isnumeric():
i = int(lead_lag)
elif lead_lag.isnumeric():
i = int(lead_lag)
s = s[:ind]
if "_plus_" in s:
ind = s.rindex("_plus_")
lead_lag = s[ind+6:]
v = s[:ind]
if lead_lag.isnumeric():
i += int(lead_lag)
elif "_minus_" in s:
ind = s.rindex("_minus_")
lead_lag = s[ind+7:]
v = s[:ind]
if lead_lag.isnumeric():
i -= int(lead_lag)
else:
v = s
if v in variables:
j = variables.index(v)
else:
j = 0
return (j,i)
[docs]
def parse_string(text, start=None):
from lark.lark import Lark
from lark.exceptions import UnexpectedInput, UnexpectedCharacters
DIR_PATH, this_filename = os.path.split(__file__)
DATA_PATH = os.path.join(DIR_PATH, "grammar.lark")
grammar = open(DATA_PATH, "rt", encoding="utf-8").read()
parser = Lark(
grammar,
start=[
"start",
"variable",
"equation_block",
"assignment_block",
"complementarity_block",
],
)
if start is None:
start = "start"
if isinstance(text, ScalarNode):
if text.tag != "tag:yaml.org,2002:str":
# raise Exception(f"Don't know how to parse node {text}")
txt = text.value
else:
if text.start_mark is None:
txt = text.value
else:
buffer = text.end_mark.buffer
i1 = text.start_mark.pointer
i2 = text.end_mark.pointer
txt = buffer[i1:i2]
if text.style in (">", "|"):
txt = txt[1:]
else:
txt = text
try:
return parser.parse(txt, start)
except (UnexpectedInput, UnexpectedCharacters) as e:
if isinstance(text, ScalarNode):
sm = text.start_mark
# em = text.end_mark
if text.style not in (">", "|"):
new_column = sm.column + e.column
new_line = sm.line + e.line
else:
new_line = sm.line + e.line
new_column = e.column
newargs = list(e.args)
newargs[0] = e.args[0].replace(f"line {e.line}", f"line {new_line}")
newargs[0] = newargs[0].replace(f"col {e.column}", f"col {new_column}")
e.args = tuple(newargs)
e.line = new_line
e.column = new_column
raise e
[docs]
def normalize(expr: Expression, variables: List[str] = [])->Expression:
"""Replace calls to variables by their time subscripts."""
try:
en = ExpressionNormalizer(variables=variables)
cp = copy.deepcopy(expr)
e = en.visit(cp)
except:
print("Error: ")
print(to_source(expr))
e = None
return e
[docs]
def log_normalize(expr: Expression, variables: List[str] = [], log_variables: List[str] = [])->Expression:
"""Replace calls to variables by their time subscripts."""
try:
en = ExpressionLogNormalizer(variables=variables,log_variables=log_variables)
cp = copy.deepcopy(expr)
e = en.visit(cp)
except:
print("Error: ")
print(to_source(expr))
e = None
return e
[docs]
def std_tsymbol(tsymbol):
"""Return string encoded with leads/lags."""
s, date = tsymbol
if date == 0:
return '_{}_'.format(s)
elif date <= 0:
return '_{}_m{}_'.format(s, str(-date))
elif date >= 0:
return '_{}__{}_'.format(s, str(date))
[docs]
class StandardizeDatesSimple(NodeTransformer):
"""Replaces calls to variables by time subscripts."""
def __init__(self, variables):
self.variables = variables
# self.variables = tvariables # ???
[docs]
def visit_Name(self, node):
"""Visitor for Name node."""
name = node.id
if name in self.variables:
return Name(id=std_tsymbol((name,0)),ctx=Load())
else:
return node
[docs]
def visit_Call(self, node):
"""Visitor for Call node."""
name = node.func.id
args = node.args[0]
if name in self.variables:
if isinstance(args, UnaryOp):
# we have s(+1)
if (isinstance(args.op, UAdd)):
args = args.operand
date = args.n
elif (isinstance(args.op, USub)):
args = args.operand
date = -args.n
else:
raise Exception("Unrecognized subscript.")
else:
date = args.n
newname = std_tsymbol((name, date))
if newname is not None:
return Name(newname, Load())
else:
# , keywords=node.keywords, starargs=node.starargs, kwargs=node.kwargs)
return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[])
[docs]
def timeshift(expr, variables, shift):
eexpr = copy.deepcopy(expr)
return TimeShiftTransformer(variables, shift).visit(eexpr)
[docs]
class Compare:
"""
Compares two ast tree instances.
.. currentmodule: preprocessor
"""
def __init__(self):
self.d = {}
[docs]
def compare(self, A, B):
"""Compare two nodes."""
if isinstance(A, ast.Name) and (A.id[0] == '_'):
if A.id not in self.d:
self.d[A.id] = B
return True
else:
return self.compare(self.d[A.id], B)
if not (A.__class__ == B.__class__): return False
if isinstance(A, ast.Name):
return A.id == B.id
elif isinstance(A, ast.Call):
if not self.compare(A.func, B.func): return False
if not len(A.args)==len(B.args): return False
for i in range(len(A.args)):
if not self.compare(A.args[i], B.args[i]): return False
return True
elif isinstance(A, ast.Num):
return A.n == B.n
elif isinstance(A, ast.Expr):
return self.compare(A.value, B.value)
elif isinstance(A, ast.Module):
if not len(A.body)==len(B.body): return False
for i in range(len(A.body)):
if not self.compare(A.body[i], B.body[i]): return False
return True
elif isinstance(A, ast.BinOp):
if not isinstance(A.op, B.op.__class__): return False
if not self.compare(A.left, B.left): return False
if not self.compare(A.right, B.right): return False
return True
elif isinstance(A, ast.UnaryOp):
if not isinstance(A.op, B.op.__class__): return False
return self.compare(A.operand, B.operand)
elif isinstance(A, ast.Subscript):
if not self.compare(A.value, B.value): return False
return self.compare(A.slice, B.slice)
elif isinstance(A, ast.Index):
return self.compare(A.value, B.value)
elif isinstance(A, ast.Compare):
if not self.compare(A.left, B.left): return False
if not len(A.ops)==len(B.ops): return False
for i in range(len(A.ops)):
if not self.compare(A.ops[i], B.ops[i]): return False
if not len(A.comparators)==len(B.comparators): return False
for i in range(len(A.comparators)):
if not self.compare(A.comparators[i], B.comparators[i]): return False
return True
elif isinstance(A, ast.In):
return True
elif isinstance(A, (ast.Eq, ast.LtE)):
return True
else:
print(A.__class__)
raise Exception("Not implemented")
[docs]
def compare(a,b):
"""Compare two nodes."""
comp = Compare()
val = comp.compare(a,b)
return val
[docs]
def match(m,s):
comp = Compare()
val = comp.compare(m,s)
d = comp.d
if len(d) == 0:
return val
else:
return d
[docs]
class ListNames(ast.NodeVisitor):
def __init__(self):
self.found = []
[docs]
def visit_Name(self, name):
self.found.append(name.id)
[docs]
def get_names(expr):
ln = ListNames()
ln.visit(expr)
return [e for e in ln.found]
[docs]
def eval_scalar(tree):
try:
if isinstance(tree, ast.Num):
return tree.n
elif isinstance(tree, ast.UnaryOp):
if isinstance(tree.op, ast.USub):
return -tree.operand.n
if isinstance(tree.op, ast.UAdd):
return tree.operand.n
else:
raise Exception("Don't know how to do that.")
except:
raise Exception("Don't know how to do that.")
[docs]
class ExpressionChecker(ast.NodeVisitor):
"""
Checks AST expressions.
.. currentmodule: preprocessor
"""
def __init__(self, spec_variables, known_functions, known_constants):
self.spec_variables = spec_variables
self.known_functions = known_functions
self.known_constants = known_constants
self.functions = []
self.variables = []
self.problems = []
[docs]
def visit_Call(self, call):
name = call.func.id
colno = call.func.col_offset
if name in self.spec_variables:
try:
assert(len(call.args)==1)
n = eval_scalar(call.args[0])
allowed_timing = self.spec_variables[name]
if allowed_timing is None or (n in allowed_timing):
self.variables.append((name, n, call.func.col_offset))
else:
self.problems.append([name,n,colno,'incorrect_timing',allowed_timing])
except Exception as e:
print(e)
self.problems.append([name,None,colno,'timing_error'])
elif name in self.known_functions:
self.functions.append((name, colno))
for e in call.args:
self.visit(e)
else:
self.problems.append([name, None, colno,'unknown_function'])
[docs]
def visit_Name(self, name):
# colno = name.colno
colno = name.col_offset
n = 0
name = name.id
if name in self.spec_variables:
allowed_timing = self.spec_variables[name]
if (allowed_timing is None) or (n in allowed_timing):
self.variables.append((name, n, colno))
else:
self.problems.append([name,n,colno,'incorrect_timing',allowed_timing])
elif name not in self.known_constants:
self.problems.append([name,0,colno,'unknown_variable'])
[docs]
def check_expression(expr, spec_variables, known_functions=[]):
from snowdrop.src.preprocessor.language import functions, constants
func = list(functions.keys()) + known_functions
ch = ExpressionChecker(spec_variables, func, constants)
ch.visit(expr)
return dict(
functions = ch.functions,
variables = ch.variables,
problems = ch.problems
)
[docs]
class ExpressionNormalizer(NodeTransformer):
"""
Replaces calls to variables by time subscripts.
.. currentmodule: preprocessor.symbolic
"""
def __init__(self, variables=None, functions=None):
self.variables = variables if variables is not None else []
if functions is None:
self.functions = [e for e in functions_dict.keys()]
else:
self.functions = functions
[docs]
def visit_Name(self, node):
name = node.id
# if name self.functions:
# return node
if name in self.variables:
return Name(id=stringify_variable((name,0)), ctx=Load())
else:
return Name(id=stringify_parameter(name), ctx=Load())
[docs]
def visit_Call(self, node):
name = node.func.id
args = node.args[0]
if name in self.variables or name not in self.functions:
try:
date = eval_scalar(args)
except:
raise Exception("Unrecognized subscript: name ={} args={}".format(name,args))
newname = stringify_variable((name, date))
if newname is not None:
return Name(newname, Load())
else:
return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[])
[docs]
class ExpressionLogNormalizer(NodeTransformer):
"""
Replaces calls to variables by log function of their time subscripts.
.. currentmodule: preprocessor.symbolic
"""
def __init__(self, variables=[], log_variables=[], functions=None):
self.variables = variables
self.log_variables = log_variables
if functions is None:
self.functions = [e for e in functions_dict.keys()]
else:
self.functions = functions
[docs]
def visit_Name(self, node):
name = node.id
# if name self.functions:
# return node
if name in self.log_variables:
return Name(id=log_stringify_variable((name,0)), ctx=Load())
elif name in self.variables:
return Name(id=stringify_variable((name,0)), ctx=Load())
else:
return Name(id=stringify_parameter(name), ctx=Load())
[docs]
def visit_Call(self, node):
name = node.func.id
args = node.args[0]
if name in self.variables or name not in self.functions:
try:
date = eval_scalar(args)
except:
raise Exception("Unrecognized subscript: name ={} args={}".format(name,args))
if name in self.log_variables:
newname = log_stringify_variable((name, date))
else:
newname = stringify_variable((name, date))
if newname is not None:
return Name(newname, Load())
else:
return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[])
[docs]
class ListSymbols(ast.NodeVisitor):
"""
Creates a lists of symbols by visiting each Call object in ast expression tree.
.. currentmodule: preprocessor.symbolic
"""
def __init__(self, known_functions=[], known_variables=[]):
self.known_functions = known_functions
self.known_variables = known_variables
self.functions = []
self.variables = []
self.constants = []
self.problems = []
[docs]
def visit_Call(self, call):
name = call.func.id
colno = call.func.col_offset
if name in self.known_functions:
self.functions.append((name, colno))
[self.visit(e) for e in call.args]
else:
try:
assert(len(call.args) == 1)
n = int(eval_scalar(call.args[0]))
self.variables.append(((name, n), colno))
except:
if name in self.known_variables + [vv[0][0] for vv in self.variables]:
self.problems.append([name, 0, colno, 'incorrect subscript'])
else:
self.problems.append([name, 0, colno, 'unknown_function'])
# [self.visit(e) for e in call.args]
[docs]
def visit_Name(self, name):
# colno = name.colno
colno = name.col_offset
name = name.id
if name in self.known_variables:
self.variables.append(((name, 0), colno))
elif name in self.known_functions:
self.problems.append([name, colno, 'function_not_called'])
else:
self.constants.append((name, colno))
# def get_variables(variables, expr):
# ln = ListVariables(variables)
# ln.visit(expr)
# return ln.found
# def get_functions(variables, expr):
# ln = ListVariables(variables)
# ln.visit(expr)
# return ln.functions
# class ExpressionChecker(ast.NodeVisitor):
#
# def __init__(self, variables, functions):
#
# self.allowed_variables = variables
# self.functions = functions
# self.found = []
# self.problems = []
#
# def visit_Call(self, call):
# name = call.func.id
# if name in self.variables:
# assert(len(call.args)==1)
# print(call.args[0])
# n = eval_scalar(call.args[0])
# self.found.append((name, n))
# elif name in self.functions:
# self.functions.append(name)
# for e in call.args:
# self.visit(e)
# else:
# for e in call.args:
# self.visit(e)
#
# def visit_Name(self, name):
# name = name.id
# if name in self.variables:
# self.found.append((name,0))
#
# def check_expression(expr, variables, functions):
# ec = ExpressionChecker(variables, functions)
# pbs = ec.visit(ec)
if __name__ == '__main__':
"""
The main program
"""
var = ['g', 'p_pdot1', 'p_pdot2', 'p_pdot3', 'p_rs1', 'p_y1', 'p_y2', 'p_y3', 'p_pdot1__m1_']
s = var[1]
x = stringify((s,-1))
print(x)
s = var[-1]
sd = destringify(s,var)
print(s,sd)