Skip to content

Commit 4289284

Browse files
author
Zhang Zhimin
committed
Fixes layout stride as the numpy strides
1 parent feb1209 commit 4289284

File tree

10 files changed

+170
-129
lines changed

10 files changed

+170
-129
lines changed

.clang-format

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# We'll use defaults from the Google style, but with 4 columns indentation.
33
BasedOnStyle: Google
44
IndentWidth: 4
5-
ColumnLimit: 100
5+
ColumnLimit: 120
66
---
77
Language: Cpp
88
# Force pointers to the type for C++.
@@ -12,7 +12,6 @@ IndentWidth: 4
1212
---
1313
Language: JavaScript
1414
# Use 100 columns for JS.
15-
ColumnLimit: 100
1615
---
1716
Language: Proto
1817
# Don't format .proto files.

include/matazure/allocator.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <malloc.h>
3+
#include <stdlib.h>
44
#include <matazure/config.hpp>
55

66
namespace matazure {
@@ -13,23 +13,24 @@ class aligned_allocator : public std::allocator<_Type> {
1313
aligned_allocator& operator=(const aligned_allocator& rhs) { return *this; }
1414

1515
_Type* allocate(size_t size) {
16-
#ifdef __GNUC__
17-
_Type* data = reinterpret_cast<_Type*>(memalign(_Alignment, size * sizeof(_Type)));
16+
_Type* data = nullptr;
17+
#ifdef __WIN32
18+
data = reinterpret_cast<_Type*>(_aligned_malloc(size * sizeof(_Type), _Alignment));
1819
#else
19-
_Type* data = reinterpret_cast<_Type*>(_aligned_malloc(size * sizeof(_Type), _Alignment));
20+
posix_memalign(reinterpret_cast<void **>(&data), _Alignment, size * sizeof(_Type));
2021
#endif
2122
return data;
2223
}
2324

2425
void deallocate(_Type* p, size_t size) {
25-
#ifdef __GNUC__
26-
free(p);
27-
#else
26+
#ifdef __WIN32
2827
_aligned_free(p);
28+
#else
29+
free(p);
2930
#endif
3031
}
3132

3233
~aligned_allocator(){};
3334
};
3435

35-
} // namespace matazure
36+
} // namespace matazure

include/matazure/config.hpp

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,16 @@ struct blank_t {};
120120
} // namespace matazure
121121

122122
// for assert
123-
#define MATAZURE_STATIC_ASSERT_DIM_MATCHED(T1, T2) \
124-
static_assert(T1::rank == T2::rank, "the rank is not matched")
123+
#define MATAZURE_STATIC_ASSERT_DIM_MATCHED(T1, T2) static_assert(T1::rank == T2::rank, "the rank is not matched")
125124

126125
#define MATAZURE_STATIC_ASSERT_VALUE_TYPE_MATCHED(T1, T2) \
127126
static_assert(std::is_same<typename T1::value_type, typename T2::value_type>::value, \
128127
"the value type is not matched")
129128

130-
#define MATAZURE_STATIC_ASSERT_MEMORY_TYPE_MATCHED(T1, T2) \
131-
static_assert(std::is_same<runtime_t<T1>, runtime_t<T2>>::value, \
132-
"the memory type is not matched")
129+
#define MATAZURE_STATIC_ASSERT_MEMORY_TYPE_MATCHED(T1, T2) \
130+
static_assert(std::is_same<runtime_t<T1>, runtime_t<T2>>::value, "the memory type is not matched")
133131

134-
#define MATAZURE_STATIC_ASSERT_MATRIX_RANK(T) \
135-
static_assert(T::rank == 2, "the matrix rank should be 2")
132+
#define MATAZURE_STATIC_ASSERT_MATRIX_RANK(T) static_assert(T::rank == 2, "the matrix rank should be 2")
136133

137134
#define MATAZURE_CURRENT_FUNCTION "(unknown)"
138135

@@ -149,29 +146,68 @@ struct blank_t {};
149146
#define MATAZURE_UNLIKELY(x) x
150147
#endif
151148

