Skip to content

Commit 37940ff

Browse files
committed
update module 3
1 parent 5d89fe7 commit 37940ff

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

project/parallel_check.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
# MM
3030
print("MATRIX MULTIPLY")
3131
out, a, b = (
32-
minitorch.zeros((10, 10)),
33-
minitorch.zeros((10, 20)),
34-
minitorch.zeros((20, 10)),
32+
minitorch.zeros((1, 10, 10)),
33+
minitorch.zeros((1, 10, 20)),
34+
minitorch.zeros((1, 20, 10)),
3535
)
3636
tmm = minitorch.fast_ops.tensor_matrix_multiply
3737

tests/test_tensor_general.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,16 @@ def test_sum_practice5():
157157
out = b2.sum(0)
158158
assert_close(s, out[0])
159159

160+
@pytest.mark.task3_3
161+
def test_sum_practice_other_dims():
162+
x = [[random.random() for i in range(32)] for j in range(16)]
163+
b = minitorch.tensor(x)
164+
s = b.sum(1)
165+
b2 = minitorch.tensor(x, backend=shared["cuda"])
166+
out = b2.sum(1)
167+
for i in range(16):
168+
assert_close(s[i, 0], out[i, 0])
169+
160170
@pytest.mark.task3_4
161171
def test_mul_practice1():
162172
x = [[random.random() for i in range(2)] for j in range(2)]

0 commit comments

Comments
 (0)