1- #pragma once
1+ #pragma once
22
33#include < matazure/lambda_tensor.hpp>
44#include < matazure/tensor.hpp>
55
66namespace matazure {
77
8- enum struct data_type {
9- undefined = 0 ,
8+ enum dtype {
9+ none = 0 ,
1010 dt_uint8,
1111 dt_uint16,
1212 dt_uint32,
@@ -24,67 +24,67 @@ template <typename _T>
2424struct get_data_type_traits ;
2525template <>
2626struct get_data_type_traits <std::uint8_t > {
27- const static data_type value = data_type ::dt_uint8;
27+ const static dtype value = dtype ::dt_uint8;
2828};
2929template <>
3030struct get_data_type_traits <std::uint16_t > {
31- const static data_type value = data_type ::dt_uint16;
31+ const static dtype value = dtype ::dt_uint16;
3232};
3333template <>
3434struct get_data_type_traits <std::uint32_t > {
35- const static data_type value = data_type ::dt_uint32;
35+ const static dtype value = dtype ::dt_uint32;
3636};
3737template <>
3838struct get_data_type_traits <std::uint64_t > {
39- const static data_type value = data_type ::dt_uint64;
39+ const static dtype value = dtype ::dt_uint64;
4040};
4141template <>
4242struct get_data_type_traits <std::int8_t > {
43- const static data_type value = data_type ::dt_int8;
43+ const static dtype value = dtype ::dt_int8;
4444};
4545template <>
4646struct get_data_type_traits <std::int16_t > {
47- const static data_type value = data_type ::dt_int16;
47+ const static dtype value = dtype ::dt_int16;
4848};
4949template <>
5050struct get_data_type_traits <std::int32_t > {
51- const static data_type value = data_type ::dt_int32;
51+ const static dtype value = dtype ::dt_int32;
5252};
5353template <>
5454struct get_data_type_traits <std::int64_t > {
55- const static data_type value = data_type ::dt_int64;
55+ const static dtype value = dtype ::dt_int64;
5656};
5757
58- // template <> struct get_data_type_traits<std::float> { const static data_type value =
59- // data_type ::dt_float16; };
58+ // template <> struct get_data_type_traits<std::float> { const static dtype value =
59+ // dtype ::dt_float16; };
6060template <>
6161struct get_data_type_traits <float > {
62- const static data_type value = data_type ::dt_float32;
62+ const static dtype value = dtype ::dt_float32;
6363};
6464template <>
6565struct get_data_type_traits <double > {
66- const static data_type value = data_type ::dt_float64;
66+ const static dtype value = dtype ::dt_float64;
6767};
6868
69- inline int_t get_data_type_size (data_type type ) {
70- switch (type ) {
71- case data_type ::dt_uint8:
69+ inline int_t get_data_type_size (dtype data_type ) {
70+ switch (data_type ) {
71+ case dtype ::dt_uint8:
7272 return 1 ;
73- case data_type ::dt_int8:
73+ case dtype ::dt_int8:
7474 return 1 ;
75- case data_type ::dt_uint16:
75+ case dtype ::dt_uint16:
7676 return 2 ;
77- case data_type ::dt_int16:
77+ case dtype ::dt_int16:
7878 return 2 ;
79- case data_type ::dt_uint32:
79+ case dtype ::dt_uint32:
8080 return 4 ;
81- case data_type ::dt_int32:
81+ case dtype ::dt_int32:
8282 return 4 ;
83- case data_type ::dt_float16:
83+ case dtype ::dt_float16:
8484 return 2 ;
85- case data_type ::dt_float32:
85+ case dtype ::dt_float32:
8686 return 4 ;
87- case data_type ::dt_float64:
87+ case dtype ::dt_float64:
8888 return 8 ;
8989 default :
9090 MATAZURE_ASSERT (false , " unreachable" );
@@ -99,74 +99,72 @@ class dynamic_tensor {
9999
100100 dynamic_tensor () {}
101101
102- dynamic_tensor (data_type type, shape_type ts_shape)
103- : type_(type),
104- ts_shape_ (ts_shape),
105- size_(reduce(ts_shape_, 1 , [](int_t x0, int_t x1) { return x0 * x1; })) {
106- auto p_mem_ = new byte[size_ * element_size ()];
107- sp_mem_.reset (p_mem_, [](byte* p) { delete[] p; });
108- }
109-
110- dynamic_tensor (data_type type, shape_type ts_shape, shared_ptr<void > sp_mem)
111- : type_(type),
112- ts_shape_ (ts_shape),
113- size_(reduce(ts_shape_, 1 , [](int_t x0, int_t x1) { return x0 * x1; })),
114- sp_mem_ (std::static_pointer_cast<byte>(sp_mem)) {}
102+ dynamic_tensor (dtype data_type, shape_type shape, shape_type stride, shared_ptr<void > sp_mem)
103+ : data_type_(data_type),
104+ shape_ (shape),
105+ stride_(stride),
106+ size_(reduce(shape_, 1 , [](int_t x0, int_t x1) { return x0 * x1; })),
107+ sp_mem_ (std::static_pointer_cast<void >(sp_mem)) {}
115108
116109 dynamic_tensor (const dynamic_tensor& other)
117- : type_(other.type_),
118- ts_shape_(other.ts_shape_),
110+ : data_type_(other.data_type_),
111+ shape_(other.shape_),
112+ stride_(other.stride_),
119113 size_(other.size_),
120114 sp_mem_(other.sp_mem_) {}
121115
122116 dynamic_tensor& operator =(const dynamic_tensor& other) {
123- type_ = other.type_ ;
124- ts_shape_ = other.ts_shape_ ;
117+ data_type_ = other.data_type_ ;
118+ shape_ = other.shape_ ;
119+ stride_ = other.stride_ ;
125120 size_ = other.size_ ;
126121 sp_mem_ = other.sp_mem_ ;
127122 return *this ;
128123 }
129124
130- data_type type () const { return type_; }
125+ dtype dtype () const { return data_type_; }
126+
127+ shape_type shape () const { return shape_; }
131128
132- shape_type shape () const { return ts_shape_ ; }
129+ shape_type stride () const { return stride_ ; }
133130
134- int_t shape (int_t i) const { return ts_shape_ [i]; }
131+ int_t shape (int_t i) const { return shape_ [i]; }
135132
136- int_t rank () const { return ts_shape_ .size (); }
133+ int_t rank () const { return shape_ .size (); }
137134
138135 int_t size () const { return size_; }
139136
140- template <typename _Type = byte >
137+ template <typename _Type = void >
141138 shared_ptr<_Type> shared_data () {
142139 auto sp_mem = sp_mem_;
143140 shared_ptr<_Type> sp_tmp (data<_Type>(), [sp_mem](_Type* p) {});
144141 return sp_tmp;
145142 }
146143
147- template <typename _Type = byte >
144+ template <typename _Type = void >
148145 shared_ptr<const _Type> shared_data () const {
149146 auto sp_mem = sp_mem_;
150147 shared_ptr<const _Type> sp_tmp (data<_Type>(), [sp_mem](const _Type* p) {});
151148 return sp_tmp;
152149 }
153150
154- template <typename _Type = byte >
151+ template <typename _Type = void >
155152 _Type* data () {
156153 return reinterpret_cast <_Type*>(sp_mem_.get ());
157154 }
158155
159- template <typename _Type = byte >
156+ template <typename _Type = void >
160157 const _Type* data () const {
161158 return reinterpret_cast <_Type*>(sp_mem_.get ());
162159 }
163160
164- int_t element_size () const { return get_data_type_size (type_ ); }
161+ int_t element_size () const { return get_data_type_size (data_type_ ); }
165162
166163 private:
167- data_type type_;
168- shared_ptr<byte> sp_mem_ = nullptr ;
169- shape_type ts_shape_;
164+ enum dtype data_type_;
165+ shared_ptr<void > sp_mem_ = nullptr ;
166+ shape_type shape_;
167+ shape_type stride_;
170168 int_t size_;
171169};
172170
@@ -175,8 +173,10 @@ dynamic_tensor dynamic_tensor_wrap(_Tensor ts) {
175173 auto rank = _Tensor::rank;
176174 dynamic_tensor::shape_type shape (rank);
177175 copy (ts.shape (), shape);
178- shared_ptr<byte> sp_tmp (reinterpret_cast <byte*>(ts.data ()), [ts](byte* p) {});
179- return dynamic_tensor (get_data_type_traits<typename _Tensor::value_type>::value, shape, sp_tmp);
176+ dynamic_tensor::shape_type stride (rank);
177+ copy (ts.layout ().stride (), stride);
178+ shared_ptr<void > sp_tmp (reinterpret_cast <void *>(ts.data ()), [ts](void * p) {});
179+ return dynamic_tensor (get_data_type_traits<typename _Tensor::value_type>::value, shape, stride, sp_tmp);
180180}
181181
182182} // namespace matazure
0 commit comments