EyeAI
Loading...
Searching...
No Matches
TfLiteRuntime.hpp
Go to the documentation of this file.
1#pragma once
2
4#include "TfLiteUtils.hpp"
5#if EYE_AI_CORE_USE_PREBUILT_TFLITE
6#include "tflite/c/c_api.h" // IWYU pragma: export
7#include "tflite/delegates/gpu/delegate.h"
8#else
9#include "tensorflow/lite/c/c_api.h" // IWYU pragma: export
10#include "tensorflow/lite/delegates/gpu/delegate.h"
11#endif
12#include <memory>
13#include <span>
14#include <string>
15#include <string_view>
16
17using TfLiteLogWarningCallback = void (*)(std::string);
18using TfLiteLogErrorCallback = void (*)(std::string);
19
32
33class ProfilingFrame;
34
37 std::vector<int8_t> model_data;
38 FloatTensorFormat model_input_format;
39 FloatTensorFormat model_output_format;
40 std::unique_ptr<TfLiteModel, decltype(&TfLiteModelDelete)> model{
41 nullptr, TfLiteModelDelete
42 };
43 std::unique_ptr<TfLiteInterpreter, decltype(&TfLiteInterpreterDelete)>
44 interpreter{nullptr, TfLiteInterpreterDelete};
45 std::unique_ptr<
46 TfLiteInterpreterOptions,
47 decltype(&TfLiteInterpreterOptionsDelete)>
48 interpreter_options{nullptr, TfLiteInterpreterOptionsDelete};
50 std::unique_ptr<TfLiteDelegate, void(*)(TfLiteDelegate*)>
51 gpu_delegate{nullptr, TfLiteGpuDelegateV2Delete};
53 std::unique_ptr<TfLiteDelegate, void(*)(TfLiteDelegate*)>
54 npu_delegate{nullptr, TfLiteGpuDelegateV2Delete};
55
56 TfLiteReporterUserData reporter_user_data;
57
58 ProfilingFrame& profiling_frame;
59
60 public:
62 tl::expected<std::unique_ptr<TfLiteRuntime>, TfLiteCreateRuntimeError>;
63
65 [[nodiscard]] static CreateResult create(
66 std::vector<int8_t>&& model_data,
67 std::string_view delegate_serialization_dir,
68 std::string_view model_token,
69 FloatTensorFormat model_input_format,
70 FloatTensorFormat model_output_format,
71 TfLiteLogWarningCallback log_warning_callback,
72 TfLiteLogErrorCallback log_error_callback,
73 ProfilingFrame& profiling_frame,
74 NpuConfiguration npu_config,
75 bool enable_npu,
76 std::string skel_library_dir
77 );
78
80
89 [[nodiscard]] std::optional<TfLiteRunInferenceError>
90 run_inference(std::span<float> input, std::span<float> output);
91
92 template<FloatTensorFormat InputFormat, FloatTensorFormat OutputFormat>
93 [[nodiscard]] tl::
94 expected<FloatTensorBuffer<OutputFormat>, TfLiteRunInferenceError>
96
97 if (model_input_format != InputFormat) {
98 return tl::unexpected(
100 .provided = InputFormat, .expected = model_input_format
101 }
102 );
103 }
104 if (model_output_format != OutputFormat) {
105 return tl::unexpected(
107 .provided = InputFormat, .expected = model_input_format
108 }
109 );
110 }
111
112 size_t output_element_count = 1;
113 for (const int dim : get_output_shape()) {
114 output_element_count *= dim;
115 }
117 std::vector<float>(output_element_count)
118 };
119
120 if (const auto error = run_inference(input.data(), output.data())) {
121 return tl::unexpected(*error);
122 }
123
124 return output;
125 }
126
127 [[nodiscard]] std::span<const int> get_input_shape() const;
128
129 [[nodiscard]] std::span<const int> get_output_shape() const;
130
132 TfLiteRuntime(const TfLiteRuntime&) = delete;
133 void operator=(TfLiteRuntime&&) = delete;
134 void operator=(const TfLiteRuntime&) = delete;
135
136 private:
137 explicit TfLiteRuntime(
138 std::vector<int8_t>&& model_data,
139 FloatTensorFormat model_input_format,
140 FloatTensorFormat model_output_format,
141 TfLiteReporterUserData error_reporter_user_data,
142 ProfilingFrame& profiling_frame
143 )
144 : model_data(std::move(model_data)),
145 model_input_format(model_input_format),
146 model_output_format(model_output_format),
147 reporter_user_data(error_reporter_user_data),
148 profiling_frame(profiling_frame) {}
149
150 [[nodiscard]] std::optional<TfLiteInvokeInterpreterError> invoke();
151
152 [[nodiscard]] std::optional<TfLiteLoadInputError>
153 load_input(std::span<const float> input);
154
155 [[nodiscard]] std::optional<TfLiteReadOutputError>
156 read_output(std::span<float> output);
157};
FloatTensorFormat
Definition TensorBuffer.hpp:73
TensorBuffer< float, FloatTensorFormat, FORMAT > FloatTensorBuffer
Definition TensorBuffer.hpp:106
void(*)(std::string) TfLiteLogWarningCallback
Definition TfLiteRuntime.hpp:17
void(*)(std::string) TfLiteLogErrorCallback
Definition TfLiteRuntime.hpp:18
NpuConfiguration
Definition TfLiteUtils.hpp:48
collection of profile records from different threads (lock-free thread-safe)
Definition Profiling.hpp:52
TfLiteRuntime(const TfLiteRuntime &)=delete
std::span< const int > get_input_shape() const
Definition TfLiteRuntime.cpp:171
tl::expected< std::unique_ptr< TfLiteRuntime >, TfLiteCreateRuntimeError > CreateResult
Definition TfLiteRuntime.hpp:61
~TfLiteRuntime()
Definition TfLiteRuntime.cpp:136
void operator=(const TfLiteRuntime &)=delete
void operator=(TfLiteRuntime &&)=delete
TfLiteRuntime(TfLiteRuntime &&)=delete
tl::expected< FloatTensorBuffer< OutputFormat >, TfLiteRunInferenceError > run_inference(FloatTensorBuffer< InputFormat > &input)
Definition TfLiteRuntime.hpp:95
std::span< const int > get_output_shape() const
Definition TfLiteRuntime.cpp:178
std::optional< TfLiteRunInferenceError > run_inference(std::span< float > input, std::span< float > output)
Run inference on the model, make sure input and output have the right amount of elements.
Definition TfLiteRuntime.cpp:156
static CreateResult create(std::vector< int8_t > &&model_data, std::string_view delegate_serialization_dir, std::string_view model_token, FloatTensorFormat model_input_format, FloatTensorFormat model_output_format, TfLiteLogWarningCallback log_warning_callback, TfLiteLogErrorCallback log_error_callback, ProfilingFrame &profiling_frame, NpuConfiguration npu_config, bool enable_npu, std::string skel_library_dir)
Create a TfLiteRuntime instance.
Definition TfLiteRuntime.cpp:18
Definition TfLiteUtils.hpp:231
Definition TfLiteUtils.hpp:239
std::span< T > data()
Definition TensorBuffer.hpp:34
Callbacks invoked by the tflite runtime, passed around as void* user_data.
Definition TfLiteRuntime.hpp:21
TfLiteLogWarningCallback log_warning_callback
Definition TfLiteRuntime.hpp:22
TfLiteLogErrorCallback log_error_callback
Definition TfLiteRuntime.hpp:23
TfLiteReporterUserData(TfLiteLogWarningCallback log_warning_callback, TfLiteLogErrorCallback log_error_callback)
Definition TfLiteRuntime.hpp:25