Skip to content
Snippets Groups Projects
Commit f50cd1af authored by fangwx@ihep.ac.cn's avatar fangwx@ihep.ac.cn
Browse files

update TrackHeedSimTool

parent 553b3aae
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,8 @@
<define>
<constant name="tracker_region_rmax" value="1723*mm" />
<constant name="tracker_region_zmax" value="3050*mm" />
<constant name="world_size" value="2226*mm"/>
<constant name="world_x" value="world_size"/>
<constant name="world_y" value="world_size"/>
......
......@@ -2,7 +2,9 @@
find_package(Geant4 REQUIRED ui_all vis_all)
include(${Geant4_USE_FILE})
find_package(Garfield REQUIRED)
message(Garfield::Garfield)
message("libonnxruntime ${OnnxRuntime_LIBRARY}")
message("libonnxruntime include ${OnnxRuntime_INCLUDE_DIR}")
find_package(OnnxRuntime REQUIRED)
message("libonnxruntime ${OnnxRuntime_LIBRARY}")
......@@ -20,7 +22,8 @@ gaudi_add_module(DetSimDedx
EDM4HEP::edm4hep EDM4HEP::edm4hepDict
k4FWCore::k4FWCore
Garfield::Garfield
${OnnxRuntime_LIBRARY}
OnnxRuntime
#${OnnxRuntime_LIBRARY}
#/cvmfs/sft.cern.ch/lcg/views/LCG_103/x86_64-centos7-gcc11-opt/lib/libonnxruntime.so
${CLHEP_LIBRARIES}
......
......@@ -428,9 +428,15 @@ StatusCode TrackHeedSimTool::initialize()
auto num_input_nodes = m_session->GetInputCount();
if(m_debug) std::cout << "num_input_nodes: " << num_input_nodes << std::endl;
for (size_t i = 0; i < num_input_nodes; ++i) {
#if (ORT_API_VERSION >=13)
auto name = m_session->GetInputNameAllocated(i, m_allocator);
m_inputNodeNameAllocatedStrings.push_back(std::move(name));
m_input_node_names.push_back(m_inputNodeNameAllocatedStrings.back().get());
#else
auto name = m_session->GetInputName(i, m_allocator);
m_inputNodeNameAllocatedStrings.push_back(name);
m_input_node_names.push_back(m_inputNodeNameAllocatedStrings.back());
#endif
Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
......@@ -441,7 +447,11 @@ StatusCode TrackHeedSimTool::initialize()
if(m_debug) std::cout<< "[" << i << "]"
#if (ORT_API_VERSION >=13)
<< " input_name: " << m_inputNodeNameAllocatedStrings.back().get()
#else
<< " input_name: " << m_inputNodeNameAllocatedStrings.back()
#endif
<< " ndims: " << dims.size()
<< " dims: " << dims_str(dims)
<< std::endl;
......@@ -449,15 +459,25 @@ StatusCode TrackHeedSimTool::initialize()
// prepare the output
size_t num_output_nodes = m_session->GetOutputCount();
for(std::size_t i = 0; i < num_output_nodes; i++) {
#if (ORT_API_VERSION >=13)
auto output_name = m_session->GetOutputNameAllocated(i, m_allocator);
m_outputNodeNameAllocatedStrings.push_back(std::move(output_name));
m_output_node_names.push_back(m_outputNodeNameAllocatedStrings.back().get());
#else
auto output_name = m_session->GetOutputName(i, m_allocator);
m_outputNodeNameAllocatedStrings.push_back(output_name);
m_output_node_names.push_back(m_outputNodeNameAllocatedStrings.back());
#endif
Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
m_output_node_dims = tensor_info.GetShape();
if(m_debug) std::cout << "[" << i << "]"
#if (ORT_API_VERSION >=13)
<< " output_name: " << m_outputNodeNameAllocatedStrings.back().get()
#else
<< " output_name: " << m_outputNodeNameAllocatedStrings.back()
#endif
<< " ndims: " << m_output_node_dims.size()
<< " dims: " << dims_str(m_output_node_dims)
<< std::endl;
......
......@@ -39,6 +39,7 @@
#include <string>
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_c_api.h"
using namespace Garfield;
class TrackHeedSimTool: public extends<AlgTool, IDedxSimTool> {
......@@ -135,9 +136,13 @@ class TrackHeedSimTool: public extends<AlgTool, IDedxSimTool> {
std::vector<std::vector<int64_t>> m_input_node_dims;
std::vector<const char*> m_output_node_names;
std::vector<int64_t> m_output_node_dims;
#if (ORT_API_VERSION >=13)
std::vector<Ort::AllocatedStringPtr> m_inputNodeNameAllocatedStrings;
std::vector<Ort::AllocatedStringPtr> m_outputNodeNameAllocatedStrings;
#else
std::vector<const char*> m_inputNodeNameAllocatedStrings;
std::vector<const char*> m_outputNodeNameAllocatedStrings;
#endif
Gaudi::Property<bool> m_sim_pulse { this, "sim_pulse" , true };
Gaudi::Property<std::string> m_model_file{ this, "model", "model_test.onnx"};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment