Compiling mathematical expressions

Hello. In this essay, I'll show you how I implemented compilation of mathematical (numerical and logical) expressions into a delegate using Linq Expression.





Navigation: Problem · Compilation rules · Compiler · Default rules · Nice API · Performance · Work examples · Conclusion · Links





What do we want?

We want to compile the expression into a function of an arbitrary number of arguments of an arbitrary type, not only numeric, but also boolean. For instance,





var func = "x + sin(y) + 2ch(0)".Compile<Complex, double, Complex>("x", "y");
Console.WriteLine(func(new(3, 4), 1.2d));
>>> (5.932039085967226, 4)
  
var func = "x > 3 and (a implies b)".Compile<int, bool, bool, bool>("x", "a", "b");
Console.WriteLine(func(4, false, true));
>>> True
      
      



What do we have?

Since I am doing this within the framework of the existing symbolic algebra library, we will immediately proceed to compilation, already having a parser and an expression tree.





We have a base Entity class and a descendant tree.





Entity
|
+--Operators
  |
  +--Sumf
  |
  +--Minusf
  |
  ...
+--Trigonometry
  |
  +--Sinf
  |
  +--Cosf
  |
  ...
+--Discrete
  |
  +--Andf
  |
  +--Lessf
  |
      
      



This is what the type tree looks like. An expression tree is just a graph where the children of a node are the operands of an operator / function.





Each type is either abstract (only used to generalize types) or sealed. The latter is just real operators / functions / constants / other entities that occur in an expression (be it plus, sine, conjunction, number, set, etc.).





For example, this is how the sum operator is defined.





Compilation protocol

/ , , Entity , . , , .





:





public sealed record CompilationProtocol
{
	public Func<Entity, Expression> ConstantConverter { get; init; }

	public Func<Expression, Expression, Entity, Expression> BinaryNodeConverter { get; init; }

	public Func<Expression, Entity, Expression> UnaryNodeConverter { get; init; }

	public Func<IEnumerable<Expression>, Entity, Expression> AnyArgumentConverter { get; init; }
}
      
      



: ConstantConverter



, BinaryNodeConverter



, UnaryNodeConverter



.





, , , , Linq.Expression.





, . "", , .





, , :





internal static TDelegate Compile<TDelegate>(
            Entity expr, 
            Type? returnType,
            CompilationProtocol protocol,
            IEnumerable<(Type type, Variable variable)> typesAndNames
            ) where TDelegate : Delegate
      
      



  1. Entity expr



    - , .





  2. Type? returnType



    - . "" , .





  3. CompilationProtocol protocol



    - ,





  4. IEnumerable<(Type type, Variable variable)> typesAndNames



    - -, . , x , y , new[] { (typeof(int), "x"), (typeof(Complex), "y") }







:





