Skip to content

Commit 6e72dd6

Browse files
committed
Add L3
1 parent a4d054c commit 6e72dd6

7 files changed

Lines changed: 263 additions & 2 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
|| 内容 | 资料 | 代码 |
77
|---|---|---|---|
8-
| 1 | 使用Python编写神经网络 | [课件](notes/L1-使用python编写神经网络.pdf) | [factorial-relu](code/l1-factorial-relu.py), [factorial-sigmoid](code/l1-factorial-sigmoid.py) |
8+
| 1 | 使用Python编写神经网络 | [课件](notes/L1-使用Python编写神经网络.pdf) | [factorial-relu](code/l1-factorial-relu.py), [factorial-sigmoid](code/l1-factorial-sigmoid.py) |
99
| 2 | Python程序的解析和运行 | [课件](notes/L2-Python程序的解析和运行.pdf) | [regex](code/l2-regex.py), [pytool](code/l2-pytool.py), [relu](code/l2-relu.py) |
10-
| 3 | Python与Native Code交互 | | |
10+
| 3 | Python与Native Code交互 | [课件](notes/L3-Python与native交互.pdf) | [bench_dot(py)](code/l3-1-bench_dot.py), [bench_dot(c)](code/l3-1-bench_dot.c), [bench_numba](code/l3-2-bench_numba.py), [decorator(py)](code/l3-3-deco.py), [decorator(c)](code/l3-3-vec.c) |
1111
| 4 | Python多线程 | | |
1212
| 5 | 课堂反转 | |
1313
| 6 | 使用PyTorch编写和训练神经网络模型 | | |

code/l3-1-bench_dot.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include <stddef.h>
2+
3+
double dot_c(const double *a, const double *b, size_t n) {
4+
double s = 0.0;
5+
for (size_t i = 0; i < n; i++) {
6+
s += a[i] * b[i];
7+
}
8+
return s;
9+
}

code/l3-1-bench_dot.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import ctypes
2+
import random
3+
import time
4+
import numpy as np
5+
from numba import njit
6+
7+
# -------------------------------
8+
# Configuration
9+
# -------------------------------
10+
N = 5_000_000 # vector length
11+
SEED = 42
12+
13+
random.seed(SEED)
14+
15+
# -------------------------------
16+
# Data preparation
17+
# -------------------------------
18+
a_list = [random.random() for _ in range(N)]
19+
b_list = [random.random() for _ in range(N)]
20+
21+
a_np = np.array(a_list, dtype=np.float64)
22+
b_np = np.array(b_list, dtype=np.float64)
23+
24+
25+
# -------------------------------
26+
# Implementations
27+
# -------------------------------
28+
def dot_python(a, b):
29+
s = 0.0
30+
for i in range(len(a)):
31+
s += a[i] * b[i]
32+
return s
33+
34+
35+
def dot_numpy(a, b):
36+
return np.dot(a, b)
37+
38+
# -------------------------------
39+
# Benchmark helper
40+
# -------------------------------
41+
def benchmark(func, *args):
42+
times = []
43+
start = time.perf_counter()
44+
func(*args)
45+
end = time.perf_counter()
46+
return end - start
47+
48+
49+
lib = ctypes.cdll.LoadLibrary("./libdot.so")
50+
51+
lib.dot_c.restype = ctypes.c_double
52+
lib.dot_c.argtypes = [
53+
ctypes.POINTER(ctypes.c_double),
54+
ctypes.POINTER(ctypes.c_double),
55+
ctypes.c_int,
56+
]
57+
58+
def dot_c(a, b):
59+
return lib.dot_c(
60+
a.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
61+
b.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
62+
a.size
63+
)
64+
65+
# -------------------------------
66+
# Run benchmarks
67+
# -------------------------------
68+
t_py = benchmark(dot_python, a_list, b_list)
69+
t_np = benchmark(dot_numpy, a_np, b_np)
70+
t_c = benchmark(dot_c, a_np, b_np)
71+
72+
# -------------------------------
73+
# Results
74+
# -------------------------------
75+
print(f"Python loop : {t_py:.4f} s")
76+
print(f"NumPy dot : {t_np:.4f} s")
77+
print(f"C dot : {t_np:.4f} s")

