#!/usr/bin/env python3

import os
import sys
import time
import numpy
import argparse
import subprocess
import multiprocessing

from collections.abc import Iterable
from scipy.spatial.transform import Rotation, Slerp

ROOT = os.path.abspath(os.path.dirname(__file__))

input_file = lambda fname : os.path.join(ROOT, fname)
exout_file = lambda fname :  os.path.join(ROOT, 'output_pysp', fname)

GRAD2RAD = numpy.pi/180.0

QUAT_SIGNS = ((+1.0, +1.0, +1.0, +1.0),
              (+1.0, +1.0, +1.0, -1.0),
              (+1.0, +1.0, -1.0, +1.0),
              (+1.0, +1.0, -1.0, -1.0),
              (+1.0, -1.0, +1.0, +1.0),
              (+1.0, -1.0, +1.0, -1.0),
              (+1.0, -1.0, -1.0, +1.0),
              (+1.0, -1.0, -1.0, -1.0),
              (-1.0, +1.0, +1.0, +1.0),
              (-1.0, +1.0, +1.0, -1.0),
              (-1.0, +1.0, -1.0, +1.0),
              (-1.0, +1.0, -1.0, -1.0),
              (-1.0, -1.0, +1.0, +1.0),
              (-1.0, -1.0, +1.0, -1.0),
              (-1.0, -1.0, -1.0, +1.0),
              (-1.0, -1.0, -1.0, -1.0))


def axis(apba_dir, trsm_dir, return_fail=False):
  mat, fail = numpy.identity(3), True

  norm_apba = numpy.linalg.norm(apba_dir)
  norm_trsm = numpy.linalg.norm(trsm_dir)
  dotp = apba_dir.dot(trsm_dir)

  if norm_apba > 0.0 and norm_trsm > 0.0 and abs(dotp) > 0.0:
    e1 = apba_dir/norm_apba
    tmp = trsm_dir-e1.dot(trsm_dir)*e1
    e2 = tmp/numpy.linalg.norm(tmp)
    e0 = numpy.cross(e1, e2)
    mat = numpy.array([e0, e1, e2])
    fail = False

  rot = Rotation.from_matrix(mat)
  if return_fail:
    return rot, fail

  return rot


def orient(rot:Rotation, alpha, beta):
  ca, sa = numpy.cos(-alpha), numpy.sin(-alpha)
  cb, sb = numpy.cos(beta), numpy.sin(beta)

  # rotation around z-axis
  rotA = Rotation.from_matrix([[+ca,-sa,0.0],[+sa,+ca,0.0],[0.0,0.0,1.0]])
  # rotation around x-axis
  rotB = Rotation.from_matrix([[1.0,0.0,0.0],[0.0,+cb,+sb],[0.0,-sb,+cb]])

  return rotB*(rotA*rot)


def bislerp(rotA:Rotation, rotB:Rotation, t):
  qA, qB = rotA.as_quat(scalar_first=True), rotB.as_quat(scalar_first=True)

  qM, max_norm = qA, 0.0
  for sign in QUAT_SIGNS:
    qT = (sign[0]*qA[0], sign[1]*qA[1], sign[2]*qA[2], sign[3]*qA[3])
    qP = (qT[0]*qB[0] - qT[1]*qB[1] - qT[2]*qB[2] - qT[3]*qB[3],
          qT[0]*qB[1] + qT[1]*qB[0] + qT[2]*qB[3] - qT[3]*qB[2],
          qT[0]*qB[2] - qT[1]*qB[3] + qT[2]*qB[0] + qT[3]*qB[1],
          qT[0]*qB[3] + qT[1]*qB[2] - qT[2]*qB[1] + qT[3]*qB[0])
    norm = numpy.sqrt(qP[0]*qP[0] + qP[1]*qP[1] + qP[2]*qP[2] + qP[3]*qP[3])
    if norm > max_norm:
      qM, max_norm = qT, norm

  rotM = Rotation.from_quat(qM, scalar_first=True)
  slerp = Slerp((0.0, 1.0), Rotation.concatenate([rotM, rotB]))

  return slerp(t)


def assign_fibers_sheets(phi_epi, phi_lv, phi_rv, psi_ab,
                         grad_phi_epi, grad_phi_lv, grad_phi_rv, grad_psi_ab,
                         alpha_endo, alpha_epi, beta_endo, beta_epi,
                         num_workers=16):

  assert (phi_epi.ndim == 1) and (grad_phi_epi.ndim == 2) and  (grad_phi_epi.shape[1] == 3) and \
         (phi_epi.size == grad_phi_epi.shape[0]) and (phi_epi.size == phi_lv.size) and \
         (phi_epi.size == phi_rv.size) and (phi_epi.size == psi_ab.size) and \
         (grad_phi_epi.shape == grad_phi_lv.shape) and (grad_phi_epi.shape == grad_phi_rv.shape) and \
         (grad_phi_epi.shape == grad_psi_ab.shape)

  alpha_s = lambda d: (alpha_endo*(1.0-d) - alpha_endo*d)*GRAD2RAD
  alpha_w = lambda d: (alpha_endo*(1.0-d) + alpha_epi*d)*GRAD2RAD
  beta_s = lambda d: (beta_endo*(1.0-d) - beta_endo*d)*GRAD2RAD
  beta_w = lambda d: (beta_endo*(1.0-d) + beta_epi*d)*GRAD2RAD

  num = phi_epi.shape[0]
  shared_array = multiprocessing.Array('f', 6*num)
  chunk_size = num // num_workers
  processes = []

  def assign_fs(shared_array, start, end):
    arr = numpy.frombuffer(shared_array.get_obj(), dtype=numpy.float32)
    count_fails = 0
    for i in range(start, end):
        matFST = numpy.zeros((3, 3), dtype=numpy.float32)

        try:
          apba_dir = grad_psi_ab[i]
          trsm_dir_l, trsm_dir_r, trsm_dir_e = grad_phi_lv[i], grad_phi_rv[i], grad_phi_epi[i]

          ax_l, fail_l = axis(apba_dir, -trsm_dir_l, return_fail=True)
          ax_r, fail_r = axis(apba_dir, +trsm_dir_r, return_fail=True)
          ax_e, fail_e = axis(apba_dir, +trsm_dir_e, return_fail=True)

          fail_l = fail_r = True

          if fail_l and fail_r:
            if not fail_e:
              d = phi_epi[i]
              rotFST = orient(ax_e, alpha_w(d), beta_w(d))
          else:
            if fail_l and not fail_r:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotV = orient(ax_r, alpha_s(d), beta_s(d))
            elif not fail_l and fail_r:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotV = orient(ax_l, alpha_s(d), beta_s(d))
            else:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotVl = orient(ax_l, alpha_s(d), beta_s(d))
              rotVr = orient(ax_r, alpha_s(d), beta_s(d))
              rotV = bislerp(rotVl, rotVr, d)

            if fail_e:
              rotFST = rotV
            else:
              d = phi_epi[i]
              rotE = orient(ax_e, alpha_w(d), beta_w(d))
              rotFST = bislerp(rotV, rotE, d)

          if rotFST is not None:
            matFST = rotFST.as_matrix()

        except Exception:
          count_fails += 1

        finally:
          arr[6*i:(6*i+3)] = matFST[0]
          arr[(6*i+3):(6*i+6)] = matFST[1]

  shared_array = multiprocessing.Array('f', 6*num)
  chunk_size = num // num_workers
  processes = []

  for i in range(num_workers):
    start = i * chunk_size
    # Make sure last process covers the remainder
    end = num if i == (num_workers-1) else (i+1)*chunk_size
    p = multiprocessing.Process(target=assign_fs, args=(shared_array, start, end))
    p.start()
    processes.append(p)

  for p in processes:
      p.join()

  # Convert shared array to numpy array for use in main process
  lon = numpy.frombuffer(shared_array.get_obj(), dtype=numpy.float32)

  return lon.reshape(num, 2, 3)