internal static TDelegate Compile<TDelegate>(Entity expr, Type? returnType, CompilationProtocol protocol, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate
{
  //      ,    
	var subexpressionsCache = typesAndNames.ToDictionary(c => (Entity)c.variable, c => Expression.Parameter(c.type));
  //   :  ,     ,    
	var functionArguments = subexpressionsCache.Select(c => c.Value).ToArray(); // copying
  //    ,      
	var localVars = new List<ParameterExpression>();
  //   - (  )
	var variableAssignments = new List<Expression>();

  //     
	var tree = BuildTree(expr, subexpressionsCache, variableAssignments, localVars, protocol);
  // ,  ,     ,   
	var treeWithLocals = Expression.Block(localVars, variableAssignments.Append(tree));
  //    returnType,      
	Expression entireExpresion = returnType is not null ? Expression.Convert(treeWithLocals, returnType) : treeWithLocals;
  //     
	var finalLambda = Expression.Lambda<TDelegate>(entireExpresion, functionArguments);

  // 
	return finalLambda.Compile();
}
      
      



- , , . BuildTree



. linq- Entity



. :





internal static Expression BuildTree(
	Entity expr, 
	Dictionary<Entity, ParameterExpression> cachedSubexpressions, 
	List<Expression> variableAssignments, 
	List<ParameterExpression> newLocalVars,
	CompilationProtocol protocol)
      
      



BuildTree
  1. Entity expr



    - , .





  2. Dictionary<Entity, ParameterExpression> cachedSubexpressions



    - ( , ).





  3. List<Expression> variableAssignments



    - .





  4. List<ParameterExpression> newLocalVars



    - BuildTree



    ( ).





  5. CompilationProtocol protocol



    - , Entity



    Linq.Expression



    . BuildTree



    .





- BuildTree



:





internal static Expression BuildTree(Entity expr, ...)
{
  //   ,  ,    
	if (cachedSubexpressions.TryGetValue(expr, out var readyVar))
		return readyVar;

	Expression subTree = expr switch
	{
	  ...
		
    //   -   ConstantConverter
		//  
		Entity.Boolean or Number => protocol.ConstantConverter(expr),

    //    , ,  n- 
		IUnaryNode oneArg
			=> protocol.UnaryNodeConverter(BuildTree(oneArg.NodeChild, ...), expr),

		IBinaryNode twoArg
			=> protocol.BinaryNodeConverter(
				BuildTree(twoArg.NodeFirstChild, ...), 
				BuildTree(twoArg.NodeSecondChild, ...), 
				expr),

		var other => protocol.AnyArgumentConverter(
				other.DirectChildren.Select(c => BuildTree(c, ...)), 
			expr)
	};

  //      
	var newVar = Expression.Variable(subTree.Type);
	
	//    var5 = subTree
	variableAssignments.Add(Expression.Assign(newVar, subTree));
	
	// ,       
	cachedSubexpressions[expr] = newVar;
	
	// ,     
	newLocalVars.Add(newVar);
	
	return newVar;
}
      
      



. , , , Linq.Expression, , .





, , . Linq.Expression



, . - ?





.





(assumptions)

Compile<TDelegate>(Entity, Type?, CompilationProtocol, IEnumerable<(Type, Variable)>)



, , , , , .





, , (bool



, int



, long



, float



, double



, Complex



, BigInteger



).





ConstantConverter:





Entity Linq.Constant :





public static Expression ConverterConstant(Entity e)
	=> e switch
	{
		Number n => Expression.Constant(DownCast(n)),
		Entity.Boolean b => Expression.Constant((bool)b),
		_ => throw new AngouriBugException("Undefined constant type")
	};
      
      



, bool



.





DownCast

Entity.Number - :





private static object DownCast(Number num)
{
	if (num is Integer)
		return (long)num;
	if (num is Real)
		return (double)num;
	if (num is Number.Complex)
		return (System.Numerics.Complex)num;
	throw new InvalidProtocolProvided("Undefined type, provide valid compilation protocol");
}
      
      



object



, Expression.Constant



. : ?





UnaryNodeConverter:





- , Linq.Expression



.





public static Expression OneArgumentEntity(Expression e, Entity typeHolder)
  => typeHolder switch
	{
		Sinf =>         Expression.Call(GetDef("Sin", 1, e.Type), e),
		...
		Cosecantf =>    Expression.Call(GetDef("Csc", 1, e.Type), e),

		Arcsinf =>      Expression.Call(GetDef("Asin", 1, e.Type), e),
		...
		Arccosecantf => Expression.Call(GetDef("Acsc", 1, e.Type), e),

		Absf =>         Expression.Call(GetDef("Abs", 1, e.Type), e),
		Signumf =>      Expression.Call(GetDef("Sgn", 1, e.Type), e),

		Notf =>         Expression.Not(e),

		_ => throw new AngouriBugException("A node seems to be not added")
	};
      
      



( ). , , . GetDef



.





GetDef

Math



Complex



. if-, Math



, Complex



, BigInteger. , Math



int Pow(int, int)



, .





MathAllMethods ( T4), , .





GetDef



. .





BinaryNodeConverter:





Linq.Expression



.





public static Expression TwoArgumentEntity(Expression left, Expression right, Entity typeHolder)
{
	var typeToCastTo = MaxType(left.Type, right.Type);
	if (left.Type != typeToCastTo)
		left = Expression.Convert(left, typeToCastTo);
	if (right.Type != typeToCastTo)
		right = Expression.Convert(right, typeToCastTo);
	return typeHolder switch
	{
		Sumf => Expression.Add(left, right),
		...
		Andf => Expression.And(left, right),
		...
		Lessf => Expression.LessThan(left, right),
		...
		_ => throw new AngouriBugException("A node seems to be not added")
	};
}
      
      



upcast



. , , . . :





Complex:   10
double:     9
float:      8
long:       8
BigInteger: 8
int:        7
      
      



, MaxType



. , MaxType(int, int) -> int



.





A , B, B A. , MaxType(long, double) -> double



.





, - , , , . , MaxType(long, float) -> double



.





, , , . , Sumf



Expression.Add



, , Andf



, Expression.And



.





.





, . , .





API

, API :





public TDelegate Compile<TDelegate>(CompilationProtocol protocol, Type returnType, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate
      
      



. , , , . , , . - T4 Text Template. :





//                                         
public Func<TIn1, TIn2, TIn3, TOut> Compile<TIn1, TIn2, TIn3, TOut>(Variable var1, Variable var2, Variable var3)
                                       //                new()        
            => IntoLinqCompiler.Compile<Func<TIn1, TIn2, TIn3, TOut>>(this, typeof(TOut),         new(), 
                new[] { (typeof(TIn1), var1), (typeof(TIn2), var2) , (typeof(TIn3), var3)  });
      
      



.





T4-
<# for (var i = 1; i <= 8; i++) { #>
        public Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut> Compile<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>(Variable var1<# for(var t=2; t<=i; t++){ #>, Variable var<#= t #><# } #>)
            => IntoLinqCompiler.Compile<Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>>(this, typeof(TOut), new(), 
                new[] { (typeof(TIn1), var1)<# for(var t=2;t<=i;t++){ #>, (typeof(TIn<#= t #>), var<#= t #>) <# } #> });
<# } #>
      
      



. :





public static Func<TIn1, TIn2, TOut> Compile<TIn1, TIn2, TOut>(this string @this, Variable var1, Variable var2)
	=> IntoLinqCompiler.Compile<Func<TIn1, TIn2, TOut>>(@this, typeof(TOut), new(), 
		new[] { (typeof(TIn1), var1), (typeof(TIn2), var2)  });
      
      



.





BenchNormalSimple - , .





BenchMySimple - , .





BenchNormalComplicated - , .





BenchmyComplicated - , .





|                 Method |       Mean |    Error |   StdDev |
|----------------------- |-----------:|---------:|---------:|
|      BenchNormalSimple |   189.1 ns |  3.75 ns |  5.83 ns |
|          BenchMySimple |   195.7 ns |  3.92 ns |  5.50 ns |
| BenchNormalComplicated | 1,383.0 ns | 26.82 ns | 35.80 ns |
|     BenchMyComplicated |   293.6 ns |  5.74 ns |  8.77 ns |
      
      



, , , . - , .





.





var func = "sin(x)".Compile<double, double>("x");
Console.WriteLine(func(Math.PI / 2));
>>> 1

var func1 = "a > b".Compile<float, int, bool>("a", "b");
Console.WriteLine(func1(5.4f, 4));
Console.WriteLine(func1(4f, 4));
>>> True
>>> False

var cr = new CompilationProtocol()
{ 
    ConstantConverter = ent => Expression.Constant(ent.ToString()),
    BinaryNodeConverter = (a, b, t) => t switch
    {
        Sumf => Expression.Call(typeof(string)
            .GetMethod("Concat", new[] { typeof(string), typeof(string) }) ?? throw new Exception(), a, b),
        _ => throw new Exception()
    }
};
var func2 = "a + b + c + 1234"
    .Compile<Func<string, string, string, string>>(
        cr, typeof(string), 
        
        new[] { 
            (typeof(string), Var("a")), 
            (typeof(string), Var("b")), 
            (typeof(string), Var("c")) }

        );
Console.WriteLine(func2("White", "Black", "Goose"));
>>> WhiteBlackGoose1234
      
      



( - , , . , ).





  1. Linq.Expression



    .





  2. , , .





  3. , , .





  4. , , . , .





, - . , - . , , ( ).





! .





  1. GitHub of the AngouriMath project , within which I did the compilation





  2. Compilation here





  3. Compilation tests can be found here








All Articles