code/l3-2-bench_numba.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import random
2+
import time
3+
import numpy as np
4+
from numba import njit
5+
6+
N = 5_000_000
7+
8+
def dot_python(a, b):
9+
s = 0.0
10+
for i in range(len(a)):
11+
s += a[i] * b[i]
12+
return s
13+
14+
@njit
15+
def dot_numba(a, b):
16+
s = 0.0
17+
for i in range(len(a)):
18+
s += a[i] * b[i]
19+
return s
20+
21+
# 生成数据
22+
a_list = [random.random() for _ in range(N)]
23+
b_list = [random.random() for _ in range(N)]
24+
25+
a_np = np.array(a_list)
26+
b_np = np.array(b_list)
27+
28+
# Python benchmark
29+
t0 = time.perf_counter()
30+
dot_python(a_np, b_np)
31+
t1 = time.perf_counter()
32+
33+
# Numba 第一次调用(包含 JIT 编译)
34+
t2 = time.perf_counter()
35+
dot_numba(a_np, b_np)
36+
t3 = time.perf_counter()
37+
38+
# Numba 第二次调用(已编译)
39+
t4 = time.perf_counter()
40+
dot_numba(a_np, b_np)
41+
t5 = time.perf_counter()
42+
43+
print("Python time:", t1 - t0)
44+
print("Numba first call (compile + run):", t3 - t2)
45+
print("Numba second call (run only):", t5 - t4)

code/l3-3-deco.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import ast
2+
import inspect
3+
import ctypes
4+
import numpy as np
5+
6+
# ============================================================
7+
# Native primitive kernels (precompiled shared library)
8+
# ============================================================
9+
10+
lib = ctypes.CDLL("./libvec.so")
11+
12+
lib.vec_elem_mul.argtypes = lib.vec_elem_add.argtypes = [
13+
ctypes.POINTER(ctypes.c_double),
14+
ctypes.POINTER(ctypes.c_double),
15+
ctypes.POINTER(ctypes.c_double),
16+
ctypes.c_int,
17+
]
18+
lib.vec_elem_add.restype = lib.vec_elem_mul.restype = None
19+
20+
21+
# ============================================================
22+
# @kernel decorator: Python AST -> lowering -> native execution
23+
# ============================================================
24+
25+
def kernel(func):
26+
"""
27+
Treat the function body as an embedded DSL.
28+
The body is parsed, not executed.
29+
"""
30+
# -------- Parse Python AST --------
31+
src = inspect.getsource(func)
32+
tree = ast.parse(src)
33+
func_def = tree.body[0]
34+
35+
# Expect:
36+
# return <expr>
37+
return_stmt = func_def.body[0]
38+
39+
expr = return_stmt.value
40+
arg_names = [arg.arg for arg in func_def.args.args]
41+
42+
# -------- Lowering: AST -> primitive ops --------
43+
def lower(node, env, n):
44+
if isinstance(node, ast.Name):
45+
return env[node.id]
46+
47+
if isinstance(node, ast.BinOp):
48+
left = lower(node.left, env, n)
49+
right = lower(node.right, env, n)
50+
out = np.empty_like(left)
51+
52+
if isinstance(node.op, ast.Mult):
53+
lib.vec_elem_mul(
54+
left.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
55+
right.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
56+
out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
57+
n,
58+
)
59+
elif isinstance(node.op, ast.Add):
60+
lib.vec_elem_add(
61+
left.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
62+
right.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
63+
out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
64+
n,
65+
)
66+
else:
67+
raise NotImplementedError("Unsupported operator")
68+
69+
return out
70+
71+
raise NotImplementedError("Unsupported AST node")
72+
73+
# -------- Runtime wrapper --------
74+
def wrapper(*args):
75+
arrays = [np.asarray(a, dtype=np.float64) for a in args]
76+
n = arrays[0].size
77+
78+
env = dict(zip(arg_names, arrays))
79+
result = lower(expr, env, n)
80+
return result
81+
82+
return wrapper
83+
84+
85+
# ============================================================
86+
# User-defined kernels (pure Python, no strings)
87+
# ============================================================
88+
89+
@kernel
90+
def vec_elem_mul(a, b):
91+
92+
# element-wise multiplication
93+
return a * b
94+
95+
@kernel
96+
def vec_elem_add(a, b):
97+
# element-wise addition
98+
return a + b
99+
100+
@kernel
101+
def vec_elem_fma(a, b, c):
102+
# element-wise fused multiply-add: a * b + c
103+
return (a * b) + c
104+
105+
# ============================================================
106+
# Test
107+
# ============================================================
108+
109+
if __name__ == "__main__":
110+
a = [1.0, 2.0, 3.0]
111+
b = [4.0, 5.0, 6.0]
112+
c = [10.0, 10.0, 10.0]
113+
114+
print("mul:", vec_elem_mul(a, b))
115+
print("add:", vec_elem_add(a, b))
116+
print("fma:", vec_elem_fma(a, b, c))
117+

code/l3-3-vec.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <stdio.h>
2+
3+
void vec_elem_mul(double* a, double* b, double* c, int n) {
4+
for (int i = 0; i < n; i++) {
5+
c[i] = a[i] * b[i];
6+
}
7+
}
8+
9+
void vec_elem_add(double* a, double* b, double* c, int n) {
10+
for (int i=0; i<n; i++) {
11+
c[i] = a[i] + b[i];
12+
}
13+
}

notes/L3-Python与native交互.pdf

426 KB
Binary file not shown.

0 commit comments

Comments
 (0)