Expression

$
h(\mathbf{x}) = \max(\mathbf{W}_1 \mathbf{x} + \mathbf{b}_1, \mathbf{W}_2 \mathbf{x} + \mathbf{b}_2, \ldots, \mathbf{W}_k \mathbf{x} + \mathbf{b}_k)
$

1. Enhanced Expressive Power

Maxout’s capability to approximate any convex function grants neural networks a significant degree of flexibility and expressive power. This means Maxout units can learn anything from simple linear responses to very complex nonlinear patterns. This capability is particularly useful in applications where the decision boundaries are complex or when the data distribution is highly variable.

2. Comparison with ReLU

Compared to ReLU (Rectified Linear Unit), Maxout offers a broader range of functionalities. ReLU is a simple yet highly effective activation function defined as $f(x)=max(0,x)$. Its main advantages include computational simplicity and mitigation of the vanishing gradient problem. However, ReLU is single-sided active, meaning it only activates for positive inputs. In contrast, Maxout can adapt to both positive and negative changes in inputs, providing a more complex nonlinear response.

3. Trade-offs in Practical Applications

While Maxout offers superior theoretical performance, it also brings higher parameter burden (multiple sets of weights per neuron), which can lead to higher computational costs and increased risk of overfitting. Therefore, the choice of activation function in practice often involves a trade-off among expressive power, computational efficiency, and ease of use.

Reference

Maxout Activation Function

Idea

This post includes some different problems I encountered during the training process of multi-class classification problems using PyTorch. It is used to remind me of some concepts and issues handling methods might happen again in the future.

Code

Create the data with preprocessing

During the preprocessing, we need to notice that the y_blob is assigned to be LongTensor because in PyTorch, when using the nn.CrossEntropyLoss for computing the loss, the target tensor (label) must be of type torch.long. This is because the loss function expects the target tensor to contain class indices as long integer to deal with large range of classification labels. torch.nn.CrossEntropyLoss require label tensor to be LongTensor.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split

device = "cuda" if torch.cuda.is_available() else "cpu"

NUM_CLASSES = 4
NUM_FEATURES = 2
RANDOM_SEED = 42

# create multiclass data
X_blob, y_blob = make_blobs(n_samples = 1000,
n_features=NUM_FEATURES,
centers=NUM_CLASSES,
cluster_std=1.5,
random_state=RANDOM_SEED)

# transform from numpy arrays to tensors
X_blob = torch.from_numpy(X_blob).type(torch.float)
y_blob = torch.from_numpy(y_blob).type(torch.LongTensor) # must be long type because loss functions do not accept float indices

# split the data
X_blob_train, X_blob_test, y_blob_train, y_blob_test = train_test_split(X_blob,
y_blob,
test_size=0.2,
random_state=RANDOM_SEED)

# plot the data
plt.figure(figsize=(10, 7))
plt.scatter(X_blob[:, 0], X_blob[:, 1], c=y_blob, cmap=plt.cm.RdYlBu)

Build the model

We can define the constructor to have multiple parameters explicitly, but only the input_features is needed during the training because forward function takes only one parameter.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class BlobModel(nn.Module):
def __init__(self, input_features, output_features, hidden_units=8):
super().__init__()
self.linear_layer_stack = nn.Sequential(
nn.Linear(in_features=input_features, out_features=hidden_units),
nn.Linear(in_features=hidden_units, out_features=hidden_units),
nn.Linear(in_features=hidden_units, out_features=output_features),
)

def forward(self, x):
return self.linear_layer_stack(x)

model_4 = BlobModel(input_features=NUM_FEATURES,
output_features=NUM_CLASSES,
hidden_units=8).to(device)

Define loss function and optimizer

1
2
3
4
5
# CrossEntropyLoss is probably the only choice for multi-classification problem
loss_fn = nn.CrossEntropyLoss()

# the most common optimizers are SGD and Adam
optimizer = torch.optim.SGD(params=model_4.parameters(), lr=0.01)

Train the model

Note here, the nn.CrossEntropyLoss() only accepts the logits input (which means it does not want the value after softmax). However, we still have a y_pred after softmax because we need it to calcualte the accuracy.

ALso note very important thing here, dim=1 means we want to calculate the metrics by rows, based on columns, which means our softmax and argmax function are all getting the results from each row, and doing calculation based on the columns. dim=1 literally stands for “given the row not changed, get the result from different columns in that row”.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
torch.manual_seed(42)
torch.cuda.manual_seed(42)

