From 553b3aae7c0052c847493ef8c3dea9515012fddc Mon Sep 17 00:00:00 2001
From: wenxingfang <1473717798@qq.com>
Date: Fri, 16 Jun 2023 16:32:22 +0800
Subject: [PATCH] add TrackHeedSimTool for DC cell simulation

---
 Examples/options/tut_detsim_SDT_Heed.py       | 207 ++++++
 .../src/Edm4hepWriterAnaElemTool.cpp          |  19 +
 .../DetSimAna/src/Edm4hepWriterAnaElemTool.h  |   3 +
 Simulation/DetSimDedx/CMakeLists.txt          |  13 +
 .../DetSimDedx/src/TrackHeedSimTool.cpp       | 646 ++++++++++++++++++
 Simulation/DetSimDedx/src/TrackHeedSimTool.h  | 156 +++++
 .../include/DetSimInterface/IDedxSimTool.h    |   2 +
 cmake/FindOnnxRuntime.cmake                   |  35 +
 8 files changed, 1081 insertions(+)
 create mode 100644 Examples/options/tut_detsim_SDT_Heed.py
 create mode 100644 Simulation/DetSimDedx/src/TrackHeedSimTool.cpp
 create mode 100644 Simulation/DetSimDedx/src/TrackHeedSimTool.h
 create mode 100644 cmake/FindOnnxRuntime.cmake

diff --git a/Examples/options/tut_detsim_SDT_Heed.py b/Examples/options/tut_detsim_SDT_Heed.py
new file mode 100644
index 00000000..049c3c22
--- /dev/null
+++ b/Examples/options/tut_detsim_SDT_Heed.py
@@ -0,0 +1,207 @@
+#!/usr/bin/env python
+
+import os
+import sys
+# sys.exit(0)
+
+from Gaudi.Configuration import *
+
+##############################################################################
+# Random Number Svc
+##############################################################################
+from Configurables import RndmGenSvc, HepRndm__Engine_CLHEP__RanluxEngine_
+
+seed = [42]
+
+# rndmengine = HepRndm__Engine_CLHEP__RanluxEngine_() # The default engine in Gaudi
+rndmengine = HepRndm__Engine_CLHEP__HepJamesRandom_("RndmGenSvc.Engine") # The default engine in Geant4
+rndmengine.SetSingleton = True
+rndmengine.Seeds = seed
+
+rndmgensvc = RndmGenSvc("RndmGenSvc")
+rndmgensvc.Engine = rndmengine.name()
+
+
+##############################################################################
+# Event Data Svc
+##############################################################################
+from Configurables import k4DataSvc
+dsvc = k4DataSvc("EventDataSvc")
+
+
+##############################################################################
+# Geometry Svc
+##############################################################################
+
+# geometry_option = "CepC_v4-onlyTracker.xml"
+geometry_option = "det.xml"
+#geometry_option = "CepC_v4.xml"
+det_root = "DETDRIFTCHAMBERROOT"
+#det_root = "DETCEPCV4ROOT"#"DETDRIFTCHAMBERROOT"
+if not os.getenv(det_root):
+    print("Can't find the geometry. Please setup envvar %s."%det_root )
+    sys.exit(-1)
+
+geometry_path = os.path.join(os.getenv(det_root), "compact", geometry_option)
+if not os.path.exists(geometry_path):
+    print("Can't find the compact geometry file: %s"%geometry_path)
+    sys.exit(-1)
+
+from Configurables import GeomSvc
+geosvc = GeomSvc("GeomSvc")
+print('geometry_path=',geometry_path)
+geosvc.compact = geometry_path
+
+##############################################################################
+# Physics Generator
+##############################################################################
+from Configurables import GenAlgo
+from Configurables import GtGunTool
+from Configurables import StdHepRdr
+from Configurables import SLCIORdr
+from Configurables import HepMCRdr
+from Configurables import GenPrinter
+
+gun = GtGunTool("GtGunTool")
+# gun.Particles = ["pi+"]
+# gun.EnergyMins = [100.] # GeV
+# gun.EnergyMaxs = [100.] # GeV
+
+gun.Particles = ["e-"]
+
+# gun.PositionXs = [100.] # mm
+# gun.PositionYs = [100.] # mm
+# gun.PositionZs = [0.] # mm
+
+
+gun.EnergyMins = [10] # GeV
+gun.EnergyMaxs = [10] # GeV
+
+gun.ThetaMins = [80] # rad; 45deg
+gun.ThetaMaxs = [90.] # rad; 45deg
+
+gun.PhiMins = [0] # rad; 0deg
+gun.PhiMaxs = [360.] # rad; 360deg
+
+# stdheprdr = StdHepRdr("StdHepRdr")
+# stdheprdr.Input = "/cefs/data/stdhep/CEPC250/2fermions/E250.Pbhabha.e0.p0.whizard195/bhabha.e0.p0.00001.stdhep"
+
+# lciordr = SLCIORdr("SLCIORdr")
+# lciordr.Input = "/cefs/data/stdhep/lcio250/signal/Higgs/E250.Pbbh.whizard195/E250.Pbbh_X.e0.p0.whizard195/Pbbh_X.e0.p0.00001.slcio"
+
+# hepmcrdr = HepMCRdr("HepMCRdr")
+# hepmcrdr.Input = "example_UsingIterators.txt"
+
+genprinter = GenPrinter("GenPrinter")
+
+genalg = GenAlgo("GenAlgo")
+genalg.GenTools = ["GtGunTool"]
+# genalg.GenTools = ["StdHepRdr"]
+# genalg.GenTools = ["StdHepRdr", "GenPrinter"]
+# genalg.GenTools = ["SLCIORdr", "GenPrinter"]
+# genalg.GenTools = ["HepMCRdr", "GenPrinter"]
+
+##############################################################################
+# Detector Simulation
+##############################################################################
+from Configurables import DetSimSvc
+
+detsimsvc = DetSimSvc("DetSimSvc")
+
+# from Configurables import ExampleAnaElemTool
+# example_anatool = ExampleAnaElemTool("ExampleAnaElemTool")
+
+from Configurables import DetSimAlg
+
+detsimalg = DetSimAlg("DetSimAlg")
+detsimalg.RandomSeeds = seed
+
+if int(os.environ.get("VIS", 0)):
+    detsimalg.VisMacs = ["vis.mac"]
+
+detsimalg.RunCmds = [
+#    "/tracking/verbose 1",
+]
+
+from Configurables import DummyFastSimG4Tool
+dummy_fastsim_tool = DummyFastSimG4Tool("DummyFastSimG4Tool")
+
+detsimalg.FastSimG4Tools = [
+#    "DummyFastSimG4Tool"
+]
+
+detsimalg.AnaElems = [
+    # example_anatool.name()
+    # "ExampleAnaElemTool"
+    "Edm4hepWriterAnaElemTool"
+]
+detsimalg.RootDetElem = "WorldDetElemTool"
+
+from Configurables import AnExampleDetElemTool
+example_dettool = AnExampleDetElemTool("AnExampleDetElemTool")
+
+from Configurables import CalorimeterSensDetTool
+from Configurables import DriftChamberSensDetTool
+
+calo_sensdettool = CalorimeterSensDetTool("CalorimeterSensDetTool")
+driftchamber_sensdettool = DriftChamberSensDetTool("DriftChamberSensDetTool")
+
+#dedxoption = "DummyDedxSimTool"
+#dedxoption = "BetheBlochEquationDedxSimTool"
+dedxoption = "TrackHeedSimTool"
+
+driftchamber_sensdettool.DedxSimTool = dedxoption
+
+from Configurables import DummyDedxSimTool
+from Configurables import BetheBlochEquationDedxSimTool
+from Configurables import TrackHeedSimTool
+
+if dedxoption == "DummyDedxSimTool":
+    dedx_simtool = DummyDedxSimTool("DummyDedxSimTool")
+elif dedxoption == "BetheBlochEquationDedxSimTool":
+    dedx_simtool = BetheBlochEquationDedxSimTool("BetheBlochEquationDedxSimTool")
+    dedx_simtool.material_Z = 2
+    dedx_simtool.material_A = 4
+    dedx_simtool.scale = 10
+    dedx_simtool.resolution = 0.0001
+elif dedxoption == "TrackHeedSimTool":
+    dedx_simtool = TrackHeedSimTool("TrackHeedSimTool")
+    dedx_simtool.only_primary = False#True
+    dedx_simtool.use_max_step = True#True
+    dedx_simtool.max_step = 1#mm
+    #dedx_simtool.he   = 50
+    #dedx_simtool.isob = 50
+    #dedx_simtool.gas_file ="/junofs/users/wxfang/MyGit/tmp/check_G4FastSim_20210121/CEPCSW/Digitisers/DigiGarfield/He_50_isobutane_50.gas" 
+    dedx_simtool.he   = 90
+    dedx_simtool.isob = 10
+    #dedx_simtool.gas_file ="/junofs/users/wxfang/MyGit/tmp/check_G4FastSim_20210121/CEPCSW/Digitisers/DigiGarfield/he_90_isobutane_10.gas" 
+    #dedx_simtool.IonMobility_file ="/junofs/users/wxfang/MyGit/tmp/check_G4FastSim_20210121/CEPCSW/Digitisers/DigiGarfield/IonMobility_He+_He.txt" 
+    dedx_simtool.gas_file         ="he_90_isobutane_10.gas"
+    dedx_simtool.IonMobility_file ="IonMobility_He+_He.txt"
+    dedx_simtool.save_mc = True##IF this is False then ... 
+    dedx_simtool.debug = False 
+    dedx_simtool.sim_pulse = True
+    #dedx_simtool.model='/junofs/users/wxfang/MyGit/tmp/fork_cepcsw_20220418/CEPCSW/Digitisers/SimCurrentONNX/src/model_test.onnx'
+    #dedx_simtool.model='/junofs/users/wxfang/MyGit/tmp/fork_cepcsw_20220418/CEPCSW/Digitisers/SimCurrentONNX/src/model_90He10C4H10_18mm.onnx'
+    dedx_simtool.model='model_90He10C4H10_18mm.onnx'
+    dedx_simtool.batchsize = 100
+
+##############################################################################
+# POD I/O
+##############################################################################
+from Configurables import PodioOutput
+out = PodioOutput("outputalg")
+out.filename = "detsim_heed.root"
+out.outputCommands = ["keep *"]
+
+##############################################################################
+# ApplicationMgr
+##############################################################################
+
+from Configurables import ApplicationMgr
+ApplicationMgr( TopAlg = [genalg, detsimalg, out],
+                EvtSel = 'NONE',
+                EvtMax = 20,
+                ExtSvc = [rndmengine, rndmgensvc, dsvc, geosvc],
+                OutputLevel=INFO
+)
diff --git a/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.cpp b/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.cpp
index cfbf31df..8e121bac 100644
--- a/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.cpp
+++ b/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.cpp
@@ -14,6 +14,7 @@
 #include "DDG4/Geant4HitCollection.h"
 #include "DDG4/Geant4Data.h"
 #include "DetSimSD/Geant4Hits.h"