def write_vtx(filename, vtx, domain='extra'):
  npvtx = numpy.sort(numpy.array(vtx, dtype=numpy.int64).flatten())
  hdr = '{}\n{}'.format(len(vtx), domain)
  numpy.savetxt(filename, npvtx, fmt='%ld', header=hdr, comments='')


def read_vtx_from_tri_surf(file_name):
  if not (os.path.exists(file_name) and file_name.endswith('.surf')):
    return None
  vtx = numpy.loadtxt(file_name, usecols=(1, 2, 3), skiprows=1, dtype=numpy.int64)
  return numpy.unique(vtx)


def write_dummy_tri_surf_from_vtx(file_name, vtx):
  vtx = numpy.unique(vtx)
  fp = open(file_name, 'w')
  fp.write('{}\n'.format(len(vtx)))
  for v in vtx:
    fp.write('Tr {0} {0} {0}\n'.format(v))
  fp.close()


def read_num_points(file_name):
  if file_name.endswith('.pts'):
    with open(file_name, 'r') as fp:
      num_pnts = int(fp.readline().strip())
  elif file_name.endswith('.bpts'):
    with open(file_name, 'rb') as fp:
      header = fp.read(1024).decode('utf-8')
      num_pnts = int(header.split()[0])
  else:
    num_pnts = -1
  return num_pnts


def read_num_elements(file_name):
  if file_name.endswith('.elem'):
    with open(file_name, 'r') as fp:
      num_elem = int(fp.readline().strip())
  elif file_name.endswith('.belem'):
    with open(file_name, 'rb') as fp:
      header = fp.read(1024).decode('utf-8')
      num_elem = int(header.split()[0])
  else:
    num_elem = -1
  return num_elem


def write_fibers(file_name, fibers, txt_float_fmt='%.6f'):
  fibers = numpy.array(fibers, dtype=numpy.float32)
  num_elem = fibers.shape[0]
  fibers = fibers.reshape(num_elem, -1)
  num_fib = fibers.shape[1] // 3
  if file_name.endswith('.blon'):
    endianess = 0 if sys.byteorder == 'little' else 1
    header_string = '{:d} {:d} {:d} {:d}'.format(num_fib, num_elem, endianess, 666).encode('utf-8')
    header_bytes = bytearray(1024)
    header_bytes[:len(header_string)] = header_string[:]
    header_array = numpy.frombuffer(header_bytes, dtype=numpy.uint8)
    with open(file_name, 'wb') as fp:
      fp.write(header_array.tobytes())
      fp.write(fibers.tobytes())
  else:
    numpy.savetxt(file_name, fibers, fmt=txt_float_fmt, header=str(num_fib), comments='')


class Mapping:

  def __init__(self, base_name):
    nod_file = '{}.nod'.format(base_name)
    self._node_s2m = numpy.fromfile(nod_file, dtype=numpy.int64)
    eidx_file = '{}.eidx'.format(base_name)
    self._elem_s2m = numpy.fromfile(eidx_file, dtype=numpy.int64)

  def prolongate(self, sub_data, data_size, *, nodal_data=True):
    data_shape = list(sub_data.shape)
    data_shape[0] = data_size
    data = numpy.zeros(data_shape, dtype=sub_data.dtype)
    if nodal_data:
      for sidx, midx in enumerate(self._node_s2m):
        data[midx] = sub_data[sidx]
    else:
      for sidx, midx in enumerate(self._elem_s2m):
        data[midx] = sub_data[sidx]
    return data

  def restrict(self, data, *, nodel_data=True):
    if nodel_data:
      return numpy.take(data, self._node_s2m, axis=0)
    else:
      return numpy.take(data, self._elem_s2m, axis=0)

  def insert_data(self, sub_data, data, *, nodal_data=True):
    assert (sub_data.ndim == data.ndim) and \
           ((sub_data.ndim <= 1) or (sub_data.shape[1:] == data.shape[1:])) and \
           (sub_data.dtype == data.dtype)

    if nodal_data:
      for sidx, midx in enumerate(self._node_s2m):
        data[midx] = sub_data[sidx]
    else:
      for sidx, midx in enumerate(self._elem_s2m):
        data[midx] = sub_data[sidx]
    return data

  def map_selection(self, sel, *, nodal_selection=True, map_forward=True):
    if map_forward:
      raise NotImplementedError('Mapping.map_selection(map_forward=True) not implemented yet !')
    else:
      if nodal_selection:
        return numpy.take(self._node_s2m, sel)
      else:
        return numpy.take(self._elem_s2m, sel)