X_blob_train, X_blob_test = X_blob_train.to(device), X_blob_test.to(device)
y_blob_train, y_blob_test = y_blob_train.to(device), y_blob_test.to(device)

epochs = 1000

for epoch in range(epochs):
model_4.train()

y_logits = model_4(X_blob_train)
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # note here

loss = loss_fn(y_logits, y_blob_train)
acc = accuracy_fn(y_true=y_blob_train, y_pred=y_pred)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# test
model_4.eval()
with torch.inference_mode():
test_logits = model_4(X_blob_test)
test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1) # note here

test_loss = loss_fn(test_logits, y_blob_test)
acc = accuracy_fn(y_true=y_blob_test, y_pred=test_pred)

if epoch % 100 == 0:
print(f"Epoch: {epoch} | Loss: {loss:.4f}, Acc: {acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

Evaluate the model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
model_4.eval()
with torch.inference_mode():
y_logits = model_4(X_blob_test)

# remember to manually activate the logits by applying softmax and argmax
y_pred_probs = torch.softmax(y_logits, dim=1)
y_preds = torch.argmax(y_pred_probs, dim=1)

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title("Train")
plot_decision_boundary(model_4, X_blob_train, y_blob_train)
plt.subplot(1,2,2)
plt.title("Test")
plot_decision_boundary(model_4, X_blob_test, y_blob_test)

Problems

LC39: Combination Sum I

In this problem, we do not have the limitation on using the single element multiple times, therefore we can use the backtracking algorithm and start next recursion from the index we are at each step.

Note that we should not do the recursion on all the elements in candidate array because we don’t want to output the same solution in any order e.g. [2,3,3], [3,2,3], [3,3,2], therefore we should not move backward the index.

Note that the candidates array has only distinct elements, therefore we will not count same solution in different order multiple times because of the duplicates in the array.

Also note that the candidates array is not sorted so we can’t prune the solution after we find at any index the sum already exceeds the target but we should keep traverse the whole candidate array. And it reminds me of sorting the array at very first place.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
public List<List<Integer>> combinationSum(int[] candidates, int target) {
List<List<Integer>> ans = new ArrayList<>();
dfs(ans, new LinkedList<Integer>(), candidates, target, 0);
return ans;
}

public void dfs(List<List<Integer>> ans, List<Integer> temp,
int[] candidates,int remaining, int start) {
// base case
if (remaining == 0) {
ans.add(new LinkedList<>(temp));
return;
} else if (remaining < 0) {
return;
}

for (int i = start; i < candidates.length; i++) {
temp.add(candidates[i]);
dfs(ans, temp, candidates, remaining-candidates[i], i);
temp.removeLast();
}
}
}

// time: the time complextity is not fixed, O(n * 2^n) is the worst case where all the combinations are considered
// however, if we consider the search tree itself, it is O(S), S stands for the valid solutions tree node sum

// space: O(target) the longest valid target solution level

If we sort the array beforehand, then we can prune the search tree which would imporve our runtime overhead. Note that the time complexity won’t change, but the algorithm is improved for sure.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public void dfs(List<List<Integer>> ans, List<Integer> temp,
int[] candidates, int remaining, int start) {
if (remaining == 0) {
ans.add(new ArrayList<>(temp));
}

for (int i = start; i < candidates.length; ++i) {
if (remaining - candidates[i] < 0) break;
temp.add(candidates[i]);
dfs(ans, temp, candidates, remaining-candidates[i], i);
temp.removeLast();
}
}

// we can sort the candidates array by using Arrays.sort(candidates) in the main function
// here we can check if candidates[i] exceeds the limit, if it does, we break and return to the previous level of the recursions
// it saves some overhead both in time and space but won't save the algotithm from the worst case

LC40: Combination Sum II

From the previous problem, we found that pruning can only happens when the candidates array is sorted.

LC216: Combination Sum III

Find all valid combinations of k numbers that sum up to n such that the following conditions are true:

  • Only numbers 1 through 9 are used.
  • Each number is used at most once.

Return a list of all possible valid combinations. The list must not contain the same combination twice.

Solution

We can use depth-first-search to search to the deepest element that is possible to sum up to the target and backtracking all the possible combinations. We loop from 1 to 9 and each time we add one single number to the sum and do the check. We recursively try every combinations

  • Base case: if sum == target and the count of elements are equal to k, add the NEW temp list to the answer list