+#include <DetSimInterface/IDedxSimTool.h>
 
 DECLARE_COMPONENT(Edm4hepWriterAnaElemTool)
 
@@ -42,6 +43,7 @@ Edm4hepWriterAnaElemTool::BeginOfRunAction(const G4Run*) {
     } else {
         error() << "Failed to find GeomSvc." << endmsg;
     }
+
 }
 
 void
@@ -87,6 +89,14 @@ Edm4hepWriterAnaElemTool::BeginOfEventAction(const G4Event* anEvent) {
     // reset
     m_track2primary.clear();
 
+    auto SimPIonCol =  m_SimPrimaryIonizationCol.createAndPut();
+    ToolHandleArray<IDedxSimTool> tmp_m_dedx_tools;
+    tmp_m_dedx_tools.push_back("TrackHeedSimTool");
+    for (auto dedxtool: tmp_m_dedx_tools) {
+        debug() << "reset dedx_tool:" <<dedxtool << endmsg;
+        dedxtool->reset();
+    }
+ 
 }
 
 void
@@ -94,6 +104,15 @@ Edm4hepWriterAnaElemTool::EndOfEventAction(const G4Event* anEvent) {
 
     msg() << "mcCol size (after simulation) : " << mcCol->size() << endmsg;
     // save all data
+    auto SimPrimaryIonizationCol =  m_SimPrimaryIonizationCol.get();
+    msg() << "SimPrimaryIonizationCol size ="<<SimPrimaryIonizationCol->size()<<endmsg;
+    ToolHandleArray<IDedxSimTool> tmp_m_dedx_tools;
+    tmp_m_dedx_tools.push_back("TrackHeedSimTool");
+    for (auto dedxtool: tmp_m_dedx_tools) {
+        debug() << "call endOfEvent() for dedx_tool:" << dedxtool << endmsg;
+        dedxtool->endOfEvent();
+    }
+ 
 
     // create collections.
     auto trackercols = m_trackerCol.createAndPut();
diff --git a/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.h b/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.h
index e4415b17..60cc930a 100644
--- a/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.h
+++ b/Simulation/DetSimAna/src/Edm4hepWriterAnaElemTool.h
@@ -15,6 +15,7 @@
 #include "edm4hep/SimTrackerHitCollection.h"
 #include "edm4hep/SimCalorimeterHitCollection.h"
 #include "edm4hep/CaloHitContributionCollection.h"
+#include "edm4hep/SimPrimaryIonizationClusterCollection.h"
 
 class Edm4hepWriterAnaElemTool: public extends<AlgTool, IAnaElemTool> {
 
@@ -129,6 +130,8 @@ private:
             "DriftChamberHitsCollection", 
             Gaudi::DataHandle::Writer, this};
 
+    // for ionized electron
+    DataHandle<edm4hep::SimPrimaryIonizationClusterCollection>    m_SimPrimaryIonizationCol{"SimPrimaryIonizationClusterCollection", Gaudi::DataHandle::Writer, this};
 
 private:
     // in order to associate the hit contribution with the primary track,
diff --git a/Simulation/DetSimDedx/CMakeLists.txt b/Simulation/DetSimDedx/CMakeLists.txt
index fb8e4d3a..03bf40d9 100644
--- a/Simulation/DetSimDedx/CMakeLists.txt
+++ b/Simulation/DetSimDedx/CMakeLists.txt
@@ -1,15 +1,28 @@
 
 find_package(Geant4 REQUIRED ui_all vis_all)
 include(${Geant4_USE_FILE})
+find_package(Garfield REQUIRED)
+message("libonnxruntime ${OnnxRuntime_LIBRARY}")
+find_package(OnnxRuntime REQUIRED)
 
+message("libonnxruntime ${OnnxRuntime_LIBRARY}")
 gaudi_add_module(DetSimDedx
                  SOURCES src/DummyDedxSimTool.cpp
                          src/BetheBlochEquationDedxSimTool.cpp
                          src/GFDndxSimTool.cpp
+                         src/TrackHeedSimTool.cpp
 
                  LINK DetSimInterface
+                      DetInterface
+                      DetSegmentation
                       ${DD4hep_COMPONENT_LIBRARIES}
                       Gaudi::GaudiKernel
                       EDM4HEP::edm4hep EDM4HEP::edm4hepDict
+                      k4FWCore::k4FWCore
+                      Garfield::Garfield
+                      ${OnnxRuntime_LIBRARY}
+                      #/cvmfs/sft.cern.ch/lcg/views/LCG_103/x86_64-centos7-gcc11-opt/lib/libonnxruntime.so
+                      ${CLHEP_LIBRARIES}
+ 
 )
 
diff --git a/Simulation/DetSimDedx/src/TrackHeedSimTool.cpp b/Simulation/DetSimDedx/src/TrackHeedSimTool.cpp
new file mode 100644
index 00000000..fb03cefe
--- /dev/null
+++ b/Simulation/DetSimDedx/src/TrackHeedSimTool.cpp
@@ -0,0 +1,646 @@
+#include "TrackHeedSimTool.h"
+#include "G4Step.hh"
+#include "G4SystemOfUnits.hh"
+#include <G4VProcess.hh>
+#include "G4TransportationManager.hh"
+#include "G4Navigator.hh"
+#include "G4VPhysicalVolume.hh"
+#include "G4LogicalVolume.hh"
+#include <G4VTouchable.hh>
+#include "DDG4/Geant4Converter.h"
+#include "DetSegmentation/GridDriftChamber.h"
+
+
+#include <math.h>
+#include <cmath>
+#include <iostream>
+#include <time.h>
+#include "CLHEP/Random/RandGauss.h"
+
+
+DECLARE_COMPONENT(TrackHeedSimTool)
+
+double TrackHeedSimTool::dedx(const edm4hep::MCParticle& mc) {return 0;}
+double TrackHeedSimTool::dndx(double betagamma) {return 0;}
+
+double TrackHeedSimTool::dedx(const G4Step* Step)
+{
+
+    clock_t t0 = clock();
+    double de = 0;
+    float cm_to_mm = 10; 
+    G4Step* aStep = const_cast<G4Step*>(Step);
+    G4Track* g4Track  =  aStep->GetTrack();
+    int pdg_code        = g4Track->GetParticleDefinition()->GetPDGEncoding();
+    G4double pdg_mass   = g4Track->GetParticleDefinition()->GetPDGMass();
+    G4double pdg_charge = g4Track->GetParticleDefinition()->GetPDGCharge();
+    const G4VProcess* creatorProcess = g4Track->GetCreatorProcess();
+    const G4String tmp_str_pro = (creatorProcess !=0) ? creatorProcess->GetProcessName() : "normal";
+    G4double gammabeta=aStep->GetPreStepPoint()->GetBeta() * aStep->GetPreStepPoint()->GetGamma();
+    if(g4Track->GetParticleDefinition()->GetPDGCharge() ==0) return 0;//skip neutral particle 
+    if(gammabeta<0.01)return m_eps;//too low momentum
+    if(m_only_primary.value() && g4Track->GetParentID() != 0) return m_eps;
+    if(g4Track->GetKineticEnergy() <=0) return 0;
+    if(pdg_code == 11 && (tmp_str_pro=="phot" || tmp_str_pro=="hIoni" || tmp_str_pro=="eIoni" || tmp_str_pro=="muIoni" || tmp_str_pro=="ionIoni" ) ) return m_eps;//skip the electron produced by Ioni, because it is already simulated by TrackHeed
+    if(m_particle_map.find(pdg_code) == m_particle_map.end() ) return m_eps;
+    edm4hep::SimPrimaryIonizationClusterCollection* SimPrimaryIonizationCol = nullptr;
+    edm4hep::MCParticleCollection* mcCol = nullptr;
+    try{
+        SimPrimaryIonizationCol =  const_cast<edm4hep::SimPrimaryIonizationClusterCollection*>(m_SimPrimaryIonizationCol.get());
+        mcCol = const_cast<edm4hep::MCParticleCollection*>(m_mc_handle.get());
+    }
+    catch(...){
+        G4cout<<"Error! Can't find collection in event, please check it have been createAndPut() in Begin of event"<<G4endl;
+        G4cout<<"SimPrimaryIonizationCol="<<SimPrimaryIonizationCol<<",mcCol="<<mcCol<<G4endl;
+        throw "stop here!";
+    }
+
+    G4double track_KE   = aStep->GetPreStepPoint()->GetKineticEnergy();
+    G4double track_time = aStep->GetPreStepPoint()->GetGlobalTime();
+    G4double track_dx     = aStep->GetPreStepPoint()->GetMomentumDirection ().x();
+    G4double track_dy     = aStep->GetPreStepPoint()->GetMomentumDirection ().y();
+    G4double track_dz     = aStep->GetPreStepPoint()->GetMomentumDirection ().z();
+    G4double track_length = aStep->GetStepLength();
+    G4double position_x   = aStep->GetPreStepPoint()->GetPosition().x();
+    G4double position_y   = aStep->GetPreStepPoint()->GetPosition().y();
+    G4double position_z   = aStep->GetPreStepPoint()->GetPosition().z();
+    int track_ID = g4Track->GetTrackID();
+    int Parent_ID = g4Track->GetParentID();
+    bool update_ke = true;
+    if(m_use_max_step.value()){
+        bool do_sim = false;
+        if(m_isFirst){
+            m_pre_x  = aStep->GetPreStepPoint()->GetPosition().x();
+            m_pre_y  = aStep->GetPreStepPoint()->GetPosition().y();
+            m_pre_z  = aStep->GetPreStepPoint()->GetPosition().z();
+            m_pre_dx = aStep->GetPreStepPoint()->GetMomentumDirection().x();
+            m_pre_dy = aStep->GetPreStepPoint()->GetMomentumDirection().y();
+            m_pre_dz = aStep->GetPreStepPoint()->GetMomentumDirection().z();
+            m_pre_t  = aStep->GetPreStepPoint()->GetGlobalTime();
+            m_post_point = aStep->GetPostStepPoint();
+            m_total_range += track_length;
+            m_current_track_ID = g4Track->GetTrackID();
+            m_current_Parent_ID = g4Track->GetParentID();
+            m_pdg_code = g4Track->GetParticleDefinition()->GetPDGEncoding(); 
+            m_isFirst = false;    
+            m_pa_KE =  aStep->GetPreStepPoint()->GetKineticEnergy();
+        }
+        else{
+        
+            if(g4Track->GetTrackID() != m_current_track_ID){
+                do_sim = true;
+                m_change_track = true;
+                update_ke = false;
+            }
+            else{
+                m_post_point = aStep->GetPostStepPoint();
+                m_total_range += track_length;
+            }
+        }
+        if(m_total_range/CLHEP::mm >= m_max_step.value()){
+            do_sim = true;
+        }
+        if(do_sim){
+            track_KE = m_pa_KE;
+            pdg_code = m_pdg_code;
+            track_length = m_total_range;
+            track_time = m_pre_t;
+            track_dx = m_pre_dx;
+            track_dy = m_pre_dy;
+            track_dz = m_pre_dz;
+            position_x = m_pre_x;
+            position_y = m_pre_y;
+            position_z = m_pre_z;
+            track_ID = m_current_track_ID;
+            Parent_ID = m_current_Parent_ID;
+            if(m_change_track){
+                m_pre_x  = aStep->GetPreStepPoint()->GetPosition().x();
+                m_pre_y  = aStep->GetPreStepPoint()->GetPosition().y();
+                m_pre_z  = aStep->GetPreStepPoint()->GetPosition().z();
+                m_pre_dx = aStep->GetPreStepPoint()->GetMomentumDirection().x();
+                m_pre_dy = aStep->GetPreStepPoint()->GetMomentumDirection().y();
+                m_pre_dz = aStep->GetPreStepPoint()->GetMomentumDirection().z();
+                m_pre_t  = aStep->GetPreStepPoint()->GetGlobalTime();
+                m_post_point = aStep->GetPostStepPoint(); 
+                m_total_range = aStep->GetStepLength();
+                m_current_track_ID = g4Track->GetTrackID();
+                m_current_Parent_ID = g4Track->GetParentID();
+                m_pdg_code = g4Track->GetParticleDefinition()->GetPDGEncoding(); 
+                m_change_track = false;
+            }
+            else{
+                m_pre_x  = aStep->GetPostStepPoint()->GetPosition().x();
+                m_pre_y  = aStep->GetPostStepPoint()->GetPosition().y();
+                m_pre_z  = aStep->GetPostStepPoint()->GetPosition().z();
+                m_pre_dx = aStep->GetPostStepPoint()->GetMomentumDirection().x();
+                m_pre_dy = aStep->GetPostStepPoint()->GetMomentumDirection().y();
+                m_pre_dz = aStep->GetPostStepPoint()->GetMomentumDirection().z();
+                m_pre_t  = aStep->GetPostStepPoint()->GetGlobalTime();
+                m_total_range = 0;
+            }
+        }
+        else return m_eps;
+    }
+
+    float init_x = 10;//cm
+    float init_y = -10;//cm
+    /*
+    if(pdg_code == 11 && track_KE/CLHEP::keV < m_delta_threshold.value()){
+        int nc = 0, ni=0;
+        m_track->TransportDeltaElectron(init_x, init_y, 0, track_time/CLHEP::ns, track_KE/CLHEP::eV, track_dx, track_dy, track_dz, nc, ni);
+        for (int j = 0; j < nc; ++j) {
+            double xe = 0., ye = 0., ze = 0., te = 0., ee = 0.;
+            double dx = 0., dy = 0., dz = 0.;
+            m_track->GetElectron(j, xe, ye, ze, te, ee, dx, dy, dz);
+            auto ehit  = SimIonizationCol->create();
+            ehit.setTime(te);
+            double epos[3] = { cm_to_mm*( (xe - init_x)+position_x/CLHEP::cm) , cm_to_mm*((ye - init_y)+position_y/CLHEP::cm), cm_to_mm*(ze + position_z/CLHEP::cm)};
+            ehit.setPosition(edm4hep::Vector3d(epos));
+            ehit.setType(11);
+        }
+        g4Track->SetTrackStatus(fStopAndKill);
+        return 0;
+    }
+    */
+    clock_t t01 = clock();
+    //cmp.SetMagneticField(0., 0., -3.);
+    m_track->SetParticle(m_particle_map[pdg_code]);
+    
+    bool change_KE = false;
+    if( abs(m_previous_KE -(track_KE/CLHEP::eV) )/(track_KE/CLHEP::eV) > m_change_threshold.value()) {
+        change_KE = true;
+        m_previous_KE = track_KE/CLHEP::eV;
+    }
+    
+    bool change_ID = false;
+    if(m_previous_track_ID != track_ID ){
+        change_ID = true;
+        m_previous_track_ID = track_ID; 
+        m_previous_KE = track_KE/CLHEP::eV;
+    }
+    if(change_ID || change_KE ){
+        m_track->SetKineticEnergy(track_KE/CLHEP::eV);
+    }
+    
+  
+    m_track->EnableOneStepFly(true);
+    m_track->SetSteppingLimits( track_length/CLHEP::cm, 1000, 0.1, 0.2);
+    clock_t t012 = clock();
+    m_track->NewTrack(init_x, init_y, 0, track_time/CLHEP::ns, track_dx, track_dy, track_dz);//cm
+    double xc = 0., yc = 0., zc = 0., tc = 0., ec = 0., extra = 0.;
+    int nc = 0;
+    int ic = 0;
+    int first=true;
+    clock_t t02 = clock();
+    while (m_track->GetCluster(xc, yc, zc, tc, nc, ec, extra)) {
+        //auto chit = SimHitCol->create();
+        auto chit = SimPrimaryIonizationCol->create();
+        chit.setTime(tc);
+        double cpos[3] = { cm_to_mm*( (xc - init_x)+position_x/CLHEP::cm) , cm_to_mm*((yc - init_y)+position_y/CLHEP::cm), cm_to_mm*(zc + position_z/CLHEP::cm)};
+        chit.setPosition(edm4hep::Vector3d(cpos));
+        //float cmom[3]  = {0,0,0};
+        //getMom(ec, 1, 0, 0, cmom);//FIXME direction is not important?
+        chit.setType(0);//default
+        if(m_save_cellID) chit.setCellID( getCellID(cpos[0], cpos[1], cpos[2]) );
+        if(m_save_mc && Parent_ID == 0 && track_ID <= mcCol->size() && mcCol ){ 
+            chit.setMCParticle(  mcCol->at(track_ID-1) );
+            //std::cout<<"mc obj index="<<mcCol->at(track_ID-1).getObjectID().index<<std::endl;
+            //std::cout<<"mc obj index1="<<chit.getMCParticle().getObjectID().index<<std::endl;
+        }
+        de += ec;
+        for (int j = 0; j < nc; ++j) {
+            double xe = 0., ye = 0., ze = 0., te = 0., ee = 0.;
+            double dx = 0., dy = 0., dz = 0.;
+            m_track->GetElectron(j, xe, ye, ze, te, ee, dx, dy, dz);
+            //auto ehit = SimHitCol->create();
+            //auto ehit = SimIonizationCol->create();
+            //ehit.setPrimaryIonization(chit);
+            chit.addToElectronTime(te);
+            //ehit.setTime(te);
+            double epos[3] = { cm_to_mm*( (xe - init_x)+position_x/CLHEP::cm) , cm_to_mm*((ye - init_y)+position_y/CLHEP::cm), cm_to_mm*(ze + position_z/CLHEP::cm)};
+            //ehit.setPosition(edm4hep::Vector3d(epos));
+            //ehit.setPosition(edm4hep::Vector3d(epos));
+            chit.addToElectronPosition(edm4hep::Vector3d(epos));
+            //if(m_save_mc && Parent_ID == 0 && track_ID <= mcCol->size() && mcCol ){ 
+            //    ehit.setMCParticle(  mcCol->at(track_ID-1) );
+                //ehit.setMcParticleObjID( mcCol->at(track_ID-1).id() );
+                //ehit.setMcParticleColID( mcCol->at(track_ID-1).getObjectID().collectionID );
+            //}
+            /* //no sense of conductor electron
+            float emom[3] = {0,0,0};
+            getMom(ee, dx, dy, dz, emom); 
+            ehit.setMomentum(edm4hep::Vector3f(emom));
+            */
+            //if(m_save_cellID) ehit.setCellID( getCellID(epos[0], epos[1], epos[2]) );
+            if(m_save_cellID) chit.addToElectronCellID( getCellID(epos[0], epos[1], epos[2]) );
+            //ehit.setQuality(2);
+            //ehit.setType(0);//default
+        }
+    }
+    double Dedx = (de*1e-6/(track_length/CLHEP::cm) ) ;//MeV/cm
+    double new_KE = track_KE/CLHEP::MeV - de*1e-6; 
+    if( update_ke ){
+        g4Track->SetKineticEnergy ( new_KE*CLHEP::MeV );
+        aStep->GetPostStepPoint()->SetKineticEnergy ( new_KE*CLHEP::MeV );
+        m_pa_KE = new_KE;
+    }
+    else{
+        m_pa_KE = aStep->GetPreStepPoint()->GetKineticEnergy(); 
+    }
+    m_tot_edep += de;
+    m_tot_length += track_length;
+    return Dedx;
+}
+
+
+long long TrackHeedSimTool::getCellID(float x, float y, float z)
+{
+    float MM_2_CM = 0.1;
+    G4Navigator* gNavigator = G4TransportationManager::GetTransportationManager()->GetNavigatorForTracking();
+    G4ThreeVector global(x,y,z);
+    dd4hep::sim::Geant4VolumeManager volMgr = dd4hep::sim::Geant4Mapping::instance().volumeManager();
+    G4VPhysicalVolume* pv = gNavigator->LocateGlobalPointAndSetup( global, 0, true);
+    if(!pv) return 0;
+    G4TouchableHistory *hist = gNavigator->CreateTouchableHistory();
+    dd4hep::VolumeID volID  = volMgr.volumeID(hist);
+    const G4AffineTransform & affine = gNavigator->GetGlobalToLocalTransform();
+    G4ThreeVector local = affine.TransformPoint(global);
+    dd4hep::Position loc(local.x()*MM_2_CM, local.y()*MM_2_CM, local.z()*MM_2_CM);
+    dd4hep::Position glob(global.x()*MM_2_CM, global.y()*MM_2_CM, global.z()*MM_2_CM);
+    dd4hep::VolumeID cID = m_segmentation->cellID(loc,glob,volID);
+    
+    if(m_debug){
+        TVector3 Wstart(0,0,0);
+        TVector3 Wend  (0,0,0);
+        m_segmentation->cellposition(cID, Wstart, Wend);
+        std::cout<<"Name="<<pv->GetName()<<",CopyNo="<<pv->GetCopyNo()<<",cID="<<cID<<",volID="<<volID<<",glob="<<glob<<",loc="<<loc<<",ws_x="<<Wstart.X()<<",y="<<Wstart.Y()<<",z="<<Wstart.Z()<<",we_x="<<Wend.X()<<",y="<<Wend.Y()<<",z="<<Wend.Z()<<std::endl;
+    }
+    delete hist;
+    return cID;
+}
+
+void TrackHeedSimTool::getMom(float ee, float dx, float dy,float dz, float mom[3])
+{
+    double tot_E = 0.511*1e6 + ee;//eV
+    double Mom = sqrt(tot_E*tot_E - pow(0.511*1e6,2) );
+    double mom_direction =  sqrt(dx*dx + dy*dy + dz*dz);
+    if (mom_direction == 0){
+        mom[0] = 0;
+        mom[1] = 0;
+        mom[2] = 0;
+    }
+    else{
+        double scale = Mom/mom_direction;
+        mom[0] = scale*dx/1e9;
+        mom[1] = scale*dy/1e9;
+        mom[2] = scale*dz/1e9;
+    }
+}
+
+StatusCode TrackHeedSimTool::initialize()
+{
+
+  m_geosvc = service<IGeomSvc>("GeomSvc");
+  if ( !m_geosvc )  throw "TrackHeedSimTool :Failed to find GeomSvc ...";
+  m_dd4hep = m_geosvc->lcdd();
+  if ( !m_dd4hep )  throw "TrackHeedSimTool :Failed to get dd4hep::Detector ...";
+  m_readout = new dd4hep::Readout( m_dd4hep->readout(m_readout_name) );
+  if ( !m_readout )  throw "TrackHeedSimTool :Failed to get readout ...";
+  m_segmentation = dynamic_cast<dd4hep::DDSegmentation::GridDriftChamber*>(m_readout->segmentation().segmentation());
+  if ( !m_segmentation )  throw "TrackHeedSimTool :Failed to get segmentation ...";
+
+  m_particle_map[ 11] = "e-";
+  m_particle_map[-11] = "e+";
+  m_particle_map[ 13] = "mu-";
+  m_particle_map[-13] = "mu+";
+  m_particle_map[ 211] = "pi+";
+  m_particle_map[-211] = "pi-";
+  m_particle_map[ 321] = "K+";
+  m_particle_map[-321] = "K-";
+  m_particle_map[2212] = "p";
+  m_particle_map[-2212] = "pbar";
+  m_particle_map[700201] = "d";
+  m_particle_map[700202] = "alpha";
+
+  m_gas.SetComposition("he", m_he,"isobutane", m_isob);
+  m_gas.SetTemperature(293.15);
+  m_gas.SetPressure(760.0);
+  m_gas.SetMaxElectronEnergy(200.);
+  m_gas.EnablePenningTransfer(0.44, 0.0, "He");
+  m_gas.LoadGasFile(m_gas_file.value());
+  m_gas.LoadIonMobility(m_IonMobility.value());
+  //std::this_thread::sleep_for(std::chrono::milliseconds(m_delay_time));
+  //m_gas.LoadGasFile("/junofs/users/wxfang/MyGit/tmp/check_G4FastSim_20210121/CEPCSW/Digitisers/DigiGarfield/He_50_isobutane_50.gas");
+  //m_gas.LoadIonMobility("/junofs/users/wxfang/MyGit/tmp/check_G4FastSim_20210121/CEPCSW/Digitisers/DigiGarfield/IonMobility_He+_He.txt");
+  /*
+  m_gas.SetComposition("he", 90.,"isobutane", 10.);  // cepc gas
+  m_gas.SetPressure(760.0);
+  m_gas.SetTemperature(293.15);
+  m_gas.SetFieldGrid(100., 100000., 20, true);
+  m_gas.GenerateGasTable(10);
+  m_gas.WriteGasFile("he_90_isobutane_10.gas");
+  */
+
+  cmp.SetMedium(&m_gas);
+  // Field Wire radius [cm]
+  const double rFWire = 110.e-4;
+  // Signa Wire radius [cm]
+  const double rSWire = 25.e-4;
+  // Cell radius [cm]
+  float rCell = 50;//As the ionization process is almost not effected by cell geometry and wire voltage. Here the radius is to make sure the ionization process is completed.
+  // Voltages
+  const double vSWire = 2000.;
+  const double vFWire = 0.;
+  // Add the signal wire in the centre.
+  cmp.AddWire(0, 0, 2 * rSWire, vSWire, "s");
+  // Add the field wire around the signal wire.
+  cmp.AddWire(-rCell, -rCell, 2 * rFWire, vFWire, "f");
+  cmp.AddWire(    0., -rCell, 2 * rFWire, vFWire, "f");
+  cmp.AddWire( rCell, -rCell, 2 * rFWire, vFWire, "f");
+  cmp.AddWire(-rCell,     0., 2 * rFWire, vFWire, "f");
+  cmp.AddWire( rCell,     0., 2 * rFWire, vFWire, "f");
+  cmp.AddWire(-rCell,  rCell, 2 * rFWire, vFWire, "f");
+  cmp.AddWire(    0.,  rCell, 2 * rFWire, vFWire, "f");
+  cmp.AddWire( rCell,  rCell, 2 * rFWire, vFWire, "f");
+  if(m_BField !=0 ) cmp.SetMagneticField(0., 0., m_BField);
+  cmp.AddReadout("s");
+
+  
+  ///
+  /// Make a sensor.
+  ///
+  m_sensor = new Sensor(); 
+  m_sensor->AddComponent(&cmp);
+  m_sensor->AddElectrode(&cmp, "s");
+  // Set the signal time window. [ns]
+  const double tstep = 0.5;
+  const double tmin = -0.5 * 0.5;
+  const unsigned int nbins = 1000;
+  m_sensor->SetTimeWindow(tmin, tstep, nbins);
+  m_sensor->ClearSignal();
+
+
+
+  m_track = new Garfield::TrackHeed();
+  //track->EnableDebugging();
+  m_track->SetSensor(m_sensor);
+  m_track->EnableDeltaElectronTransport();
+  //track->DisableDeltaElectronTransport();
+  m_track->EnableMagneticField();
+  m_track->EnableElectricField();//almost no effect here
+
+  m_current_track_ID = 0;
+  m_previous_track_ID =0;
+  m_previous_KE = 0;
+  m_current_Parent_ID = -1;
+  m_change_track = false;
+  m_total_range = 0;
+  m_isFirst = false;
+  m_tot_edep = 0;
+  m_tot_length = 0;
+  m_pa_KE =0;
+  m_pdg_code = 0;
+  m_pre_x  = 0;
+  m_pre_y  = 0;
+  m_pre_z  = 0;
+  m_pre_dx = 0;
+  m_pre_dy = 0;
+  m_pre_dz = 0;
+  m_pre_t  = 0;
+
+
+   // for NN pulse simulation//
+   m_env = std::make_shared<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
+   m_seesion_options = std::make_shared<Ort::SessionOptions>();
+   m_seesion_options->SetIntraOpNumThreads(m_intra_op_nthreads);
+   m_seesion_options->SetInterOpNumThreads(m_inter_op_nthreads);
+   if(m_debug) std::cout << "before load model " << m_model_file.value() << std::endl;
+   m_session = std::make_shared<Ort::Session>(*m_env, m_model_file.value().c_str(), *m_seesion_options);
+   if(m_debug) std::cout << "after load model " << m_model_file.value() << std::endl;
+   // lambda function to print the dims.
+   auto dims_str = [&](const auto& dims) {
+      return std::accumulate(dims.begin(), dims.end(), std::to_string(dims[0]),
+                             [](const std::string& a, int64_t b){
+                                 return a + "x" + std::to_string(b);
+                             });
+   };
+   // prepare the input
+   auto num_input_nodes = m_session->GetInputCount();
+   if(m_debug) std::cout << "num_input_nodes: " << num_input_nodes << std::endl;
+   for (size_t i = 0; i < num_input_nodes; ++i) {
+      auto name = m_session->GetInputNameAllocated(i, m_allocator);
+      m_inputNodeNameAllocatedStrings.push_back(std::move(name));
+      m_input_node_names.push_back(m_inputNodeNameAllocatedStrings.back().get());
+
+      Ort::TypeInfo type_info  = m_session->GetInputTypeInfo(i);
+      auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+      auto dims = tensor_info.GetShape();
+      //dims[0] = 1; //wxfang, FIXME, if it is -1 (dynamic axis), need overwrite it manually
+      dims[0] = 10; //wxfang, FIXME, if it is -1 (dynamic axis), need overwrite it manually
+      m_input_node_dims.push_back(dims);
+
+
+      if(m_debug) std::cout<< "[" << i << "]"
+              << " input_name: " << m_inputNodeNameAllocatedStrings.back().get()
+              << " ndims: " << dims.size()
+              << " dims: " << dims_str(dims)
+              << std::endl;
+   }
+   // prepare the output
+   size_t num_output_nodes = m_session->GetOutputCount();
+   for(std::size_t i = 0; i < num_output_nodes; i++) {
+       auto output_name = m_session->GetOutputNameAllocated(i, m_allocator);
+       m_outputNodeNameAllocatedStrings.push_back(std::move(output_name));
+       m_output_node_names.push_back(m_outputNodeNameAllocatedStrings.back().get());
+       Ort::TypeInfo type_info        = m_session->GetOutputTypeInfo(i);
+       auto tensor_info               = type_info.GetTensorTypeAndShapeInfo();
+       ONNXTensorElementDataType type = tensor_info.GetElementType();
+       m_output_node_dims               = tensor_info.GetShape();
+       if(m_debug) std::cout << "[" << i << "]"
+               << " output_name: " << m_outputNodeNameAllocatedStrings.back().get()
+               << " ndims: " << m_output_node_dims.size()
+               << " dims: " << dims_str(m_output_node_dims)
+               << std::endl;
+
+   }
+
+  return StatusCode::SUCCESS;
+}
+
+void TrackHeedSimTool::wire_xy(float x1, float y1, float z1, float x2, float y2, float z2, float z, float &x, float &y){
+    //linear function: 
+    //(x-x1)/(x2-x1)=(y-y1)/(y2-y1)=(z-z1)/(z2-z1)
+    x = x1+(x2-x1)*(z-z1)/(z2-z1);
+    y = y1+(y2-y1)*(z-z1)/(z2-z1);
+}
+
+float TrackHeedSimTool::xy2phi(float x, float y){
+    float phi = acos(x/sqrt(x*x+y*y));
+    if(y < 0) phi = 2*M_PI-phi;
+    return phi; 
+}
+void TrackHeedSimTool::getLocal(float x1, float y1, float x2, float y2, float& dx, float& dy){
+    /*  .    
+        .     *(x2,y2)
+        .   .
+        . .    
+        *(x1, y1)        
+        .  
+        .  
+        . 
+        o   
+    */
+    float mo1 = sqrt(x1*x1+y1*y1);
+    float mo2 = sqrt((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1) );
+    float costheta = (x1*(x2-x1)+y1*(y2-y1))/(mo1*mo2);
+    dy = mo2*costheta;
+    dx = xy2phi(x2,y2)>xy2phi(x1,y1) ? mo2*sqrt(1-costheta*costheta) : -mo2*sqrt(1-costheta*costheta) ;
+}
+
+float* TrackHeedSimTool::NNPred(std::vector<float>& inputs)
+{
+
+
+    std::vector<Ort::Value> input_tensors;
+    auto& dims = m_input_node_dims[0];
+    //std::cout << "inputs.size()="<<inputs.size() << std::endl;
+    dims[0] = int(inputs.size()/3);
+    Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+    // prepare a dummy input for the model
+    
+    
+    auto input_tensor = Ort::Value::CreateTensor(info,
+                                                 inputs.data(),
+                                                 inputs.size(),
+                                                 dims.data(),
+                                                 dims.size());
+    
+    input_tensors.push_back(std::move(input_tensor));
+    auto output_tensors = m_session->Run(Ort::RunOptions{ nullptr }, m_input_node_names.data(), input_tensors.data(), input_tensors.size(), m_output_node_names.data(), m_output_node_names.size());
+    //const auto& output_tensor = output_tensors[0];
+    auto& output_tensor = output_tensors[0];
+    int num_elements = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount();
+    //std::cout << "output_tensor num_elements=" << num_elements<< std::endl;
+
+    float* vec2 = new float[num_elements];
+    std::memcpy(vec2, output_tensor.GetTensorMutableData<float>(), num_elements * sizeof(float));
+    /*
+    for (int k=0;k<num_elements;k++){
+       std::cout<<"k="<<k<< ",v=" <<vec2[k]<< std::endl;
+    }
+    */ 
+    return vec2;
+}
+
+
+void TrackHeedSimTool::endOfEvent() {
+    if(m_sim_pulse){
+        edm4hep::SimPrimaryIonizationClusterCollection* SimPrimaryIonizationCol = nullptr;
+        try{
+            SimPrimaryIonizationCol =  const_cast<edm4hep::SimPrimaryIonizationClusterCollection*>(m_SimPrimaryIonizationCol.get());
+        }
+        catch(...){
+            G4cout<<"Error! Can't find collection in event, please check it have been createAndPut() in Begin of event"<<G4endl;
+            G4cout<<"SimPrimaryIonizationCol="<<SimPrimaryIonizationCol<<G4endl;
+            throw "stop here!";
+        }
+        if(m_debug) G4cout<<"SimPrimaryIonizationCol size="<<SimPrimaryIonizationCol->size()<<G4endl;
+        clock_t t01 = clock();
+        std::vector<float> inputs;
+        std::vector<unsigned long> indexs_c;
+        std::vector<unsigned long> indexs_i;
+        std::map<unsigned long long, std::vector<std::pair<float, float> > > id_pulse_map;
+        for (unsigned long z=0; z<SimPrimaryIonizationCol->size(); z++) {
+            for (unsigned long k=0; k<SimPrimaryIonizationCol->at(z).electronCellID_size(); k++) {
+                //auto simIon = SimIonizationCol->at(k);
+                auto position = SimPrimaryIonizationCol->at(z).getElectronPosition(k);//mm
+                auto cellId = SimPrimaryIonizationCol->at(z).getElectronCellID(k);
+                TVector3 Wstart(0,0,0);
+                TVector3 Wend  (0,0,0);
+                m_segmentation->cellposition(cellId, Wstart, Wend);
+                float dd4hep_mm = dd4hep::mm;
+                Wstart =(1/dd4hep_mm)* Wstart;// from DD4HEP cm to mm
+                Wend   =(1/dd4hep_mm)* Wend  ;
+                //std::cout<<"cellid="<<cellId<<",s_x="<<Wstart.x()<<",s_y="<<Wstart.y()<<",s_z="<<Wstart.z()<<",E_x="<<Wend.x()<<",E_y="<<Wend.y()<<",E_z="<<Wend.z()<<std::endl;
+                float wire_x = 0;
+                float wire_y = 0;
+                double pos_z = position[2];
+                wire_xy(Wend.x(), Wend.y(), Wend.z(), Wstart.x(), Wstart.y(), Wstart.z(), pos_z, wire_x, wire_y);
+                float local_x = 0;
+                float local_y = 0;
+                getLocal(wire_x, wire_y, position[0], position[1], local_x, local_y);
+                //std::cout<<"pos_z="<<pos_z<<",wire_x="<<wire_x<<",wire_y="<<wire_y<<",position[0]="<<position[0]<<",position[1]="<<position[1]<<",local_x="<<local_x<<",local_y="<<local_y<<",dr="<<sqrt(local_x*local_x+local_y*local_y)<<",dr1="<<sqrt( (wire_x-position[0])*(wire_x-position[0])+(wire_y-position[1])*(wire_y-position[1]) )<<std::endl;
+                float m_x_scale = 1;
+                float m_y_scale = 1;
+                local_x = local_x/m_x_scale;//FIXME, default is 18mm x 18mm, the real cell size maybe a bit different. need konw size of cell for each layer, and do normalization
+                local_y = local_y/m_y_scale;
+                float noise = CLHEP::RandGauss::shoot(0,1);
+                inputs.push_back(local_x);
+                inputs.push_back(local_y);
+                inputs.push_back(noise  );
+                indexs_c.push_back(z);
+                indexs_i.push_back(k);
+                if(indexs_c.size()==m_batchsize){
+                    float* res = NNPred(inputs);
+                    for(unsigned int i=0; i<m_batchsize; i++){
+                        float tmp_time = res[i*2  ]*m_time_scale + m_time_shift;// in ns
+                        float tmp_amp  = res[i*2+1]*m_amp_scale  + m_amp_shift ;
+                        //unsigned long tmp_index = indexs.at(i);
+                        //tmp_pluse.setCellID(SimIonizationCol->at(tmp_index).getCellID());
+                        //tmp_pluse.setTime(tmp_time + SimIonizationCol->at(tmp_index).getTime());//ns
+                        //tmp_pluse.setValue(tmp_amp);
+                        //tmp_pluse.setType(SimIonizationCol->at(tmp_index).getType());
+                        //tmp_pluse.setSimIonization(SimIonizationCol->at(tmp_index));
+                        auto ion_time = SimPrimaryIonizationCol->at(indexs_c.at(i)).getElectronTime(indexs_i.at(i));
+                        id_pulse_map[indexs_c.at(i)].push_back(std::make_pair(tmp_time+ion_time,tmp_amp) );
+                    }
+                    inputs.clear();
+                    indexs_c.clear();
+                    indexs_i.clear();
+                    delete [] res;
+                }
+            } //end of k
+        }//end of z
+        if(indexs_c.size()!=0){
+            float* res = NNPred(inputs);
+            for(unsigned int i=0; i<indexs_c.size(); i++){
+                float tmp_time = res[i*2  ]*m_time_scale + m_time_shift;
+                float tmp_amp  = res[i*2+1]*m_amp_scale  + m_amp_shift ;
+                //tmp_pluse.setCellID(SimIonizationCol->at(tmp_index).getCellID());
+                //tmp_pluse.setTime(tmp_time + SimIonizationCol->at(tmp_index).getTime());//ns
+                //tmp_pluse.setValue(tmp_amp);
+                //tmp_pluse.setType(SimIonizationCol->at(tmp_index).getType());
+                //tmp_pluse.setSimIonization(SimIonizationCol->at(tmp_index));
+                //id_pulse_map[SimIonizationCol->at(tmp_index).getCellID()].push_back(std::make_pair(tmp_pluse.getTime(), tmp_pluse.getValue() ) );
+                auto ion_time = SimPrimaryIonizationCol->at(indexs_c.at(i)).getElectronTime(indexs_i.at(i));
+                id_pulse_map[indexs_c.at(i)].push_back(std::make_pair(tmp_time+ion_time,tmp_amp) );
+            }
+            inputs.clear();
+            indexs_c.clear();
+            indexs_i.clear();
+            delete [] res;
+        }
+        for(auto iter = id_pulse_map.begin(); iter != id_pulse_map.end(); iter++){
+            edm4hep::MutableSimPrimaryIonizationCluster dcIonCls = SimPrimaryIonizationCol->at(iter->first); 
+            for(unsigned int i=0; i< iter->second.size(); i++){
+                auto tmp_time = iter->second.at(i).first ;
+                auto tmp_amp  = iter->second.at(i).second;
+                dcIonCls.addToPulseTime(tmp_time);  
+                dcIonCls.addToPulseAmplitude(tmp_amp);  
+            }
+            if(dcIonCls.electronPosition_size() != dcIonCls.pulseTime_size()){
+                G4cout<<"Error ion size != pulse size"<<G4endl;
+                throw "stop here!";
+            }
+        }
+        clock_t t02 = clock();
+        if(m_debug) std::cout<<"time for Pulse Simulation=" << (double)(t02 - t01) / CLOCKS_PER_SEC <<" seconds"<< std::endl;
+    }
+}
+
+StatusCode TrackHeedSimTool::finalize()
+{
+    //if(m_debug)std::cout << "m_tot_edep="<<m_tot_edep<<" eV"<<std::endl;
+    //std::cout << "m_tot_length="<<m_tot_length<<" mm"<<std::endl;
+    return StatusCode::SUCCESS;
+}
diff --git a/Simulation/DetSimDedx/src/TrackHeedSimTool.h b/Simulation/DetSimDedx/src/TrackHeedSimTool.h
new file mode 100644
index 00000000..d152fb97
--- /dev/null
+++ b/Simulation/DetSimDedx/src/TrackHeedSimTool.h
@@ -0,0 +1,156 @@
+#ifndef TrackHeedSimTool_h
+#define TrackHeedSimTool_h
+
+#include "k4FWCore/DataHandle.h"
+#include "GaudiKernel/MsgStream.h"
+#include "DetSimInterface/IDedxSimTool.h"
+#include <GaudiKernel/AlgTool.h>
+#include "edm4hep/MCParticle.h"
+#include "edm4hep/MCParticleCollection.h"
+#include "edm4hep/SimPrimaryIonizationClusterCollection.h"
+#include "TVector3.h"
+#include <G4StepPoint.hh>
+
+#include "DD4hep/Segmentations.h"
+#include "DD4hep/Printout.h"
+#include "DD4hep/Detector.h"
+#include "DetInterface/IGeomSvc.h"
+#include "DetSegmentation/GridDriftChamber.h"
+
+
+#include "Garfield/ViewCell.hh"
+#include "Garfield/ViewDrift.hh"
+#include "Garfield/ViewSignal.hh"
+#include "Garfield/ViewMedium.hh"
+#include "Garfield/ComponentAnalyticField.hh"
+#include "Garfield/MediumMagboltz.hh"
+#include "Garfield/Sensor.hh"
+#include "Garfield/DriftLineRKF.hh"
+#include "Garfield/AvalancheMicroscopic.hh"
+#include "Garfield/AvalancheMC.hh"
+#include "Garfield/TrackHeed.hh"
+#include "Garfield/ComponentNeBem3d.hh"
+#include "Garfield/SolidWire.hh"
+#include "Garfield/GeometrySimple.hh"
+#include "Garfield/MediumConductor.hh"
+#include "Garfield/ViewField.hh"
+
+#include <map>
+#include <string>
+
+#include "core/session/onnxruntime_cxx_api.h"
+using namespace Garfield;
+
+class TrackHeedSimTool: public extends<AlgTool, IDedxSimTool> {
+    public:
+        using extends::extends;
+
+        StatusCode initialize() override;
+        StatusCode finalize() override;
+        double dedx(const G4Step* Step) override;
+        double dedx(const edm4hep::MCParticle& mc) override;
+        double dndx(double betagamma) override;
+        void getMom(float ee, float dx, float dy,float dz, float mom[3] );
+        void reset(){ 
+            m_isFirst = true;
+            m_previous_track_ID = 0;
+            m_previous_KE = 0;
+            m_tot_edep = 0;
+            //std::cout<<"m_tot_length="<<m_tot_length<<std::endl;
+            m_tot_length = 0;
+        }
+        void endOfEvent();
+        long long getCellID(float x, float y, float z);
+        void wire_xy(float x1, float y1, float z1, float x2, float y2, float z2, float z, float &x, float &y);
+        float* NNPred(std::vector<float>& inputs);
+        float xy2phi(float x, float y);
+        void getLocal(float x1, float y1, float x2, float y2, float& dx, float& dy);
+    private:
+        //ServiceHandle<IDataProviderSvc> m_eds;
+        SmartIF<IGeomSvc> m_geosvc;
+        dd4hep::Detector* m_dd4hep; 
+        dd4hep::Readout* m_readout;
+        dd4hep::DDSegmentation::GridDriftChamber* m_segmentation;
+        Gaudi::Property<std::string> m_readout_name{ this, "readout", "DriftChamberHitsCollection"};//readout for getting segmentation
+        Gaudi::Property<std::string> m_gas_file{ this, "gas_file", "He_50_isobutane_50.gas"};//gas
+        Gaudi::Property<std::string> m_IonMobility{ this, "IonMobility_file", "IonMobility_He+_He.txt"};
+        Gaudi::Property<float> m_isob  {this, "isob", 50, ""};
+        Gaudi::Property<float> m_he    {this, "he", 50, ""};
+        Gaudi::Property<bool> m_debug{this, "debug", false};
+        Gaudi::Property<bool> m_use_max_step{this, "use_max_step", false};
+        Gaudi::Property<bool> m_update_KE{this, "update_KE", true};
+        Gaudi::Property<float> m_max_step   {this, "max_step", 1};//mm
+        Gaudi::Property<bool> m_only_primary{this, "only_primary", false};
+        Gaudi::Property<bool> m_save_mc{this, "save_mc", false};
+        Gaudi::Property<bool> m_save_cellID{this, "save_cellID", true};
+        Gaudi::Property<float> m_delta_threshold{this, "delta_threshold", 50};//keV
+        Gaudi::Property<float> m_change_threshold {this, "change_threshold", 0.05};
+        Gaudi::Property<float> m_BField   {this, "BField", -3};
+        Gaudi::Property<float> m_eps     { this, "eps"   , 1e-6  };//very small value, it is returned dedx for unsimulated step (may needed for SimTrackerHit)
+        // Output collections
+        DataHandle<edm4hep::SimPrimaryIonizationClusterCollection>    m_SimPrimaryIonizationCol{"SimPrimaryIonizationClusterCollection", Gaudi::DataHandle::Writer, this};
+        // In order to associate MCParticle with contribution, we need to access MC Particle.
+        DataHandle<edm4hep::MCParticleCollection> m_mc_handle{"MCParticle", Gaudi::DataHandle::Writer, this};
+
+        TrackHeed* m_track;
+        ComponentNeBem3d m_nebem;
+        ComponentAnalyticField cmp;
+        GeometrySimple m_geo;
+        MediumConductor m_metal;
+        MediumMagboltz m_gas;
+        Sensor* m_sensor;
+        std::map<int, std::string> m_particle_map;
+        
+        int m_previous_track_ID;
+        float m_previous_KE;
+        int m_current_track_ID;
+        int m_current_Parent_ID;
+        int m_pdg_code;
+        G4StepPoint* m_pre_point;
+        G4StepPoint* m_post_point;
+        G4double m_total_range;
+        bool m_isFirst;
+        bool m_change_track;
+        edm4hep::MCParticle m_mc_paricle; 
+        float m_tot_edep;
+        float m_tot_length;
+        float m_pa_KE;
+  
+        G4double m_pre_x  ;
+        G4double m_pre_y  ;
+        G4double m_pre_z  ;
+        G4double m_pre_dx ;
+        G4double m_pre_dy ;
+        G4double m_pre_dz ;
+        G4double m_pre_t  ;
+  
+        //// sim pulse from NN /// 
+        Gaudi::Property<int> m_intra_op_nthreads{ this, "intraOpNumThreads", 1};
+        Gaudi::Property<int> m_inter_op_nthreads{ this, "interOpNumThreads", 1};
+        std::shared_ptr<Ort::Env> m_env;
+        std::shared_ptr<Ort::SessionOptions> m_seesion_options;
+        std::shared_ptr<Ort::Session> m_session;
+        Ort::AllocatorWithDefaultOptions m_allocator;
+        std::vector<const char*> m_input_node_names;
+        std::vector<std::vector<int64_t>> m_input_node_dims;
+        std::vector<const char*> m_output_node_names;
+        std::vector<int64_t> m_output_node_dims;
+        std::vector<Ort::AllocatedStringPtr> m_inputNodeNameAllocatedStrings;
+        std::vector<Ort::AllocatedStringPtr> m_outputNodeNameAllocatedStrings;
+  
+  
+        Gaudi::Property<bool> m_sim_pulse    { this, "sim_pulse"   , true  };
+        Gaudi::Property<std::string> m_model_file{ this, "model", "model_test.onnx"};
+        Gaudi::Property<int> m_batchsize     { this, "batchsize", 100};
+        Gaudi::Property<float> m_time_scale  { this, "time_scale", 99.0};
+        Gaudi::Property<float> m_time_shift  { this, "time_shift", 166.4};
+        Gaudi::Property<float> m_amp_scale   { this, "amp_scale" , 1e-2 };
+        Gaudi::Property<float> m_amp_shift   { this, "amp_shift" , 0    };
+        Gaudi::Property<float> m_x_scale     { this, "x_scale"   , 5.  };// in mm
+        Gaudi::Property<float> m_y_scale     { this, "y_scale"   , 5.  };// in mm
+  
+  
+
+};
+
+#endif
diff --git a/Simulation/DetSimInterface/include/DetSimInterface/IDedxSimTool.h b/Simulation/DetSimInterface/include/DetSimInterface/IDedxSimTool.h
index 8c804500..d4e8e9e4 100644
--- a/Simulation/DetSimInterface/include/DetSimInterface/IDedxSimTool.h
+++ b/Simulation/DetSimInterface/include/DetSimInterface/IDedxSimTool.h
@@ -28,6 +28,8 @@ public:
     virtual double dedx(const G4Step* aStep) = 0;
     virtual double dedx(const edm4hep::MCParticle& mc) = 0;
     virtual double dndx(double betagamma) = 0;
+    virtual void reset() {}
+    virtual void endOfEvent() {}
 
 };
 
