#!/usr/bin/env python3

import os
import time
import numpy
import argparse
import pymeshtool
import multiprocessing

from scipy.spatial.transform import Rotation, Slerp



ROOT = os.path.abspath(os.path.dirname(__file__))

input_file = lambda fname : os.path.join(ROOT, fname)
exout_file = lambda fname :  os.path.join(ROOT, 'output_pymt', fname)

GRAD2RAD = numpy.pi/180.0

QUAT_SIGNS = ((+1.0, +1.0, +1.0, +1.0),
              (+1.0, +1.0, +1.0, -1.0),
              (+1.0, +1.0, -1.0, +1.0),
              (+1.0, +1.0, -1.0, -1.0),
              (+1.0, -1.0, +1.0, +1.0),
              (+1.0, -1.0, +1.0, -1.0),
              (+1.0, -1.0, -1.0, +1.0),
              (+1.0, -1.0, -1.0, -1.0),
              (-1.0, +1.0, +1.0, +1.0),
              (-1.0, +1.0, +1.0, -1.0),
              (-1.0, +1.0, -1.0, +1.0),
              (-1.0, +1.0, -1.0, -1.0),
              (-1.0, -1.0, +1.0, +1.0),
              (-1.0, -1.0, +1.0, -1.0),
              (-1.0, -1.0, -1.0, +1.0),
              (-1.0, -1.0, -1.0, -1.0))


def axis(apba_dir, trsm_dir, return_fail=False):
  mat, fail = numpy.identity(3), True

  norm_apba = numpy.linalg.norm(apba_dir)
  norm_trsm = numpy.linalg.norm(trsm_dir)
  dotp = apba_dir.dot(trsm_dir)

  if norm_apba > 0.0 and norm_trsm > 0.0 and abs(dotp) > 0.0:
    e1 = apba_dir/norm_apba
    tmp = trsm_dir-e1.dot(trsm_dir)*e1
    e2 = tmp/numpy.linalg.norm(tmp)
    e0 = numpy.cross(e1, e2)
    mat = numpy.array([e0, e1, e2])
    fail = False

  rot = Rotation.from_matrix(mat)
  if return_fail:
    return rot, fail

  return rot


def orient(rot:Rotation, alpha, beta):
  ca, sa = numpy.cos(-alpha), numpy.sin(-alpha)
  cb, sb = numpy.cos(beta), numpy.sin(beta)

  # rotation around z-axis
  rotA = Rotation.from_matrix([[+ca,-sa,0.0],[+sa,+ca,0.0],[0.0,0.0,1.0]])
  # rotation around x-axis
  rotB = Rotation.from_matrix([[1.0,0.0,0.0],[0.0,+cb,+sb],[0.0,-sb,+cb]])

  return rotB*(rotA*rot)


def bislerp(rotA:Rotation, rotB:Rotation, t):
  qA, qB = rotA.as_quat(scalar_first=True), rotB.as_quat(scalar_first=True)

  qM, max_norm = qA, 0.0
  for sign in QUAT_SIGNS:
    qT = (sign[0]*qA[0], sign[1]*qA[1], sign[2]*qA[2], sign[3]*qA[3])
    qP = (qT[0]*qB[0] - qT[1]*qB[1] - qT[2]*qB[2] - qT[3]*qB[3],
          qT[0]*qB[1] + qT[1]*qB[0] + qT[2]*qB[3] - qT[3]*qB[2],
          qT[0]*qB[2] - qT[1]*qB[3] + qT[2]*qB[0] + qT[3]*qB[1],
          qT[0]*qB[3] + qT[1]*qB[2] - qT[2]*qB[1] + qT[3]*qB[0])
    norm = numpy.sqrt(qP[0]*qP[0] + qP[1]*qP[1] + qP[2]*qP[2] + qP[3]*qP[3])
    if norm > max_norm:
      qM, max_norm = qT, norm

  rotM = Rotation.from_quat(qM, scalar_first=True)
  slerp = Slerp((0.0, 1.0), Rotation.concatenate([rotM, rotB]))

  return slerp(t)


