-
Notifications
You must be signed in to change notification settings - Fork 0
/
operation.cpp
102 lines (78 loc) · 2.08 KB
/
operation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <math.h>
#include "operation.h"
Variable::Variable(double _value) : value(_value) {}
void Variable::forward() {
this->result = this->value;
}
void Variable::backward() {}
void Variable::set(double value) {
this->value = value;
}
void Add::forward() {
this->result = (this->ins[0]->result) + (this->ins[1]->result);
}
void Add::backward() {
this->ins[0]->grad += this->grad;
this->ins[1]->grad += this->grad;
}
void Mul::forward() {
this->result = (this->ins[0]->result) * (this->ins[1]->result);
}
void Mul::backward() {
this->ins[0]->grad += (this->grad) * (this->ins[1]->result);
this->ins[1]->grad += (this->grad) * (this->ins[0]->result);
}
void Div::forward() {
this->result = (this->ins[0]->result) / (this->ins[1]->result);
}
void Div::backward() {
double &x = this->ins[0]->result;
double &y = this->ins[1]->result;
this->ins[0]->grad += (this->grad) / y;
this->ins[1]->grad += (this->grad) * - x / (y * y);
}
void Log::forward() {
this->result = log(this->ins[0]->result);
}
void Log::backward() {
this->ins[0]->grad += (this->grad) / (this->ins[0]->result);
}
void Exp::forward() {
this->result = exp(this->ins[0]->result);
}
void Exp::backward() {
this->ins[0]->grad += (this->grad) * (this->result);
}
void Sin::forward() {
this->result = sin(this->ins[0]->result);
}
void Sin::backward() {
this->ins[0]->grad += (this->grad) * cos(this->ins[0]->result);
}
void Cos::forward() {
this->result = cos(this->ins[0]->result);
}
void Cos::backward() {
this->ins[0]->grad += (this->grad) * -sin(this->ins[0]->result);
}
void Asin::forward() {
this->result = asin(this->ins[0]->result);
}
void Asin::backward() {
double &x = this->ins[0]->result;
this->ins[0]->grad += (this->grad) / sqrt(1 - x * x);
}
void Acos::forward() {
this->result = acos(this->ins[0]->result);
}
void Acos::backward() {
double &x = this->ins[0]->result;
this->ins[0]->grad += (this->grad) / -sqrt(1 - x * x);
}
void Atan::forward() {
this->result = atan(this->ins[0]->result);
}
void Atan::backward() {
double &x = this->ins[0]->result;
this->ins[0]->grad += (this->grad) / (1 + x * x);
}