Takeaways

  • When we add the temp list to the answer, we should be aware of the reference copy here.
1
2
3
4
5
6
// add the reference to temp to the ans list
// this might have a problem when we delete something further in another level in recursion
ans.add(temp);

// add the reference to a new list copied from temp list
ans.add(new ArrayList<>(temp));
  • If we define the private answer list out side the main function, there might be safety problems. For example, if we have 2 thread sharing the same object, even if the ans list is private, there still be a thread safety problem.
1
2
3
4
class Solution {
private List<List<Integer>> ans = new ArrayList<>();

public List<List<Integer>> combinationSum3(int k, int n) {
1
2
3
4
5
6
7
Solution solution = new Solution();

// Thread A
new Thread(() -> solution.combinationSum3(k1, n1)).start();

// Thread B
new Thread(() -> solution.combinationSum3(k2, n2)).start();
  • For LinkedList, ArrayList and ArrayDeque they are different:
    LinkedList can do removeLast() directly, and add and delete from both head and tail are O(1)
    ArrayList does not provide removeLast() method, but we can do remove(size() - 1) to similarly, add and delete at head is O(n)
    ArrayDeque is dynamic double-ended queue, not as efficient as ArrayList in searching but provide all searching, inserting and deleting at O(1)

Previously

This problem is the more complex version of LC102: Binary Tree Level Order Traversel. In the previous version, we can only use a Queue to implement a FIFO order traversal to solve the problem:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution {
public List<List<Integer>> levelOrder(TreeNode root) {
if (root == null) {
return new ArrayList<>();
}

List<List<Integer>> ans = new ArrayList<>();
Queue<TreeNode> queue = new LinkedList<>();
queue.offer(root);

while (!queue.isEmpty()) {
List<Integer> lst = new ArrayList<>();
int size = queue.size(); // store how many nodes in each level
while (size > 0) {
TreeNode curr = queue.poll();
lst.add(curr.val);
if (curr.left != null) {
queue.add(curr.left);
}
if (curr.right != null) {
queue.add(curr.right);
}
size--;
}
ans.add(lst);
}
return ans;
}
}

// time: O(n)
// space: O(n) (O(k) actually, k is for the most amount of nodes in each level, worst case n)

Solution

For this problem where the zigzag traverse is required, we have 2 ways to solve it, either retrieve the nodes in a reversed order when level is odd (0 is the first level) or retrieve the node normally but reverse the list. It turns out that both can be done in the same time complexity. What’s more, the first method can be implemented in 2 different ways as well.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
public List<List<Integer>> zigzagLevelOrder(TreeNode root) {
if (root == null) {
return new ArrayList<>();
}

List<List<Integer>> ans = new ArrayList<>();
Queue<TreeNode> queue = new LinkedList<>();
queue.offer(root);
boolean oddLevel = false; // 0 is first level

while (!queue.isEmpty()) {
List<Integer> lst = new LinkedList<>();
int size = queue.size();
while (size > 0) {
TreeNode curr = queue.poll();
if (oddLevel) {
lst.addFirst(curr.val);
} else {
lst.add(curr.val); // default add is addLast
}

if (curr.left != null) {
queue.add(curr.left);
}
if (curr.right != null) {
queue.add(curr.right);
}
size--;
}
ans.add(lst);
oddLevel = !oddLevel;
}
return ans;
}
}

// time: O(n)
// space: O(n) (O(k) actually, same idea with above)

Here we can also use the Collections.reverse(List<E> list) to reverse the List, but the time complexity for the reverse would be O(n) which will cause the total time to be O(n^2).

Takeaways