def assign_fibers_sheets(phi_epi, phi_lv, phi_rv, psi_ab,
                         grad_phi_epi, grad_phi_lv, grad_phi_rv, grad_psi_ab,
                         alpha_endo, alpha_epi, beta_endo, beta_epi,
                         num_workers=16):

  assert (phi_epi.ndim == 1) and (grad_phi_epi.ndim == 2) and  (grad_phi_epi.shape[1] == 3) and \
         (phi_epi.size == grad_phi_epi.shape[0]) and (phi_epi.size == phi_lv.size) and \
         (phi_epi.size == phi_rv.size) and (phi_epi.size == psi_ab.size) and \
         (grad_phi_epi.shape == grad_phi_lv.shape) and (grad_phi_epi.shape == grad_phi_rv.shape) and \
         (grad_phi_epi.shape == grad_psi_ab.shape)

  alpha_s = lambda d: (alpha_endo*(1.0-d) - alpha_endo*d)*GRAD2RAD
  alpha_w = lambda d: (alpha_endo*(1.0-d) + alpha_epi*d)*GRAD2RAD
  beta_s = lambda d: (beta_endo*(1.0-d) - beta_endo*d)*GRAD2RAD
  beta_w = lambda d: (beta_endo*(1.0-d) + beta_epi*d)*GRAD2RAD

  num = phi_epi.shape[0]
  shared_array = multiprocessing.Array('f', 6*num)
  chunk_size = num // num_workers
  processes = []

  def assign_fs(shared_array, start, end):
    arr = numpy.frombuffer(shared_array.get_obj(), dtype=numpy.float32)
    count_fails = 0
    for i in range(start, end):
        matFST = numpy.zeros((3, 3), dtype=numpy.float32)

        try:
          apba_dir = grad_psi_ab[i]
          trsm_dir_l, trsm_dir_r, trsm_dir_e = grad_phi_lv[i], grad_phi_rv[i], grad_phi_epi[i]

          ax_l, fail_l = axis(apba_dir, -trsm_dir_l, return_fail=True)
          ax_r, fail_r = axis(apba_dir, +trsm_dir_r, return_fail=True)
          ax_e, fail_e = axis(apba_dir, +trsm_dir_e, return_fail=True)

          fail_l = fail_r = True

          if fail_l and fail_r:
            if not fail_e:
              d = phi_epi[i]
              rotFST = orient(ax_e, alpha_w(d), beta_w(d))
          else:
            if fail_l and not fail_r:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotV = orient(ax_r, alpha_s(d), beta_s(d))
            elif not fail_l and fail_r:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotV = orient(ax_l, alpha_s(d), beta_s(d))
            else:
              d = phi_rv[i]/(phi_lv[i]+phi_rv[i]) if (phi_lv[i]+phi_rv[i]) > 0.0 else 0.0
              rotVl = orient(ax_l, alpha_s(d), beta_s(d))
              rotVr = orient(ax_r, alpha_s(d), beta_s(d))
              rotV = bislerp(rotVl, rotVr, d)

            if fail_e:
              rotFST = rotV
            else:
              d = phi_epi[i]
              rotE = orient(ax_e, alpha_w(d), beta_w(d))
              rotFST = bislerp(rotV, rotE, d)

          if rotFST is not None:
            matFST = rotFST.as_matrix()

        except Exception:
          count_fails += 1

        finally:
          arr[6*i:(6*i+3)] = matFST[0]
          arr[(6*i+3):(6*i+6)] = matFST[1]

  shared_array = multiprocessing.Array('f', 6*num)
  chunk_size = num // num_workers
  processes = []

  for i in range(num_workers):
    start = i * chunk_size
    # Make sure last process covers the remainder
    end = num if i == (num_workers-1) else (i+1)*chunk_size
    p = multiprocessing.Process(target=assign_fs, args=(shared_array, start, end))
    p.start()
    processes.append(p)

  for p in processes:
      p.join()

  # Convert shared array to numpy array for use in main process
  lon = numpy.frombuffer(shared_array.get_obj(), dtype=numpy.float32)

  return lon.reshape(num, 2, 3)


def write_vtx(filename, vtx, domain='extra'):
  npvtx = numpy.sort(numpy.array(vtx, dtype=numpy.int64).flatten())
  hdr = '{}\n{}'.format(len(vtx), domain)
  numpy.savetxt(filename, npvtx, fmt='%ld', header=hdr, comments='')


