real treeNodeStep = 0.3cm;
real treeLevelStep = 2.2cm;
real treeMinNodeHeight = 0.6cm;
struct TreeNode 
{
TreeNode parent;
TreeNode[] children;
frame content;
string prob;
pair pos;
real adjust;
bool labelpos;
}

void add( TreeNode child, TreeNode parent )
{
child.parent = parent;
parent.children.push( child );
}

TreeNode makeNode( TreeNode parent = null, frame f, string p, bool l )
{
TreeNode child = new TreeNode;
child.content = f;
child.prob = p;
child.labelpos=l;
if( parent != null ) {add( child, parent );}
return child;
}

TreeNode makeNode( TreeNode parent = null, Label label, string p, bool l )
{
frame f;
label(f,label);
return makeNode( parent, f,p,l );
}

real layout( int level, TreeNode node )
{
real maxp=0;
real minp=1e10;
if( node.children.length > 0 ) 
  {
  real height[] = new real[node.children.length];
  real curHeight = 0;
  for( int i=node.children.length-1;i>=0; --i ) 
    {
    height[i] = layout( level+1, node.children[i] );
    node.children[i].pos = (level*treeLevelStep,curHeight + height[i]/2);
    maxp=max(maxp,node.children[i].pos.y);
    minp=min(minp,node.children[i].pos.y);
    curHeight += height[i] + treeNodeStep;
    }
  real midPoint=(maxp+minp)/2;
  for( int i=node.children.length-1;i>=0; --i ) 
    {
    node.children[i].adjust = - midPoint;
    }
  return max( (max(node.content)-min(node.content)).y,sum(height)+treeNodeStep*(height.length-1) );
  }
else {return max( treeMinNodeHeight, (max(node.content)-min(node.content)).y );}
}

void drawAll( TreeNode node, frame f )
{
pair pos;
if( node.parent != null ) pos = (0,node.parent.pos.y+node.adjust);
else pos = (0,node.adjust);
node.pos += pos;
node.content = shift(node.pos)*node.content;
add( f, node.content );
string proba=node.prob;
bool arrow=(find(proba,"@",0)!=-1);
if (arrow) proba=replace(proba,"@","");
if( node.parent != null ) 
  {
  path p = point(node.content, W)--point(node.parent.content,E);
  if (arrow) draw(p, currentpen,BeginArrow(2mm) );
  else draw(p, currentpen);
  if( node.prob != "" ) 
  {
  if (node.labelpos) draw(Label(scale(#SCALE#)*proba),p,N);
  else draw(Label(scale(#SCALE#)*proba),p,S);
  }
  else 
  {
    if (arrow) draw(p, currentpen,BeginArrow(2mm) );
    else draw(p, currentpen);
  }
  }
for( int i = 0; i < node.children.length; ++i ) drawAll( node.children[i], f );
}

void draw( TreeNode root, pair pos )
{
frame f;
root.pos = (0,0);
layout( 1, root );
drawAll( root, f );
add(f,pos);
}
