Maxout’s capability to approximate any convex function grants neural networks a significant degree of flexibility and expressive power. This means Maxout units can learn anything from simple linear responses to very complex nonlinear patterns. This capability is particularly useful in applications where the decision boundaries are complex or when the data distribution is highly variable.
2. Comparison with ReLU
Compared to ReLU (Rectified Linear Unit), Maxout offers a broader range of functionalities. ReLU is a simple yet highly effective activation function defined as $f(x)=max(0,x)$. Its main advantages include computational simplicity and mitigation of the vanishing gradient problem. However, ReLU is single-sided active, meaning it only activates for positive inputs. In contrast, Maxout can adapt to both positive and negative changes in inputs, providing a more complex nonlinear response.
3. Trade-offs in Practical Applications
While Maxout offers superior theoretical performance, it also brings higher parameter burden (multiple sets of weights per neuron), which can lead to higher computational costs and increased risk of overfitting. Therefore, the choice of activation function in practice often involves a trade-off among expressive power, computational efficiency, and ease of use.
This post includes some different problems I encountered during the training process of multi-class classification problems using PyTorch. It is used to remind me of some concepts and issues handling methods might happen again in the future.
Code
Create the data with preprocessing
During the preprocessing, we need to notice that the y_blob is assigned to be LongTensor because in PyTorch, when using the nn.CrossEntropyLoss for computing the loss, the target tensor (label) must be of type torch.long. This is because the loss function expects the target tensor to contain class indices as long integer to deal with large range of classification labels. torch.nn.CrossEntropyLoss require label tensor to be LongTensor.
# transform from numpy arrays to tensors X_blob = torch.from_numpy(X_blob).type(torch.float) y_blob = torch.from_numpy(y_blob).type(torch.LongTensor) # must be long type because loss functions do not accept float indices
# split the data X_blob_train, X_blob_test, y_blob_train, y_blob_test = train_test_split(X_blob, y_blob, test_size=0.2, random_state=RANDOM_SEED)
# plot the data plt.figure(figsize=(10, 7)) plt.scatter(X_blob[:, 0], X_blob[:, 1], c=y_blob, cmap=plt.cm.RdYlBu)
Build the model
We can define the constructor to have multiple parameters explicitly, but only the input_features is needed during the training because forward function takes only one parameter.
# CrossEntropyLoss is probably the only choice for multi-classification problem loss_fn = nn.CrossEntropyLoss()
# the most common optimizers are SGD and Adam optimizer = torch.optim.SGD(params=model_4.parameters(), lr=0.01)
Train the model
Note here, the nn.CrossEntropyLoss() only accepts the logits input (which means it does not want the value after softmax). However, we still have a y_pred after softmax because we need it to calcualte the accuracy.
ALso note very important thing here, dim=1 means we want to calculate the metrics by rows, based on columns, which means our softmax and argmax function are all getting the results from each row, and doing calculation based on the columns. dim=1 literally stands for “given the row not changed, get the result from different columns in that row”.
# test model_4.eval() with torch.inference_mode(): test_logits = model_4(X_blob_test) test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1) # note here
if epoch % 100 == 0: print(f"Epoch: {epoch} | Loss: {loss:.4f}, Acc: {acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
Evaluate the model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
model_4.eval() with torch.inference_mode(): y_logits = model_4(X_blob_test)
# remember to manually activate the logits by applying softmax and argmax y_pred_probs = torch.softmax(y_logits, dim=1) y_preds = torch.argmax(y_pred_probs, dim=1)
In this problem, we do not have the limitation on using the single element multiple times, therefore we can use the backtracking algorithm and start next recursion from the index we are at each step.
Note that we should not do the recursion on all the elements in candidate array because we don’t want to output the same solution in any order e.g. [2,3,3], [3,2,3], [3,3,2], therefore we should not move backward the index.
Note that the candidates array has only distinct elements, therefore we will not count same solution in different order multiple times because of the duplicates in the array.
Also note that the candidates array is not sorted so we can’t prune the solution after we find at any index the sum already exceeds the target but we should keep traverse the whole candidate array. And it reminds me of sorting the array at very first place.
classSolution { public List<List<Integer>> combinationSum(int[] candidates, int target) { List<List<Integer>> ans = newArrayList<>(); dfs(ans, newLinkedList<Integer>(), candidates, target, 0); return ans; }
publicvoiddfs(List<List<Integer>> ans, List<Integer> temp, int[] candidates,int remaining, int start) { // base case if (remaining == 0) { ans.add(newLinkedList<>(temp)); return; } elseif (remaining < 0) { return; }
for (inti= start; i < candidates.length; i++) { temp.add(candidates[i]); dfs(ans, temp, candidates, remaining-candidates[i], i); temp.removeLast(); } } }
// time: the time complextity is not fixed, O(n * 2^n) is the worst case where all the combinations are considered // however, if we consider the search tree itself, it is O(S), S stands for the valid solutions tree node sum
// space: O(target) the longest valid target solution level
If we sort the array beforehand, then we can prune the search tree which would imporve our runtime overhead. Note that the time complexity won’t change, but the algorithm is improved for sure.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
publicvoiddfs(List<List<Integer>> ans, List<Integer> temp, int[] candidates, int remaining, int start) { if (remaining == 0) { ans.add(newArrayList<>(temp)); }
for (inti= start; i < candidates.length; ++i) { if (remaining - candidates[i] < 0) break; temp.add(candidates[i]); dfs(ans, temp, candidates, remaining-candidates[i], i); temp.removeLast(); } }
// we can sort the candidates array by using Arrays.sort(candidates) in the main function // here we can check if candidates[i] exceeds the limit, if it does, we break and return to the previous level of the recursions // it saves some overhead both in time and space but won't save the algotithm from the worst case
From the previous problem, we found that pruning can only happens when the candidates array is sorted.
LC216: Combination Sum III
Find all valid combinations of k numbers that sum up to n such that the following conditions are true:
Only numbers 1 through 9 are used.
Each number is used at most once.
Return a list of all possible valid combinations. The list must not contain the same combination twice.
Solution
We can use depth-first-search to search to the deepest element that is possible to sum up to the target and backtracking all the possible combinations. We loop from 1 to 9 and each time we add one single number to the sum and do the check. We recursively try every combinations
Base case: if sum == target and the count of elements are equal to k, add the NEW temp list to the answer list
Takeaways
When we add the temp list to the answer, we should be aware of the reference copy here.
1 2 3 4 5 6
// add the reference to temp to the ans list // this might have a problem when we delete something further in another level in recursion ans.add(temp);
// add the reference to a new list copied from temp list ans.add(newArrayList<>(temp));
If we define the private answer list out side the main function, there might be safety problems. For example, if we have 2 thread sharing the same object, even if the ans list is private, there still be a thread safety problem.
1 2 3 4
classSolution { private List<List<Integer>> ans = newArrayList<>();
public List<List<Integer>> combinationSum3(int k, int n) {
1 2 3 4 5 6 7
Solutionsolution=newSolution();
// Thread A newThread(() -> solution.combinationSum3(k1, n1)).start();
// Thread B newThread(() -> solution.combinationSum3(k2, n2)).start();
For LinkedList, ArrayList and ArrayDeque they are different: LinkedList can do removeLast() directly, and add and delete from both head and tail are O(1) ArrayList does not provide removeLast() method, but we can do remove(size() - 1) to similarly, add and delete at head is O(n) ArrayDeque is dynamic double-ended queue, not as efficient as ArrayList in searching but provide all searching, inserting and deleting at O(1)
This problem is the more complex version of LC102: Binary Tree Level Order Traversel. In the previous version, we can only use a Queue to implement a FIFO order traversal to solve the problem:
classSolution { public List<List<Integer>> levelOrder(TreeNode root) { if (root == null) { returnnewArrayList<>(); }
List<List<Integer>> ans = newArrayList<>(); Queue<TreeNode> queue = newLinkedList<>(); queue.offer(root);
while (!queue.isEmpty()) { List<Integer> lst = newArrayList<>(); intsize= queue.size(); // store how many nodes in each level while (size > 0) { TreeNodecurr= queue.poll(); lst.add(curr.val); if (curr.left != null) { queue.add(curr.left); } if (curr.right != null) { queue.add(curr.right); } size--; } ans.add(lst); } return ans; } }
// time: O(n) // space: O(n) (O(k) actually, k is for the most amount of nodes in each level, worst case n)
Solution
For this problem where the zigzag traverse is required, we have 2 ways to solve it, either retrieve the nodes in a reversed order when level is odd (0 is the first level) or retrieve the node normally but reverse the list. It turns out that both can be done in the same time complexity. What’s more, the first method can be implemented in 2 different ways as well.
// time: O(n) // space: O(n) (O(k) actually, same idea with above)
Here we can also use the Collections.reverse(List<E> list) to reverse the List, but the time complexity for the reverse would be O(n) which will cause the total time to be O(n^2).
Takeaways
LinkedList has method addFirst(E) and addLast(E) with default add(E) is addLast(E) and the time complexity both are O(n).
In Java, we can’t directly get the node object if we are provided List<E> lst no matter it’s linkedlist or arraylist.
Collections.reverse(List<E> list) time complexity is O(n).
To make it work, we should find the middle node and split the linkedlist into 2 halves. And then we reverse the second half and get the new head (which was the tail). Finally we merge 2 lists.
// findMid method uses 2 pointers and runs through the list once O(n) public ListNode findMid(ListNode head) { ListNodeslow= head; ListNodefast= head; while (fast.next != null && fast.next.next != null) { slow = slow.next; fast = fast.next.next; } return slow; }
// reverseList runs throught the list once O(n) public ListNode reverseList(ListNode head) { ListNodeprev=null; ListNodecurr= head; ListNodenext= curr.next; while (next != null) { curr.next = prev; prev = curr; curr = next; next = next.next; } curr.next = prev; return curr; }
// mergeList merges 2 half list and takes O(n) public ListNode mergeList(ListNode one, ListNode two) { ListNodedummy=newListNode(0); ListNodecurr= dummy; while (two != null) { curr.next = one; one = one.next; curr.next.next = two; two = two.next; curr = curr.next.next; } if (one != null) { curr.next = one; } return dummy.next; } }
// time: O(n) // space: O(1)
Takeaways
Don’t forget to remove the connection between the first half and second half of the linkedlist midNode.next = null, if we forget to remove it, the linkedList will have a cycle.
The second half will always second.size() <= first.size() because we choose ListNode midNext = midNode.next, therefore when we merge 2 lists together, we can only check while (two != null).
Dummy node is useful here, it makes the code uniform in each step (otherwise the first step is different from the following steps).
CommonJS modules are synchronously loaded, meaning that the module files are loaded and parsed during the runtime as the code executes. This approach is well-suited for server-side environments where files are typically locally available and can be loaded quickly.
ES6 modules are designed to support asynchronous loading, allowing modules to be loaded over the network. This feature is advantageous in browser environments, enabling scripts to be loaded in parallel while the page loads.
In CommonJS, the server should load the module in order and then execute. Since files are stored locally on the server, the process won’t take long usually.
1 2 3 4 5 6 7 8 9
// export module module.exports = { add: function(a, b) { return a + b; }, subtract: function(a, b) { return a - b; } };
Imagine you are browsing a webpage with complex frontend, ES6 enables asynchronous loading, which means before some slow modules being loaded thorouly, you can see and interact with modules that are already loaded.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import moduleA from'moduleA'; // default export exportfunctionadd(a, b) { return a + b; }
exportfunctionsubtract(a, b) { return a - b; }
exportdefaultclassMath { constructor() {} multiply(a, b) { return a * b; } }
CommonJS was designed with server-side applications in mind, where modules are loaded and parsed as needed.
ES6 modules are designed to allow static analysis at compile time, supporting static optimizations and more complex import/export patterns, such as partial imports (tree shaking) and dynamic imports.
Interoperability
In modern JavaScript development, Node.js environments have started to support ES6 module syntax, but this typically requires specific configuration (such as using the .mjs file extension or setting “type”: “module” in “package.json”). This enables the use of ES6 module syntax in Node.js, while also supporting the import of CommonJS modules.
Nowadays, ES6 has been widely used on server side too. It depends on the circumstances to make choice between them.