def main(num_procs, refine_img=False, compute_fibers=True):
  pymeshtool.set_num_par_threads(num_procs)

  io_time = 0.0
  start = time.time()
  img = pymeshtool.Image(input_file('img_bivbp_seg.vtk'))
  end = time.time()
  io_time += end-start

  img.extrude(pymeshtool.ImageExtrusionMode.outwards, 2, 31, new_tag=41, tags=[0])
  img.extrude(pymeshtool.ImageExtrusionMode.outwards, 2, 36, new_tag=46, tags=[0])

  if refine_img:
    img = img.resample(2.0)
    print(img.voxel_size)

  # Extract the BiV Mesh + Blood Pools + Lids
  bivbp_msh = img.extract_mesh(tags=[1,2,3,4,6,7,31,36,41,46], tetrahedralize_mesh=True, scale=1000.0)

  # Smooth all regions but the lids
  bivbp_msh.smooth_mesh([1,2,3,4,6,7,31,36], num_iterations=300, num_laplace_levels=2, smoothing_coeff=0.25)

  # surface opteration
  ops = ['1,2,3,4-31,41',           # lv-epi surface
         '6,7-1,2,3,4,36,46',       # rv-epi surface
         '1,2,3,4:31',              # lv-endo surface
         '1,2,3,4,6,7:36',          # rv-endo surface
         '1,2,3,4:41',              # lv-base surface
         '1,2,3,4,6,7:46',          # rv-base surface
         '1,2,3,4:6,7',             # junction surface
         '6:7',                     # rv-ant-post interface
         '1,2:3,4',                 # lv-sept-free interface
         '1,3:2,4']                 # lv-ant-post interface

  # extract all the surfaces but keep the node indices
  surf_meshes = bivbp_msh.extract_surface(setops=';'.join(ops), reindex_nodes=False,
                                          return_mapping=False)

  # get the nodes spanning the surface
  bivbp_lvepi_vtx = surf_meshes[0].extract_element_nodes()
  bivbp_rvepi_vtx = surf_meshes[1].extract_element_nodes()
  bivbp_lvendo_vtx = surf_meshes[2].extract_element_nodes()
  bivbp_rvendo_vtx = surf_meshes[3].extract_element_nodes()
  bivbp_lvbase_vtx = surf_meshes[4].extract_element_nodes()
  bivbp_rvbase_vtx = surf_meshes[5].extract_element_nodes()
  bivbp_junc_vtx = surf_meshes[6].extract_element_nodes()
  bivbp_rvantpost_vtx = surf_meshes[7].extract_element_nodes()
  bivbp_lvfreesept_vtx = surf_meshes[8].extract_element_nodes()
  bivbp_lvantpost_vtx = surf_meshes[9].extract_element_nodes()

  # delete the surfaces (free memory)
  del surf_meshes

  bivbp_rvapex_vtx = numpy.intersect1d(bivbp_rvantpost_vtx, bivbp_junc_vtx)
  bivbp_lvapex_vtx = numpy.intersect1d(bivbp_lvfreesept_vtx, bivbp_lvantpost_vtx)

  # extract biv-mesh
  biv_msh, bivbp_biv_map = bivbp_msh.extract_mesh([1,2,3,4,6,7], return_mapping=True)

  # ===========================================================================
  # UVC GENERATION
  # ===========================================================================

  # extract lv-mesh and lv-mesh
  lv_msh, biv_lv_map = biv_msh.extract_mesh([1,2,3,4], return_mapping=True)
  rv_msh, biv_rv_map = biv_msh.extract_mesh([6,7], return_mapping=True)

  biv_lvepi_vtx = bivbp_biv_map.map_selection(bivbp_lvepi_vtx)
  biv_rvepi_vtx = bivbp_biv_map.map_selection(bivbp_rvepi_vtx)
  biv_lvendo_vtx = bivbp_biv_map.map_selection(bivbp_lvendo_vtx)
  biv_rvendo_vtx = bivbp_biv_map.map_selection(bivbp_rvendo_vtx)
  biv_lvbase_vtx = bivbp_biv_map.map_selection(bivbp_lvbase_vtx)
  biv_rvbase_vtx = bivbp_biv_map.map_selection(bivbp_rvbase_vtx)
  biv_rvbase_vtx = bivbp_biv_map.map_selection(bivbp_rvbase_vtx)
  biv_junc_vtx = bivbp_biv_map.map_selection(bivbp_junc_vtx)
  biv_rvantpost_vtx = bivbp_biv_map.map_selection(bivbp_rvantpost_vtx)
  biv_lvfreesept_vtx = bivbp_biv_map.map_selection(bivbp_lvfreesept_vtx)
  biv_lvantpost_vtx = bivbp_biv_map.map_selection(bivbp_lvantpost_vtx)
  biv_rvapex_vtx = bivbp_biv_map.map_selection(bivbp_rvapex_vtx)
  biv_lvapex_vtx = bivbp_biv_map.map_selection(bivbp_lvapex_vtx)

  biv_etags_orig = biv_msh.element_tags
  biv_msh.element_tags = (biv_etags_orig // 5)*5 + 1
  start = time.time()
  biv_msh.save(exout_file('msh_biv'), format=pymeshtool.MeshOutputFormat.carp_txt)
  end = time.time()
  io_time += end-start
  biv_msh.element_tags = biv_etags_orig

  lvpstfree_msh, lv_lvpstfree_map = lv_msh.extract_mesh(2, return_mapping=True)
  lvpstsept_msh, lv_lvpstsept_map = lv_msh.extract_mesh(4, return_mapping=True)
  lvantsept_msh, lv_lvantsept_map = lv_msh.extract_mesh(3, return_mapping=True)
  lvantfree_msh, lv_lvantfree_map = lv_msh.extract_mesh(1, return_mapping=True)

  biv_lvpstfree_map = (biv_lv_map@lv_lvpstfree_map)
  biv_lvpstsept_map = (biv_lv_map@lv_lvpstsept_map)
  biv_lvantsept_map = (biv_lv_map@lv_lvantsept_map)
  biv_lvantfree_map = (biv_lv_map@lv_lvantfree_map)

  rvpst_msh, rv_rvpst_map = rv_msh.extract_mesh(7, return_mapping=True)
  rvant_msh, rv_rvant_map = rv_msh.extract_mesh(6, return_mapping=True)

  biv_rvpst_map = (biv_rv_map@rv_rvpst_map)
  biv_rvant_map = (biv_rv_map@rv_rvant_map)

  lvpstfree_circ_data = lvpstfree_msh.generate_distancefield(biv_lvpstfree_map.map_selection(biv_lvantpost_vtx),
                                                             end_vtx=biv_lvpstfree_map.map_selection(biv_lvfreesept_vtx))
  lvpstsept_circ_data = lvpstsept_msh.generate_distancefield(biv_lvpstsept_map.map_selection(biv_lvfreesept_vtx),
                                                             end_vtx=biv_lvpstsept_map.map_selection(biv_lvantpost_vtx))
  lvantsept_circ_data = lvantsept_msh.generate_distancefield(biv_lvantsept_map.map_selection(biv_lvantpost_vtx),
                                                             end_vtx=biv_lvantsept_map.map_selection(biv_lvfreesept_vtx))
  lvantfree_circ_data = lvantfree_msh.generate_distancefield(biv_lvantfree_map.map_selection(biv_lvfreesept_vtx),
                                                             end_vtx=biv_lvantfree_map.map_selection(biv_lvantpost_vtx))

  rvpst_circ_data = rvpst_msh.generate_distancefield(biv_rvpst_map.map_selection(biv_junc_vtx),
                                                     end_vtx=biv_rvpst_map.map_selection(biv_rvantpost_vtx))
  rvant_circ_data = rvant_msh.generate_distancefield(biv_rvant_map.map_selection(biv_rvantpost_vtx),
                                                     end_vtx=biv_rvant_map.map_selection(biv_junc_vtx))

  lv_apba_data = lv_msh.generate_distancefield(biv_lv_map.map_selection(biv_lvapex_vtx),
                                               end_vtx=biv_lv_map.map_selection(biv_lvbase_vtx))
  lv_endo_data = lv_msh.generate_distancefield(biv_lv_map.map_selection(biv_lvendo_vtx),
                                               end_vtx=biv_lv_map.map_selection(biv_lvepi_vtx))
  rv_apba_data = rv_msh.generate_distancefield(biv_rv_map.map_selection(biv_rvapex_vtx),
                                               end_vtx=biv_rv_map.map_selection(biv_rvbase_vtx))
  rv_endo_data = rv_msh.generate_distancefield(biv_rv_map.map_selection(biv_rvendo_vtx),
                                               end_vtx=biv_rv_map.map_selection(biv_rvepi_vtx))

  uvc_apba_data = biv_rv_map.prolongate(rv_apba_data, default=-1.0)
  biv_lv_map.insert_data(lv_apba_data, uvc_apba_data)

  uvc_endo_data = biv_rv_map.prolongate(rv_endo_data, default=-1.0)
  biv_lv_map.insert_data(lv_endo_data, uvc_endo_data)

  uvc_circ_data = biv_rvpst_map.prolongate((rvpst_circ_data-1.0)*numpy.pi*0.5)
  biv_rvant_map.insert_data(rvant_circ_data*numpy.pi*0.5, uvc_circ_data)
  biv_lvpstfree_map.insert_data(lvpstfree_circ_data*numpy.pi*0.5-numpy.pi, uvc_circ_data)
  biv_lvpstsept_map.insert_data(lvpstsept_circ_data*numpy.pi*0.5-numpy.pi*0.5, uvc_circ_data)
  biv_lvantsept_map.insert_data(lvantsept_circ_data*numpy.pi*0.5, uvc_circ_data)
  biv_lvantfree_map.insert_data(lvantfree_circ_data*numpy.pi*0.5+numpy.pi*0.5, uvc_circ_data)

  uvc_vent_data = numpy.full(biv_msh.num_points, 1.0)
  lv_vent_data = numpy.full(lv_msh.num_points, -1.0)
  biv_lv_map.insert_data(lv_vent_data, uvc_vent_data)

  biv_uvcs = numpy.column_stack((uvc_apba_data, uvc_endo_data, uvc_circ_data, uvc_vent_data))
  uvcs_header = str(biv_msh.num_points)
  start = time.time()
  numpy.savetxt(exout_file('msh_biv.uvc'), biv_uvcs, header=uvcs_header, comments='')
  end = time.time()
  io_time += end-start

  if not compute_fibers:
    return 0.0, io_time

  # ===========================================================================
  # FIBER GENERATION
  # ===========================================================================

  # ---------------------------------------------------------------------------
  # CREATE DISTANCE-FIELD SOLUTIONS
  # ---------------------------------------------------------------------------

  biv_lvendo_dist = biv_msh.generate_distancefield(biv_lvendo_vtx)
  biv_lvendo_threshold = numpy.max(biv_lvendo_dist.take(biv_lvepi_vtx))+biv_lvendo_dist.max()*0.01
  biv_lvendo_dist_elem = biv_msh.interpolate_node2elem(biv_lvendo_dist)
  biv_lvendo_elem_sel = numpy.where(biv_lvendo_dist_elem < biv_lvendo_threshold)[0]
  bivbp_lvendo_elem_sel = bivbp_biv_map.map_selection(biv_lvendo_elem_sel, nodal_selection=False, map_forward=False)
  print('\n'+'#'*80)
  print('LV-endo threshold value: {}'.format(biv_lvendo_threshold))
  print('#'*80+'\n')

  biv_rvendo_dist = biv_msh.generate_distancefield(biv_rvendo_vtx)
  biv_rvendo_threshold = numpy.max(biv_rvendo_dist.take(biv_rvepi_vtx))+biv_rvendo_dist.max()*0.01
  biv_rvendo_dist_elem = biv_msh.interpolate_node2elem(biv_rvendo_dist)
  biv_rvendo_elem_sel = numpy.where(biv_rvendo_dist_elem < biv_rvendo_threshold)[0]
  bivbp_rvendo_elem_sel = bivbp_biv_map.map_selection(biv_rvendo_elem_sel, nodal_selection=False, map_forward=False)
  print('\n'+'#'*80)
  print('RV-endo threshold value: {}'.format(biv_rvendo_threshold))
  print('#'*80+'\n')

  bivbp_old_tags = numpy.copy(bivbp_msh.element_tags)

  bivbp_msh.element_tags[bivbp_lvendo_elem_sel] = 5
  ops = ['5-31,41,46'] # extented lv-epi surface
  surf_meshes = bivbp_msh.extract_surface(setops=';'.join(ops), reindex_nodes=False,
                                          return_mapping=False)
  biv_extlvepi_vtx = bivbp_biv_map.map_selection(surf_meshes[0].extract_element_nodes())
  biv_lvendo_data = biv_msh.generate_distancefield(biv_extlvepi_vtx, end_vtx=biv_lvendo_vtx)
  biv_lvendo_data[biv_lvendo_dist >= biv_lvendo_threshold] = 0.0
  bivbp_msh.element_tags = bivbp_old_tags

  bivbp_msh.element_tags[bivbp_rvendo_elem_sel] = 8
  ops = ['8-36,41,46'] # extented rv-epi surface
  surf_meshes = bivbp_msh.extract_surface(setops=';'.join(ops), reindex_nodes=False,
                                          return_mapping=False)
  biv_extrvepi_vtx = bivbp_biv_map.map_selection(surf_meshes[0].extract_element_nodes())
  biv_rvendo_data = biv_msh.generate_distancefield(biv_extrvepi_vtx, end_vtx=biv_rvendo_vtx)
  biv_rvendo_data[biv_rvendo_dist >= biv_rvendo_threshold] = 0.0
  bivbp_msh.element_tags = bivbp_old_tags

  tmp = numpy.multiply(biv_lvendo_data, biv_rvendo_data)
  biv_lvendo_data = 2.0*biv_lvendo_data - numpy.power(biv_lvendo_data, 2.0) - tmp
  biv_rvendo_data = 2.0*biv_rvendo_data - numpy.power(biv_rvendo_data, 2.0) - tmp
  biv_epi_data = 1.0-(biv_lvendo_data+biv_rvendo_data)

  # ---------------------------------------------------------------------------
  # CREATE FIBERS
  # ---------------------------------------------------------------------------
  biv_apba_data = biv_msh.generate_distancefield(biv_lvapex_vtx, end_vtx=numpy.union1d(biv_lvbase_vtx, biv_rvbase_vtx))

  phi_lv = biv_lvendo_data
  phi_rv = biv_rvendo_data
  phi_epi = biv_epi_data
  psi_ab = biv_apba_data

  norm_thr = 1.0e-8
  _, grad_phi_lv = biv_msh.extract_gradient(phi_lv, normalize=True, norm_threshold=norm_thr)
  _, grad_phi_rv = biv_msh.extract_gradient(phi_rv, normalize=True, norm_threshold=norm_thr)
  _, grad_phi_epi = biv_msh.extract_gradient(phi_epi, normalize=True, norm_threshold=norm_thr)
  _, grad_psi_ab = biv_msh.extract_gradient(psi_ab, normalize=True, norm_threshold=norm_thr)

  for _ in range(5):
    grad_phi_lv = biv_msh.interpolate_node2elem(biv_msh.interpolate_elem2node(grad_phi_lv, normalize=True), normalize=True)
    grad_phi_rv = biv_msh.interpolate_node2elem(biv_msh.interpolate_elem2node(grad_phi_rv, normalize=True), normalize=True)
    grad_phi_epi = biv_msh.interpolate_node2elem(biv_msh.interpolate_elem2node(grad_phi_epi, normalize=True), normalize=True)
    grad_psi_ab = biv_msh.interpolate_node2elem(biv_msh.interpolate_elem2node(grad_psi_ab, normalize=True), normalize=True)

  phi_lv = biv_msh.interpolate_node2elem(phi_lv)
  phi_rv = biv_msh.interpolate_node2elem(phi_rv)
  phi_epi = biv_msh.interpolate_node2elem(phi_epi)
  psi_ab = biv_msh.interpolate_node2elem(psi_ab)

  start = time.time()
  elem_lon = assign_fibers_sheets(phi_epi, phi_lv, phi_rv, psi_ab,
                                  grad_phi_epi, grad_phi_lv, grad_phi_rv, grad_psi_ab,
                                  40.0, -50.0, -65.0, 25.0,
                                  num_workers=num_procs)
  end = time.time()
  fib_time = end-start
  print('\n\nFIBER COMPUTATION: {:.6f} sec'.format(fib_time))

  start = time.time()
  pymeshtool.save_fibers(elem_lon, exout_file('msh_biv.lon'))
  end = time.time()
  io_time += end-start

  return fib_time, io_time


if __name__ == '__main__':
  multiprocessing.set_start_method('fork', force=True)

  num_procs = os.environ.get('OMP_NUM_THREADS', None)
  if num_procs is None:
    raise RuntimeError('Please set OMP_NUM_THREADS environment variable !')
  num_procs = int(num_procs)
  if num_procs < 1:
    raise RuntimeError('Please set number of threads to a value greater than 0 !')

  parser = argparse.ArgumentParser()
  parser.add_argument('--refine-img', action='store_true', default=False,
                      help='Refine image before extracting mesh')
  parser.add_argument('--uvcs-only', action='store_true', default=False,
                      help='Compute UVSs only')
  args = parser.parse_args()

  start = time.time()
  fib_time, io_time = main(num_procs, refine_img=args.refine_img, compute_fibers=not args.uvcs_only)
  end = time.time()
  tot_time = end-start
  print('Done, took {:.6f} sec'.format(tot_time))

  fp = open('pymt_wkflw_n{}.log'.format(num_procs), 'a')
  fp.write('{:.6f}  {:.6f}  {:.6f}\n'.format(tot_time, fib_time, io_time))
  fp.close()