152-
#if defined(MATAZURE_DISABLE_ASSERTS)
153-
154-
#define MATAZURE_ASSERT(expr, msg) ((void)0)
155-
156-
#else
157-
158149
namespace matazure {
159150

160-
class assert_failed : public std::runtime_error {
151+
class assert_failed : public std::exception {
152+
public:
153+
assert_failed(const string& expr, const string& function, const string& file, size_t line,
154+
const std::string& msg = " ")
155+
: _expr(expr), _function(function), _file(file), _line(line), _msg(msg) {
156+
_what_str = _expr + ", " + _function + ", " + _file + ", " + std::to_string(_line) + ", " + _msg;
157+
}
158+
159+
virtual const char* what() const noexcept override { return _what_str.c_str(); }
160+
161+
private:
162+
string _expr;
163+
string _function;
164+
string _file;
165+
size_t _line;
166+
string _msg;
167+
string _what_str;
168+
};
169+
170+
class verify_failed : public std::exception {
161171
public:
162-
assert_failed(const std::string& msg) : std::runtime_error(msg) {}
172+
verify_failed(const string& expr, const string& function, const string& file, size_t line,
173+
const std::string& msg = " ")
174+
: _expr(expr), _function(function), _file(file), _line(line), _msg(msg) {
175+
_what_str = _expr + ", " + _function + ", " + _file + ", " + std::to_string(_line) + ", " + _msg;
176+
}
177+
178+
virtual const char* what() const noexcept override { return _what_str.c_str(); }
179+
180+
private:
181+
string _expr;
182+
string _function;
183+
string _file;
184+
size_t _line;
185+
string _msg;
186+
string _what_str;
163187
};
164188

165-
inline void assertion_failed(char const* expr, char const* msg, char const* function,
166-
char const* file, long line) {
167-
throw assert_failed(std::string(msg));
168-
}
189+
inline void raise_assert_failed(const string& expr, const string& function, const string& file, long line,
190+
const string& msg = " ") {
191+
throw assert_failed(expr, function, file, line, msg);
192+
}
169193

170-
} // namespace matazure
194+
inline void raise_verify_failed(const string& expr, const string& function, const string& file, long line,
195+
const string& msg = " ") {
196+
throw verify_failed(expr, function, file, line, msg);
197+
}
171198

172-
#define MATAZURE_ASSERT(expr, msg) \
173-
(MATAZURE_LIKELY(!!(expr)) ? ((void)0) \
174-
: ::matazure::assertion_failed( \
175-
#expr, msg, MATAZURE_CURRENT_FUNCTION, __FILE__, __LINE__))
199+
} // namespace matazure
176200

201+
#if defined(MATAZURE_DISABLE_ASSERTS)
202+
#define MATAZURE_ASSERT(expr, msg) ((void)0)
203+
#else
204+
#define MATAZURE_ASSERT(expr, ...) \
205+
(MATAZURE_LIKELY(!!(expr)) \
206+
? ((void)0) \
207+
: ::matazure::raise_assert_failed(#expr, MATAZURE_CURRENT_FUNCTION, __FILE__, __LINE__, ##__VA_ARGS__))
177208
#endif
209+
210+
#define MATAZURE_VERIFY(expr, ...) \
211+
(MATAZURE_LIKELY(!!(expr)) \
212+
? ((void)0) \
213+
: ::matazure::raise_verify_failed(#expr, MATAZURE_CURRENT_FUNCTION, __FILE__, __LINE__, ##__VA_ARGS__))
Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#pragma once
1+
#pragma once
22

33
#include <matazure/lambda_tensor.hpp>
44
#include <matazure/tensor.hpp>
55

66
namespace matazure {
77

8-
enum struct data_type {
9-
undefined = 0,
8+
enum dtype {
9+
none = 0,
1010
dt_uint8,
1111
dt_uint16,
1212
dt_uint32,
@@ -24,67 +24,67 @@ template <typename _T>
2424
struct get_data_type_traits;
2525
template <>
2626
struct get_data_type_traits<std::uint8_t> {
27-
const static data_type value = data_type::dt_uint8;
27+
const static dtype value = dtype::dt_uint8;
2828
};
2929
template <>
3030
struct get_data_type_traits<std::uint16_t> {
31-
const static data_type value = data_type::dt_uint16;
31+
const static dtype value = dtype::dt_uint16;
3232
};
3333
template <>
3434
struct get_data_type_traits<std::uint32_t> {
35-
const static data_type value = data_type::dt_uint32;
35+
const static dtype value = dtype::dt_uint32;
3636
};
3737
template <>
3838
struct get_data_type_traits<std::uint64_t> {
39-
const static data_type value = data_type::dt_uint64;
39+
const static dtype value = dtype::dt_uint64;
4040
};
4141
template <>
4242
struct get_data_type_traits<std::int8_t> {
43-
const static data_type value = data_type::dt_int8;
43+
const static dtype value = dtype::dt_int8;
4444
};
4545
template <>
4646
struct get_data_type_traits<std::int16_t> {
47-
const static data_type value = data_type::dt_int16;
47+
const static dtype value = dtype::dt_int16;
4848
};
4949
template <>
5050
struct get_data_type_traits<std::int32_t> {
51-
const static data_type value = data_type::dt_int32;
51+
const static dtype value = dtype::dt_int32;
5252
};
5353
template <>
5454
struct get_data_type_traits<std::int64_t> {
55-
const static data_type value = data_type::dt_int64;
55+
const static dtype value = dtype::dt_int64;
5656
};
5757

58-
// template <> struct get_data_type_traits<std::float> { const static data_type value =
59-
// data_type::dt_float16; };
58+
// template <> struct get_data_type_traits<std::float> { const static dtype value =
59+
// dtype::dt_float16; };
6060
template <>
6161
struct get_data_type_traits<float> {
62-
const static data_type value = data_type::dt_float32;
62+
const static dtype value = dtype::dt_float32;
6363
};
6464
template <>
6565
struct get_data_type_traits<double> {
66-
const static data_type value = data_type::dt_float64;
66+
const static dtype value = dtype::dt_float64;
6767
};
6868

69-
inline int_t get_data_type_size(data_type type) {
70-
switch (type) {
71-
case data_type::dt_uint8:
69+
inline int_t get_data_type_size(dtype data_type) {
70+
switch (data_type) {
71+
case dtype::dt_uint8:
7272
return 1;
73-
case data_type::dt_int8:
73+
case dtype::dt_int8:
7474
return 1;
75-
case data_type::dt_uint16:
75+
case dtype::dt_uint16:
7676
return 2;
77-
case data_type::dt_int16:
77+
case dtype::dt_int16:
7878
return 2;
79-
case data_type::dt_uint32:
79+
case dtype::dt_uint32:
8080
return 4;
81-
case data_type::dt_int32:
81+
case dtype::dt_int32:
8282
return 4;
83-
case data_type::dt_float16:
83+
case dtype::dt_float16:
8484
return 2;
85-
case data_type::dt_float32:
85+
case dtype::dt_float32:
8686
return 4;
87-
case data_type::dt_float64:
87+
case dtype::dt_float64:
8888
return 8;
8989
default:
9090
MATAZURE_ASSERT(false, "unreachable");
@@ -99,74 +99,72 @@ class dynamic_tensor {
9999

100100
dynamic_tensor() {}
101101

102-
dynamic_tensor(data_type type, shape_type ts_shape)
103-
: type_(type),
104-
ts_shape_(ts_shape),
105-
size_(reduce(ts_shape_, 1, [](int_t x0, int_t x1) { return x0 * x1; })) {
106-
auto p_mem_ = new byte[size_ * element_size()];
107-
sp_mem_.reset(p_mem_, [](byte* p) { delete[] p; });
108-
}
109-
110-
dynamic_tensor(data_type type, shape_type ts_shape, shared_ptr<void> sp_mem)
111-
: type_(type),
112-
ts_shape_(ts_shape),
113-
size_(reduce(ts_shape_, 1, [](int_t x0, int_t x1) { return x0 * x1; })),
114-
sp_mem_(std::static_pointer_cast<byte>(sp_mem)) {}
102+
dynamic_tensor(dtype data_type, shape_type shape, shape_type stride, shared_ptr<void> sp_mem)
103+
: data_type_(data_type),
104+
shape_(shape),
105+
stride_(stride),
106+
size_(reduce(shape_, 1, [](int_t x0, int_t x1) { return x0 * x1; })),
107+
sp_mem_(std::static_pointer_cast<void>(sp_mem)) {}
115108

116109
dynamic_tensor(const dynamic_tensor& other)
117-
: type_(other.type_),
118-
ts_shape_(other.ts_shape_),
110+
: data_type_(other.data_type_),
111+
shape_(other.shape_),
112+
stride_(other.stride_),
119113
size_(other.size_),
120114
sp_mem_(other.sp_mem_) {}
121115

122116
dynamic_tensor& operator=(const dynamic_tensor& other) {
123-
type_ = other.type_;
124-
ts_shape_ = other.ts_shape_;
117+
data_type_ = other.data_type_;
118+
shape_ = other.shape_;
119+
stride_ = other.stride_;
125120
size_ = other.size_;
126121
sp_mem_ = other.sp_mem_;
127122
return *this;
128123
}
129124

130-
data_type type() const { return type_; }
125+
dtype dtype() const { return data_type_; }
126+
127+
shape_type shape() const { return shape_; }
131128

132-
shape_type shape() const { return ts_shape_; }
129+
shape_type stride() const { return stride_; }
133130

134-
int_t shape(int_t i) const { return ts_shape_[i]; }
131+
int_t shape(int_t i) const { return shape_[i]; }
135132

136-
int_t rank() const { return ts_shape_.size(); }
133+
int_t rank() const { return shape_.size(); }
137134

138135
int_t size() const { return size_; }
139136

140-
template <typename _Type = byte>
137+
template <typename _Type = void>
141138
shared_ptr<_Type> shared_data() {
142139
auto sp_mem = sp_mem_;
143140
shared_ptr<_Type> sp_tmp(data<_Type>(), [sp_mem](_Type* p) {});
144141
return sp_tmp;
145142
}
146143

147-
template <typename _Type = byte>
144+
template <typename _Type = void>
148145
shared_ptr<const _Type> shared_data() const {
149146
auto sp_mem = sp_mem_;
150147
shared_ptr<const _Type> sp_tmp(data<_Type>(), [sp_mem](const _Type* p) {});
151148
return sp_tmp;
152149
}
153150

154-
template <typename _Type = byte>
151+
template <typename _Type = void>
155152
_Type* data() {
156153
return reinterpret_cast<_Type*>(sp_mem_.get());
157154
}
158155

159-
template <typename _Type = byte>
156+
template <typename _Type = void>
160157
const _Type* data() const {
161158
return reinterpret_cast<_Type*>(sp_mem_.get());
162159
}
163160

164-
int_t element_size() const { return get_data_type_size(type_); }
161+
int_t element_size() const { return get_data_type_size(data_type_); }
165162

166163
private:
167-
data_type type_;
168-
shared_ptr<byte> sp_mem_ = nullptr;
169-
shape_type ts_shape_;
164+
enum dtype data_type_;
165+
shared_ptr<void> sp_mem_ = nullptr;
166+
shape_type shape_;
167+
shape_type stride_;
170168
int_t size_;
171169
};
172170

@@ -175,8 +173,10 @@ dynamic_tensor dynamic_tensor_wrap(_Tensor ts) {
175173
auto rank = _Tensor::rank;
176174
dynamic_tensor::shape_type shape(rank);
177175
copy(ts.shape(), shape);
178-
shared_ptr<byte> sp_tmp(reinterpret_cast<byte*>(ts.data()), [ts](byte* p) {});
179-
return dynamic_tensor(get_data_type_traits<typename _Tensor::value_type>::value, shape, sp_tmp);
176+
dynamic_tensor::shape_type stride(rank);
177+
copy(ts.layout().stride(), stride);
178+
shared_ptr<void> sp_tmp(reinterpret_cast<void*>(ts.data()), [ts](void* p) {});
179+
return dynamic_tensor(get_data_type_traits<typename _Tensor::value_type>::value, shape, stride, sp_tmp);
180180
}
181181

182182
} // namespace matazure

0 commit comments

Comments
 (0)