  • LinkedList has method addFirst(E) and addLast(E) with default add(E) is addLast(E) and the time complexity both are O(n).
  • In Java, we can’t directly get the node object if we are provided List<E> lst no matter it’s linkedlist or arraylist.
  • Collections.reverse(List<E> list) time complexity is O(n).

Previously

Some other problems are involved, for example:

Problems

LC143

Solution

To make it work, we should find the middle node and split the linkedlist into 2 halves. And then we reverse the second half and get the new head (which was the tail). Finally we merge 2 lists.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Solution {
public void reorderList(ListNode head) {
if (head == null || head.next == null) {
return;
}

ListNode midNode = findMid(head);
ListNode midNext = midNode.next;
midNode.next = null;
ListNode revHead = reverseList(midNext);
head = mergeList(head, revHead);
}

// findMid method uses 2 pointers and runs through the list once O(n)
public ListNode findMid(ListNode head) {
ListNode slow = head;
ListNode fast = head;
while (fast.next != null && fast.next.next != null) {
slow = slow.next;
fast = fast.next.next;
}
return slow;
}

// reverseList runs throught the list once O(n)
public ListNode reverseList(ListNode head) {
ListNode prev = null;
ListNode curr = head;
ListNode next = curr.next;
while (next != null) {
curr.next = prev;
prev = curr;
curr = next;
next = next.next;
}
curr.next = prev;
return curr;
}

// mergeList merges 2 half list and takes O(n)
public ListNode mergeList(ListNode one, ListNode two) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
while (two != null) {
curr.next = one;
one = one.next;
curr.next.next = two;
two = two.next;
curr = curr.next.next;
}
if (one != null) {
curr.next = one;
}
return dummy.next;
}
}

// time: O(n)
// space: O(1)

Takeaways

  • Don’t forget to remove the connection between the first half and second half of the linkedlist midNode.next = null, if we forget to remove it, the linkedList will have a cycle.
  • The second half will always second.size() <= first.size() because we choose ListNode midNext = midNode.next, therefore when we merge 2 lists together, we can only check while (two != null).
  • Dummy node is useful here, it makes the code uniform in each step (otherwise the first step is different from the following steps).

Syntax and Module Loading

CommonJS modules are synchronously loaded, meaning that the module files are loaded and parsed during the runtime as the code executes. This approach is well-suited for server-side environments where files are typically locally available and can be loaded quickly.

ES6 modules are designed to support asynchronous loading, allowing modules to be loaded over the network. This feature is advantageous in browser environments, enabling scripts to be loaded in parallel while the page loads.

In CommonJS, the server should load the module in order and then execute. Since files are stored locally on the server, the process won’t take long usually.

1
2
3
4
5
6
7
8
9
// export module
module.exports = {
add: function(a, b) {
return a + b;
},
subtract: function(a, b) {
return a - b;
}
};
1
2
3
// import module
const math = require('./math');
console.log(math.add(1, 2)); // output: 3

Imagine you are browsing a webpage with complex frontend, ES6 enables asynchronous loading, which means before some slow modules being loaded thorouly, you can see and interact with modules that are already loaded.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import moduleA from 'moduleA';
// default export
export function add(a, b) {
return a + b;
}

export function subtract(a, b) {
return a - b;
}

export default class Math {
constructor() {}
multiply(a, b) {
return a * b;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
// app.js
// import modules
import { add, subtract } from './math';
console.log(add(3, 2)); // output: 5

// import all
import * as math from './math';
console.log(math.subtract(5, 2)); // output: 3

// import default
import Math from './math';
const mathInstance = new Math();
console.log(mathInstance.multiply(3, 2)); // output: 6

Design Philosophy

CommonJS was designed with server-side applications in mind, where modules are loaded and parsed as needed.

ES6 modules are designed to allow static analysis at compile time, supporting static optimizations and more complex import/export patterns, such as partial imports (tree shaking) and dynamic imports.

Interoperability

In modern JavaScript development, Node.js environments have started to support ES6 module syntax, but this typically requires specific configuration (such as using the .mjs file extension or setting “type”: “module” in “package.json”). This enables the use of ES6 module syntax in Node.js, while also supporting the import of CommonJS modules.

Nowadays, ES6 has been widely used on server side too. It depends on the circumstances to make choice between them.

Button

If we want to quote another post, we can use a button.

Previous Chapter Next Chapter

Mermaid

Visit https://github.com/mermaid-js/mermaid to check the usage.

graph TD
A[Hard] -->|Text| B(Round)
B --> C{Decision}
C -->|One| D[Result 1]
C -->|Two| E[Result 2]

This is Tab 1.

1
2
3
public class Tab1 {

}

This is Tab 2.

1
2
3
public class Tab2 {

}

This is Tab 3.

1
2
3
public class Tab3 {

}

This is Tab 4.

1
2
3
public class Tab4 {

}

Math Equations

Check for math equations: https://theme-next.js.org/docs/third-party-services/math-equations

Test the math equations:
$$\begin{equation} \label{eq1}
e=mc^2
\end{equation}$$

0%