def print_cmd(cmd):
  print('\n'+' '.join(cmd)+'\n')


COUNT_MESHTOOL_CALLS = 0


def mt_extract_elemtags(mesh_base, mesh_fmt, output_file, *, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'extract', 'tags',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-odat={}'.format(output_file)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_extract_surfaces(mesh_base, mesh_fmt, surf_ops, *, smesh_fmt=None, etags_file=None, stdout=None):
  global COUNT_MESHTOOL_CALLS

  ops, surfs = list(), list()
  for (surf, op) in  surf_ops:
    ops.append(op)
    surfs.append('{}_{}'.format(mesh_base, surf))
  op_str = ';'.join(ops)
  surf_str = ','.join(surfs)

  if smesh_fmt is None:
    smesh_fmt = mesh_fmt
  cmd = ['meshtool', 'extract', 'surface',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-op={}'.format(op_str),
         '-surf={}'.format(surf_str)]
  if smesh_fmt is not None:
    cmd += ['-ofmt={}'.format(smesh_fmt)]
  if etags_file is not None:
    cmd += ['-tag_file={}'.format(etags_file)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_map_surfaces(mesh_base, smesh_base, *, files=None, stdout=None):
  global COUNT_MESHTOOL_CALLS

  mesh_dir = os.path.dirname(mesh_base)

  files2map = '{0}.vtx,{0}.surf'.format(os.path.join(mesh_dir, '*'))
  if files is not None and isinstance(files, Iterable):
    files2map = ['{0}.vtx,{0}.surf'.format(os.path.join(mesh_dir, file)) for file in files]
    files2map = ','.join(files2map)
  cmd = ['meshtool', 'map',
        '-submsh={}'.format(smesh_base),
        '-files={}'.format(files2map),
        '-outdir={}'.format(os.path.dirname(smesh_base))]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_extract_mesh_and_map_surfaces(mesh_base, mesh_fmt, tags_str, smesh_base, *,
                                     smesh_fmt=None, files=None, return_mapping=False,
                                     stdout=None):
  global COUNT_MESHTOOL_CALLS

  if smesh_fmt is None:
    smesh_fmt = mesh_fmt
  cmd = ['meshtool', 'extract', 'mesh',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-tags={}'.format(tags_str),
         '-submsh={}'.format(smesh_base),
         '-ofmt={}'.format(smesh_fmt)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)

  mesh_dir = os.path.dirname(mesh_base)

  files2map = '{0}.vtx,{0}.surf'.format(os.path.join(mesh_dir, '*'))
  if files is not None and isinstance(files, Iterable):
    files2map = ['{0}.vtx,{0}.surf'.format(os.path.join(mesh_dir, file)) for file in files]
    files2map = ','.join(files2map)
  cmd = ['meshtool', 'map',
        '-submsh={}'.format(smesh_base),
        '-files={}'.format(files2map),
        '-outdir={}'.format(os.path.dirname(smesh_base))]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)

  return Mapping(smesh_base) if return_mapping else None


def mt_interpolate_node2elem(mesh_base, data_file, *, output_file=None, normalize=False, stdout=None):
  global COUNT_MESHTOOL_CALLS

  if output_file is None:
    output_file = data_file
  cmd = ['meshtool', 'interpolate', 'node2elem',
         '-omsh={}'.format(mesh_base),
         '-idat={}'.format(data_file),
         '-odat={}'.format(output_file)]
  if normalize:
    cmd += ['-norm']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_interpolate_elem2node(mesh_base, data_file, *, output_file=None, normalize=False, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'interpolate', 'node2elem',
         '-omsh={}'.format(mesh_base),
         '-idat={}'.format(data_file),
         '-odat={}'.format(output_file)]
  if normalize:
    cmd += ['-norm']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_generate_distance_field(mesh_base, mesh_fmt, output_file, start_surf, *,
                               end_surf=None, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'generate', 'distancefield',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-odat={}'.format(output_file),
         '-ssurf={}'.format(start_surf)]
  if end_surf is not None:
    cmd += ['-esurf={}'.format(end_surf)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_extract_gradient(mesh_base, mesh_fmt, data_file, grad_file, *, normalize=False, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'extract', 'gradient',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-idat={}'.format(data_file),
         '-odat={}'.format(grad_file),
         '-mode=1']
  if normalize:
    cmd += ['-norm']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_smooth_elem_data(mesh_base, data_file, *, output_file=None, normalize=False, stdout=None):
  global COUNT_MESHTOOL_CALLS

  if output_file is None:
    output_file = data_file
  cmd = ['meshtool', 'interpolate', 'elem2node',
         '-omsh={}'.format(mesh_base),
         '-idat={}'.format(data_file),
         '-odat={}'.format(output_file)]
  if normalize:
    cmd += ['-norm']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)

  cmd = ['meshtool', 'interpolate', 'node2elem',
         '-omsh={}'.format(mesh_base),
         '-idat={}'.format(output_file),
         '-odat={}'.format(output_file)]
  if normalize:
    cmd += ['-norm']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_smooth_mesh(mesh_base, mesh_fmt, smesh_base, tags_str, *, smesh_fmt=None,
                   num_it=100, smth_coeff=0.14, lapl_level=1, threshold=0.95, stdout=None):

  global COUNT_MESHTOOL_CALLS

  if smesh_fmt is None:
    smesh_fmt = mesh_fmt
  cmd = ['meshtool', 'smooth', 'mesh',
         '-msh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-tags={}'.format(tags_str),
         '-thr={}'.format(threshold),
         '-iter={}'.format(num_it),
         '-smth={}'.format(smth_coeff),
         '-lpc_lvl={}'.format(lapl_level),
         '-outmsh={}'.format(smesh_base),
         '-ofmt={}'.format(smesh_fmt)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_extrude_image(image, mode, radius, tag, output_image, *, new_tag=None, tags=None, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'itk', 'extrude',
         '-msh={}'.format(image),
         '-mode={}'.format(mode),
         '-rad={}'.format(radius),
         '-regtag={}'.format(tag),
         '-outmsh={}'.format(output_image)]
  if new_tag is not None:
    cmd += ['-newtag={}'.format(new_tag)]
  if tags is not None and isinstance(tags, Iterable):
    cmd += ['-tags={}'.format(','.join(map(str, tags)))]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_resample_image(image, ref_factor, output_image, *, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'itk', 'resample',
         '-msh={}'.format(image),
         '-ref={}'.format(ref_factor),
         '-outmsh={}'.format(output_image)]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def mt_mesh_from_image(image, mesh_base, mesh_fmt, *, scale=1.0, tags=None, tetrahedralize=False, stdout=None):
  global COUNT_MESHTOOL_CALLS

  cmd = ['meshtool', 'itk', 'mesh',
         '-msh={}'.format(image),
         '-outmsh={}'.format(mesh_base),
         '-ofmt={}'.format(mesh_fmt),
         '-scale={}'.format(scale)]
  if tags is not None and isinstance(tags, Iterable):
    cmd += ['-tags={}'.format(','.join(map(str, tags)))]
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)

  if not tetrahedralize:
    return

  cmd = ['meshtool', 'convert',
         '-imsh={}'.format(mesh_base),
         '-ifmt={}'.format(mesh_fmt),
         '-omsh={}'.format(mesh_base),
         '-ofmt={}'.format(mesh_fmt),
         '-make_tet=2']
  print_cmd(cmd)
  COUNT_MESHTOOL_CALLS += 1
  subprocess.check_call(cmd, stdout=stdout, stderr=subprocess.STDOUT)


def main(num_procs, refine_img=False, compute_fibers=True, mesh_format='carp_txt', stdout=None):
  assert mesh_format in ('carp_txt', 'carp_bin')

  pnts_ext, elem_ext, lons_ext = 'pts', 'elem', 'lon'
  if mesh_format == 'carp_bin':
    pnts_ext, elem_ext, lons_ext = 'bpts', 'belem', 'blon'

  input_image = input_file('img_bivbp_seg.vtk')
  output_image = exout_file('image_bivbp_seg_extrude_lvbp.vtk')
  mt_extrude_image(input_image, 'out', 2, 31, output_image, new_tag=41, tags=[0], stdout=stdout)

  input_image = output_image
  output_image = exout_file('image_bivbp_seg_extrude_lvbp_rvbp.vtk')
  mt_extrude_image(input_image, 'out', 2, 36, output_image, new_tag=46, tags=[0], stdout=stdout)

  if refine_img:
    input_image = output_image
    output_image = exout_file('image_bivbp_seg_extrude_lvbp_rvbp_ref.vtk')
    mt_resample_image(input_image, 2.0, output_image, stdout=stdout)

  input_image = output_image
  bivbp_mesh_base, bivbp_mesh_fmt = exout_file('mesh'), mesh_format
  mt_mesh_from_image(input_image, bivbp_mesh_base, bivbp_mesh_fmt, scale=1000.0, tetrahedralize=True, stdout=stdout)

  #bivbp_mesh_base, bivbp_mesh_fmt = exout_file('mesh'), mesh_format
  mt_smooth_mesh(bivbp_mesh_base, bivbp_mesh_fmt, bivbp_mesh_base, '1/2/3/4/6/7/31/36',
                 num_it=300, lapl_level=2, smth_coeff=0.25, stdout=stdout)

  # surface opteration
  ops = [('lvepi', '1,2,3,4-31,41'),          # lv-epi surface
         ('rvepi', '6,7-1,2,3,4,36,46'),      # rv-epi surface
         ('lvendo', '1,2,3,4:31'),            # lv-endo surface
         ('rvendo', '1,2,3,4,6,7:36'),        # rv-endo surface
         ('lvbase', '1,2,3,4:41'),            # lv-base surface
         ('rvbase', '1,2,3,4,6,7:46'),        # rv-base surface
         ('base', '1,2,3,4,6,7:41,46'),       # base surface
         ('junc', '1,2,3,4:6,7'),             # junction surface
         ('rvantpost', '6:7'),                # rv-ant-post interface
         ('lvseptfree', '1,2:3,4'),           # lv-sept-free interface
         ('lvantpost', '1,3:2,4')]            # lv-ant-post interface

  mt_extract_surfaces(bivbp_mesh_base, bivbp_mesh_fmt, ops, stdout=stdout)

  bivbp_junc_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(bivbp_mesh_base, 'junc'))
  bivbp_rvantpost_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(bivbp_mesh_base, 'rvantpost'))
  bivbp_lvseptfree_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(bivbp_mesh_base, 'lvseptfree'))
  bivbp_lvantpost_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(bivbp_mesh_base, 'lvantpost'))

  bivbp_rvapex_vtx = numpy.intersect1d(bivbp_rvantpost_vtx, bivbp_junc_vtx)
  bivbp_lvapex_vtx = numpy.intersect1d(bivbp_lvseptfree_vtx, bivbp_lvantpost_vtx)
  write_dummy_tri_surf_from_vtx('{}_{}.surf'.format(bivbp_mesh_base, 'rvapex'), bivbp_rvapex_vtx)
  write_dummy_tri_surf_from_vtx('{}_{}.surf'.format(bivbp_mesh_base, 'lvapex'), bivbp_lvapex_vtx)


  biv_root_dir = os.path.join(ROOT, 'output_pysp', 'biv')
  biv_exout_file = lambda fname :  os.path.join(biv_root_dir, fname)
  if not os.path.exists(biv_root_dir):
    os.mkdir(biv_root_dir)
  biv_mesh_base, biv_mesh_fmt = biv_exout_file('mesh'), mesh_format
  bivbp_biv_map = mt_extract_mesh_and_map_surfaces(bivbp_mesh_base, bivbp_mesh_fmt, '1,2,3,4,6,7', biv_mesh_base,
                                                   smesh_fmt=biv_mesh_fmt, return_mapping=True, stdout=stdout)

  # ===========================================================================
  # UVC GENERATION
  # ===========================================================================

  lv_root_dir = os.path.join(biv_root_dir, 'lv')
  lv_exout_file = lambda fname :  os.path.join(lv_root_dir, fname)
  if not os.path.exists(lv_root_dir):
    os.mkdir(lv_root_dir)
  lv_mesh_base, lv_mesh_fmt = lv_exout_file('mesh'), mesh_format
  biv_lv_map = mt_extract_mesh_and_map_surfaces(biv_mesh_base, biv_mesh_fmt, '1,2,3,4', lv_mesh_base,
                                                smesh_fmt=lv_mesh_fmt, return_mapping=True, stdout=stdout)

  lvpstfree_root_dir = os.path.join(lv_root_dir, 'lvpstfree')
  lvpstfree_exout_file = lambda fname :  os.path.join(lvpstfree_root_dir, fname)
  if not os.path.exists(lvpstfree_root_dir):
    os.mkdir(lvpstfree_root_dir)
  lvpstfree_mesh_base, lvpstfree_mesh_fmt = lvpstfree_exout_file('mesh'), mesh_format
  lv_lvpstfree_map = mt_extract_mesh_and_map_surfaces(lv_mesh_base, lv_mesh_fmt, '2', lvpstfree_mesh_base,
                                                      smesh_fmt=lvpstfree_mesh_fmt, return_mapping=True, stdout=stdout)

  lvpstsept_root_dir = os.path.join(lv_root_dir, 'lvpstsept')
  lvpstsept_exout_file = lambda fname :  os.path.join(lvpstsept_root_dir, fname)
  if not os.path.exists(lvpstsept_root_dir):
    os.mkdir(lvpstsept_root_dir)
  lvpstsept_mesh_base, lvpstsept_mesh_fmt = lvpstsept_exout_file('mesh'), mesh_format
  lv_lvpstsept_map = mt_extract_mesh_and_map_surfaces(lv_mesh_base, lv_mesh_fmt, '4', lvpstsept_mesh_base,
                                                      smesh_fmt=lvpstsept_mesh_fmt, return_mapping=True, stdout=stdout)

  lvantsept_root_dir = os.path.join(lv_root_dir, 'lvantsept')
  lvantsept_exout_file = lambda fname :  os.path.join(lvantsept_root_dir, fname)
  if not os.path.exists(lvantsept_root_dir):
    os.mkdir(lvantsept_root_dir)
  lvantsept_mesh_base, lvantsept_mesh_fmt = lvantsept_exout_file('mesh'), mesh_format
  lv_lvantsept_map = mt_extract_mesh_and_map_surfaces(lv_mesh_base, lv_mesh_fmt, '3', lvantsept_mesh_base,
                                                      smesh_fmt=lvantsept_mesh_fmt, return_mapping=True, stdout=stdout)

  lvantfree_root_dir = os.path.join(lv_root_dir, 'lvantfree')
  lvantfree_exout_file = lambda fname :  os.path.join(lvantfree_root_dir, fname)
  if not os.path.exists(lvantfree_root_dir):
    os.mkdir(lvantfree_root_dir)
  lvantfree_mesh_base, lvantfree_mesh_fmt = lvantfree_exout_file('mesh'), mesh_format
  lv_lvantfree_map = mt_extract_mesh_and_map_surfaces(lv_mesh_base, lv_mesh_fmt, '1', lvantfree_mesh_base,
                                                      smesh_fmt=lvantfree_mesh_fmt, return_mapping=True, stdout=stdout)

  rv_root_dir = os.path.join(biv_root_dir, 'rv')
  rv_exout_file = lambda fname :  os.path.join(rv_root_dir, fname)
  if not os.path.exists(rv_root_dir):
    os.mkdir(rv_root_dir)
  rv_mesh_base, rv_mesh_fmt = rv_exout_file('mesh'), mesh_format
  biv_rv_map = mt_extract_mesh_and_map_surfaces(biv_mesh_base, biv_mesh_fmt, '6,7', rv_mesh_base,
                                                smesh_fmt=rv_mesh_fmt, return_mapping=True, stdout=stdout)

  rvpst_root_dir = os.path.join(rv_root_dir, 'rvpst')
  rvpst_exout_file = lambda fname :  os.path.join(rvpst_root_dir, fname)
  if not os.path.exists(rvpst_root_dir):
    os.mkdir(rvpst_root_dir)
  rvpst_mesh_base, rvpst_mesh_fmt = rvpst_exout_file('mesh'), mesh_format
  rv_rvpst_map = mt_extract_mesh_and_map_surfaces(rv_mesh_base, rv_mesh_fmt, '7', rvpst_mesh_base,
                                                  smesh_fmt=rvpst_mesh_fmt, return_mapping=True, stdout=stdout)

  rvant_root_dir = os.path.join(rv_root_dir, 'rvant')
  rvant_exout_file = lambda fname :  os.path.join(rvant_root_dir, fname)
  if not os.path.exists(rvant_root_dir):
    os.mkdir(rvant_root_dir)
  rvant_mesh_base, rvant_mesh_fmt = rvant_exout_file('mesh'), mesh_format
  rv_rvant_map = mt_extract_mesh_and_map_surfaces(rv_mesh_base, rv_mesh_fmt, '6', rvant_mesh_base,
                                                  smesh_fmt=rvant_mesh_fmt, return_mapping=True, stdout=stdout)

  mt_generate_distance_field(lvpstfree_mesh_base, lvpstfree_mesh_fmt, lvpstfree_exout_file('circ.dat'),
                             '{}_{}'.format(lvpstfree_mesh_base, 'lvantpost'),
                             end_surf='{}_{}'.format(lvpstfree_mesh_base, 'lvseptfree'), stdout=stdout)
  mt_generate_distance_field(lvpstsept_mesh_base, lvpstsept_mesh_fmt, lvpstsept_exout_file('circ.dat'),
                             '{}_{}'.format(lvpstsept_mesh_base, 'lvseptfree'),
                             end_surf='{}_{}'.format(lvpstsept_mesh_base, 'lvantpost'), stdout=stdout)
  mt_generate_distance_field(lvantsept_mesh_base, lvantsept_mesh_fmt, lvantsept_exout_file('circ.dat'),
                             '{}_{}'.format(lvantsept_mesh_base, 'lvantpost'),
                             end_surf='{}_{}'.format(lvantsept_mesh_base, 'lvseptfree'), stdout=stdout)
  mt_generate_distance_field(lvantfree_mesh_base, lvantfree_mesh_fmt, lvantfree_exout_file('circ.dat'),
                             '{}_{}'.format(lvantfree_mesh_base, 'lvseptfree'),
                             end_surf='{}_{}'.format(lvantfree_mesh_base, 'lvantpost'), stdout=stdout)

  mt_generate_distance_field(rvpst_mesh_base, rvpst_mesh_fmt, rvpst_exout_file('circ.dat'),
                             '{}_{}'.format(rvpst_mesh_base, 'junc'),
                             end_surf='{}_{}'.format(rvpst_mesh_base, 'rvantpost'), stdout=stdout)
  mt_generate_distance_field(rvant_mesh_base, rvant_mesh_fmt, rvant_exout_file('circ.dat'),
                             '{}_{}'.format(rvant_mesh_base, 'rvantpost'),
                             end_surf='{}_{}'.format(rvant_mesh_base, 'junc'), stdout=stdout)

  mt_generate_distance_field(lv_mesh_base, lv_mesh_fmt, lv_exout_file('apba.dat'),
                             '{}_{}'.format(lv_mesh_base, 'lvapex'),
                             end_surf='{}_{}'.format(lv_mesh_base, 'lvbase'), stdout=stdout)
  mt_generate_distance_field(rv_mesh_base, rv_mesh_fmt, rv_exout_file('apba.dat'),
                             '{}_{}'.format(rv_mesh_base, 'rvapex'),
                             end_surf='{}_{}'.format(rv_mesh_base, 'rvbase'), stdout=stdout)

  mt_generate_distance_field(lv_mesh_base, lv_mesh_fmt, lv_exout_file('endo.dat'),
                             '{}_{}'.format(lv_mesh_base, 'lvendo'),
                             end_surf='{}_{}'.format(lv_mesh_base, 'lvepi'), stdout=stdout)
  mt_generate_distance_field(rv_mesh_base, rv_mesh_fmt, rv_exout_file('endo.dat'),
                             '{}_{}'.format(rv_mesh_base, 'rvendo'),
                             end_surf='{}_{}'.format(rv_mesh_base, 'rvepi'), stdout=stdout)

  lvpstfree_circ_data = numpy.loadtxt(lvpstfree_exout_file('circ.dat'), dtype=numpy.float64)
  lvpstsept_circ_data = numpy.loadtxt(lvpstsept_exout_file('circ.dat'), dtype=numpy.float64)
  lvantsept_circ_data = numpy.loadtxt(lvantsept_exout_file('circ.dat'), dtype=numpy.float64)
  lvantfree_circ_data = numpy.loadtxt(lvantfree_exout_file('circ.dat'), dtype=numpy.float64)

  rvpst_circ_data = numpy.loadtxt(rvpst_exout_file('circ.dat'), dtype=numpy.float64)
  rvant_circ_data = numpy.loadtxt(rvant_exout_file('circ.dat'), dtype=numpy.float64)

  lv_apba_data = numpy.loadtxt(lv_exout_file('apba.dat'), dtype=numpy.float64)
  lv_endo_data = numpy.loadtxt(lv_exout_file('endo.dat'), dtype=numpy.float64)
  rv_apba_data = numpy.loadtxt(rv_exout_file('apba.dat'), dtype=numpy.float64)
  rv_endo_data = numpy.loadtxt(rv_exout_file('endo.dat'), dtype=numpy.float64)

  num_lv_pnts = read_num_points('{}.{}'.format(lv_mesh_base, pnts_ext))
  num_rv_pnts = read_num_points('{}.{}'.format(rv_mesh_base, pnts_ext))
  num_biv_pnts = read_num_points('{}.{}'.format(biv_mesh_base, pnts_ext))

  rv_circ_data = numpy.zeros(num_rv_pnts, dtype=numpy.float64)
  rv_rvpst_map.insert_data((rvpst_circ_data-1.0)*numpy.pi*0.5, rv_circ_data)
  rv_rvant_map.insert_data(rvant_circ_data*numpy.pi*0.5, rv_circ_data)

  lv_circ_data = numpy.zeros(num_lv_pnts, dtype=numpy.float64)
  lv_lvpstfree_map.insert_data(lvpstfree_circ_data*numpy.pi*0.5-numpy.pi, lv_circ_data)
  lv_lvpstsept_map.insert_data(lvpstsept_circ_data*numpy.pi*0.5-numpy.pi*0.5, lv_circ_data)
  lv_lvantsept_map.insert_data(lvantsept_circ_data*numpy.pi*0.5, lv_circ_data)
  lv_lvantfree_map.insert_data(lvantfree_circ_data*numpy.pi*0.5+numpy.pi*0.5, lv_circ_data)

  uvc_circ_data = numpy.zeros(num_biv_pnts, dtype=numpy.float64)
  biv_rv_map.insert_data(rv_circ_data, uvc_circ_data)
  biv_lv_map.insert_data(lv_circ_data, uvc_circ_data)

  uvc_apba_data = numpy.zeros(num_biv_pnts, dtype=numpy.float64)
  biv_rv_map.insert_data(rv_apba_data, uvc_apba_data)
  biv_lv_map.insert_data(lv_apba_data, uvc_apba_data)

  uvc_endo_data = numpy.zeros(num_biv_pnts, dtype=numpy.float64)
  biv_rv_map.insert_data(rv_endo_data, uvc_endo_data)
  biv_lv_map.insert_data(lv_endo_data, uvc_endo_data)

  uvc_vent_data = numpy.zeros(num_biv_pnts, dtype=numpy.float64)
  lv_vent_data = numpy.full(num_lv_pnts, -1.0, dtype=numpy.float64)
  rv_vent_data = numpy.full(num_rv_pnts, +1.0, dtype=numpy.float64)
  biv_rv_map.insert_data(rv_vent_data, uvc_vent_data)
  biv_lv_map.insert_data(lv_vent_data, uvc_vent_data)

  biv_uvcs = numpy.column_stack((uvc_apba_data, uvc_endo_data, uvc_circ_data, uvc_vent_data))
  uvcs_header = str(num_biv_pnts)
  numpy.savetxt(biv_exout_file('mesh.uvc'), biv_uvcs, header=uvcs_header, comments='')

  if not compute_fibers:
    return 0.0

  # ===========================================================================
  # FIBER GENERATION
  # ===========================================================================

  # ---------------------------------------------------------------------------
  # CREATE DISTANCE-FIELD SOLUTIONS
  # ---------------------------------------------------------------------------

  mt_generate_distance_field(biv_mesh_base, biv_mesh_fmt, biv_exout_file('lvendo_dist.dat'),
                             '{}_{}'.format(biv_mesh_base, 'lvendo'), stdout=stdout)
  biv_lvendo_dist = numpy.loadtxt(biv_exout_file('lvendo_dist.dat'), dtype=numpy.float64)
  biv_lvepi_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(biv_mesh_base, 'lvepi'))
  biv_lvendo_threshold = numpy.max(biv_lvendo_dist.take(biv_lvepi_vtx))+biv_lvendo_dist.max()*0.01
  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('lvendo_dist.dat'),
                           output_file=biv_exout_file('lvendo_dist_elem.dat'))
  biv_lvendo_dist_elem = numpy.loadtxt(biv_exout_file('lvendo_dist_elem.dat'), dtype=numpy.float64)
  biv_lvendo_elem_sel = numpy.where(biv_lvendo_dist_elem < biv_lvendo_threshold)[0]
  bivbp_lvendo_elem_sel = bivbp_biv_map.map_selection(biv_lvendo_elem_sel, nodal_selection=False, map_forward=False)
  print('\n'+'#'*80)
  print('LV-endo threshold value: {}'.format(biv_lvendo_threshold))
  print('#'*80+'\n')

  mt_generate_distance_field(biv_mesh_base, biv_mesh_fmt, biv_exout_file('rvendo_dist.dat'),
                             '{}_{}'.format(biv_mesh_base, 'rvendo'), stdout=stdout)
  biv_rvendo_dist = numpy.loadtxt(biv_exout_file('rvendo_dist.dat'), dtype=numpy.float64)
  biv_rvepi_vtx = read_vtx_from_tri_surf('{}_{}.surf'.format(biv_mesh_base, 'rvepi'))
  biv_rvendo_threshold = numpy.max(biv_rvendo_dist.take(biv_rvepi_vtx))+biv_rvendo_dist.max()*0.01
  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('rvendo_dist.dat'),
                           output_file=biv_exout_file('rvendo_dist_elem.dat'))
  biv_rvendo_dist_elem = numpy.loadtxt(biv_exout_file('rvendo_dist_elem.dat'), dtype=numpy.float64)
  biv_rvendo_elem_sel = numpy.where(biv_rvendo_dist_elem < biv_rvendo_threshold)[0]
  bivbp_rvendo_elem_sel = bivbp_biv_map.map_selection(biv_rvendo_elem_sel, nodal_selection=False, map_forward=False)
  print('\n'+'#'*80)
  print('RV-endo threshold value: {}'.format(biv_rvendo_threshold))
  print('#'*80+'\n')

  mt_extract_elemtags(bivbp_mesh_base, bivbp_mesh_fmt, exout_file('mesh_etags.dat'), stdout=stdout)
  bivbp_elem_tags = numpy.loadtxt(exout_file('mesh_etags.dat'), dtype=numpy.int16)

  bivbp_lvendo_elem_tags = numpy.copy(bivbp_elem_tags)
  bivbp_lvendo_elem_tags[bivbp_lvendo_elem_sel] = 5
  numpy.savetxt(exout_file('mesh_lvendo_etags.tags'), bivbp_lvendo_elem_tags, fmt='%d')
  numpy.savetxt(exout_file('mesh_lvendo_etags.dat'), bivbp_lvendo_elem_tags, fmt='%d')

  bivbp_rvendo_elem_tags = numpy.copy(bivbp_elem_tags)
  bivbp_rvendo_elem_tags[bivbp_rvendo_elem_sel] = 8
  numpy.savetxt(exout_file('mesh_rvendo_etags.tags'), bivbp_rvendo_elem_tags, fmt='%d')
  numpy.savetxt(exout_file('mesh_rvendo_etags.dat'), bivbp_rvendo_elem_tags, fmt='%d')

  ops = [('lvextepi', '5-31,41,46')]    # lv-epi surface
  mt_extract_surfaces(bivbp_mesh_base, bivbp_mesh_fmt, ops,
                      etags_file=exout_file('mesh_lvendo_etags.tags'), stdout=stdout)
  mt_map_surfaces(bivbp_mesh_base, biv_mesh_base, files=['{}_{}'.format(bivbp_mesh_base, 'lvextepi')], stdout=stdout)
  mt_generate_distance_field(biv_mesh_base, biv_mesh_fmt, biv_exout_file('lvendo_sol_raw.dat'),
                             '{}_{}'.format(biv_mesh_base, 'lvextepi'),
                             end_surf='{}_{}'.format(biv_mesh_base, 'lvendo'), stdout=stdout)

  ops = [('rvextepi', '8-36,41,46')]    # lv-epi surface
  mt_extract_surfaces(bivbp_mesh_base, bivbp_mesh_fmt, ops,
                      etags_file=exout_file('mesh_rvendo_etags.tags'), stdout=stdout)
  mt_map_surfaces(bivbp_mesh_base, biv_mesh_base, files=['{}_{}'.format(bivbp_mesh_base, 'rvextepi')], stdout=stdout)
  mt_generate_distance_field(biv_mesh_base, biv_mesh_fmt, biv_exout_file('rvendo_sol_raw.dat'),
                             '{}_{}'.format(biv_mesh_base, 'rvextepi'),
                             end_surf='{}_{}'.format(biv_mesh_base, 'rvendo'), stdout=stdout)

  biv_lvendo_data = numpy.loadtxt(biv_exout_file('lvendo_sol_raw.dat'))
  biv_lvendo_data[biv_lvendo_dist >= biv_lvendo_threshold] = 0.0
  
  biv_rvendo_data = numpy.loadtxt(biv_exout_file('rvendo_sol_raw.dat'))
  biv_rvendo_data[biv_rvendo_dist >= biv_rvendo_threshold] = 0.0

  tmp = numpy.multiply(biv_lvendo_data, biv_rvendo_data)
  biv_lvendo_data = 2.0*biv_lvendo_data - numpy.power(biv_lvendo_data, 2.0) - tmp
  biv_rvendo_data = 2.0*biv_rvendo_data - numpy.power(biv_rvendo_data, 2.0) - tmp
  biv_epi_data = 1.0-(biv_lvendo_data+biv_rvendo_data)

  numpy.savetxt(biv_exout_file('lvendo_sol.dat'), biv_lvendo_data)
  numpy.savetxt(biv_exout_file('rvendo_sol.dat'), biv_rvendo_data)
  numpy.savetxt(biv_exout_file('epi_sol.dat'), biv_epi_data)

  # ---------------------------------------------------------------------------
  # CREATE FIBERS
  # ---------------------------------------------------------------------------

  mt_generate_distance_field(biv_mesh_base, biv_mesh_fmt, biv_exout_file('apba_sol.dat'),
                             '{}_{}'.format(biv_mesh_base, 'lvapex'),
                             end_surf='{}_{}'.format(biv_mesh_base, 'base'), stdout=stdout)

  mt_extract_gradient(biv_mesh_base, biv_mesh_fmt, biv_exout_file('lvendo_sol.dat'),
                      biv_exout_file('lvendo_sol'), normalize=True, stdout=stdout)
  mt_extract_gradient(biv_mesh_base, biv_mesh_fmt, biv_exout_file('rvendo_sol.dat'),
                      biv_exout_file('rvendo_sol'), normalize=True, stdout=stdout)
  mt_extract_gradient(biv_mesh_base, biv_mesh_fmt, biv_exout_file('epi_sol.dat'),
                      biv_exout_file('epi_sol'), normalize=True, stdout=stdout)
  mt_extract_gradient(biv_mesh_base, biv_mesh_fmt, biv_exout_file('apba_sol.dat'),
                      biv_exout_file('apba_sol'), normalize=True, stdout=stdout)

  for _ in range(5):
    mt_smooth_elem_data(biv_mesh_base, biv_exout_file('lvendo_sol.grad.vec'), normalize=True, stdout=stdout)
    mt_smooth_elem_data(biv_mesh_base, biv_exout_file('rvendo_sol.grad.vec'), normalize=True, stdout=stdout)
    mt_smooth_elem_data(biv_mesh_base, biv_exout_file('epi_sol.grad.vec'), normalize=True, stdout=stdout)
    mt_smooth_elem_data(biv_mesh_base, biv_exout_file('apba_sol.grad.vec'), normalize=True, stdout=stdout)

  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('lvendo_sol.dat'),
                           output_file=biv_exout_file('lvendo_sol_elem.dat'), stdout=stdout)
  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('rvendo_sol.dat'),
                           output_file=biv_exout_file('rvendo_sol_elem.dat'), stdout=stdout)
  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('epi_sol.dat'),
                           output_file=biv_exout_file('epi_sol_elem.dat'), stdout=stdout)
  mt_interpolate_node2elem(biv_mesh_base, biv_exout_file('apba_sol.dat'),
                           output_file=biv_exout_file('apba_sol_elem.dat'), stdout=stdout)

  phi_lv = numpy.loadtxt(biv_exout_file('lvendo_sol_elem.dat'), dtype=numpy.float64)
  phi_rv = numpy.loadtxt(biv_exout_file('rvendo_sol_elem.dat'), dtype=numpy.float64)
  phi_epi = numpy.loadtxt(biv_exout_file('epi_sol_elem.dat'), dtype=numpy.float64)
  psi_ab = numpy.loadtxt(biv_exout_file('apba_sol_elem.dat'), dtype=numpy.float64)

  grad_phi_lv = numpy.loadtxt(biv_exout_file('lvendo_sol.grad.vec'), dtype=numpy.float64)
  grad_phi_rv = numpy.loadtxt(biv_exout_file('rvendo_sol.grad.vec'), dtype=numpy.float64)
  grad_phi_epi = numpy.loadtxt(biv_exout_file('epi_sol.grad.vec'), dtype=numpy.float64)
  grad_psi_ab = numpy.loadtxt(biv_exout_file('apba_sol.grad.vec'), dtype=numpy.float64)

  start = time.time()
  elem_lon = assign_fibers_sheets(phi_epi, phi_lv, phi_rv, psi_ab,
                                  grad_phi_epi, grad_phi_lv, grad_phi_rv, grad_psi_ab,
                                  40.0, -50.0, -65.0, 25.0,
                                  num_workers=num_procs)
  end = time.time()
  fib_time = end-start
  print('\n\nFIBER COMPUTATION: {:.6f} sec'.format(fib_time))

  write_fibers(biv_exout_file('mesh.{}'.format(lons_ext)), elem_lon)
  return fib_time


if __name__ == '__main__':
  multiprocessing.set_start_method('fork', force=True)

  num_procs = os.environ.get('OMP_NUM_THREADS', None)
  if num_procs is None:
    raise RuntimeError('Please set OMP_NUM_THREADS environment variable !')
  num_procs = int(num_procs)
  if num_procs < 1:
    raise RuntimeError('Please set number of threads to a value greater than 0 !')

  parser = argparse.ArgumentParser()
  parser.add_argument('--refine-img', action='store_true', default=False,
                      help='Refine image before extracting mesh')
  parser.add_argument('--uvcs-only', action='store_true', default=False,
                      help='Compute UVSs only')
  args = parser.parse_args()

  start = time.time()
  fib_time = main(num_procs, refine_img=args.refine_img, compute_fibers=not args.uvcs_only, stdout=None)
  end = time.time()
  tot_time = end-start
  print('Done, took {:.6f} sec'.format(tot_time))

  fp = open('pysp_wkflw_n{}.log'.format(num_procs), 'a')
  fp.write('{:.6f}  {:.6f}  {:d}\n'.format(tot_time, fib_time, COUNT_MESHTOOL_CALLS))
  fp.close()