diff --git a/cmake/FindOnnxRuntime.cmake b/cmake/FindOnnxRuntime.cmake
new file mode 100644
index 00000000..5a9db01c
--- /dev/null
+++ b/cmake/FindOnnxRuntime.cmake
@@ -0,0 +1,35 @@
+# Find the ONNX Runtime include directory and library.
+#
+# This module defines the `onnxruntime` imported target that encodes all
+# necessary information in its target properties.
+
+find_library(
+  OnnxRuntime_LIBRARY
+  NAMES onnxruntime
+  PATH_SUFFIXES lib lib32 lib64
+  DOC "The ONNXRuntime library")
+  
+if(NOT OnnxRuntime_LIBRARY)
+  message(FATAL_ERROR "onnxruntime library not found")
+endif()
+
+find_path(
+  OnnxRuntime_INCLUDE_DIR
+  NAMES core/session/onnxruntime_cxx_api.h
+  PATH_SUFFIXES include include/onnxruntime
+  DOC "The ONNXRuntime include directory")
+  
+if(NOT OnnxRuntime_INCLUDE_DIR)
+  message(FATAL_ERROR "onnxruntime includes not found")
+endif()
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(
+  OnnxRuntime
+  REQUIRED_VARS OnnxRuntime_LIBRARY OnnxRuntime_INCLUDE_DIR)
+
+add_library(OnnxRuntime SHARED IMPORTED)
+set_property(TARGET OnnxRuntime PROPERTY IMPORTED_LOCATION ${OnnxRuntime_LIBRARY})
+set_property(TARGET OnnxRuntime PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${OnnxRuntime_INCLUDE_DIR})
+
+mark_as_advanced(OnnxRuntime_FOUND OnnxRuntime_INCLUDE_DIR OnnxRuntime_LIBRARY)
-- 
GitLab