Update.
[mygptrnn.git] / memload.py
1 #!/usr/bin/env python
2
3 import torch
4
5 from setuptools import setup
6 from torch.utils.cpp_extension import BuildExtension, CppExtension
7
8 cpp_source = """
9 std::vector<torch::Tensor> greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) {
10   auto nb_lines = load_start.size(1);
11   auto batch_size = line_requests.size(0);
12   auto nb_heads = line_requests.size(1);
13   auto T = line_requests.size(2);
14
15   auto load_start_a = load_start.accessor<float,2>();
16   auto line_requests_a = line_requests.accessor<float,3>();
17
18   auto load = torch::empty({batch_size, nb_lines, T});
19   auto load_a = load.accessor<float,3>();
20
21   auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64));
22   auto allocation_result_a = allocation_result.accessor<long,3>();
23
24   for(int n = 0; n < batch_size; n++) {
25     for(int t = 0; t < T; t++) {
26       for(int l = 0; l < nb_lines; l++) {
27         if(t == 0) {
28           load[n][l][t] = decay * load_start_a[n][l];
29         } else {
30           load[n][l][t] = decay * load[n][l][t-1];
31         }
32       }
33       for(int h = 0; h < nb_heads; h++) {
34         if(line_requests_a[n][h][t] > 0) {
35           int l_lowest_load;
36           for(int l = 0; l < nb_lines; l++) {
37             if(l == 0 || load_a[n][l][t]<load_a[n][l_lowest_load][t]) l_lowest_load=l;
38           }
39           if(load_a[n][l_lowest_load][t] < line_requests_a[n][h][t]) {
40             allocation_result_a[n][h][t] = l_lowest_load;
41             load_a[n][l_lowest_load][t] = line_requests_a[n][h][t];
42           } else {
43             allocation_result_a[n][h][t] = -1;
44           }
45         } else {
46           allocation_result_a[n][h][t] = -1;
47         }
48       }
49     }
50   }
51
52   return {allocation_result,load};
53 }
54 """
55
56 ######################################################################
57
58 allocator_module = torch.utils.cpp_extension.load_inline(
59     name="allocator_module",
60     cpp_sources=[cpp_source],
61     functions=["greedy_lines_allocation"],
62     build_directory="/tmp/",
63     # verbose=True,
64 )
65
66 lines_allocation = allocator_module.greedy_lines_allocation
67
68 ######################################################################
69
70 if __name__ == "__main__":
71     N, H, L, T = 1, 1, 3, 20
72
73     load_start = torch.rand(N, L)
74     requests = (2 * torch.rand(N, H, T) - 1).clamp(min=0)
75
76     print("load_start", load_start)
77
78     print("requests", requests)
79
80     alloc, load = lines_allocation(load_start, 0.99, requests)
81
82     print("alloc", alloc)
83
84     print("load", load)