A straight forward solution using O(mn) space is probably a bad idea.
A simple improvement uses O(m + n) space, but still not the best solution.
Could you devise a constant space solution?
// Constant space method
void setZeroes(vector<vector<int>>& matrix) { // time: O(m * n); space: O(1)
// Use the 1st row to record each column, and
// the 1st column to record each row
// matrix[0][0] is overlapped, representing the row_0
// Need another variable col_0 to record the first column
int col_0 = 1, m = matrix.size(), n = matrix[0].size();
// Loop through all elements to set states in the first row and the first column
for (int i = 0; i < m; ++i) {
if (matrix[i][0] == 0) col_0 = 0;
for (int j = 1; j < n; ++j) {
if (matrix[i][j] == 0) matrix[i][0] = matrix[0][j] = 0;
}
}
// Use the first row and column to set zero from bottom right
for (int i = m - 1; i >= 0; --i) {
for (int j = n - 1; j >= 1; --j) {
if (matrix[i][0] == 0 || matrix[0][j] == 0) matrix[i][j] = 0;
}
if (!col_0) matrix[i][0] = 0;
}
}
// Constant space method
void setZeroes(vector<vector<int>>& matrix) { // time: O(m * n); space: O(1)
// Input validation
if (matrix.empty() || matrix[0].empty()) return;
// Use the 1st row to record each column, and
// the 1st column to record each row
// matrix[0][0] is not used to store any info
// Need another variable row_0 and col_0 to record the first row and column
int row_0 = 1, col_0 = 1, m = matrix.size(), n = matrix[0].size();
// Loop through all elements to set states in the first row and the first column
for (int i = 0; i < m; ++i) {
if (matrix[i][0] == 0) col_0 = 0;
}
for (int j = 0; j < n; ++j) {
if (matrix[0][j] == 0) row_0 = 0;
}
for (int i = 1; i < m; ++i) {
for (int j = 1; j < n; ++j) {
if (matrix[i][j] == 0)
matrix[i][0] = matrix[0][j] = 0;
}
}
// Use the first row and column to set zero
for (int i = 1; i < m; ++i) {
for (int j = 1; j < n; ++j) {
if (matrix[i][0] == 0 || matrix[0][j] == 0)
matrix[i][j] = 0;
}
}
if (!row_0) {
for (int j = 0; j < n; ++j) matrix[0][j] = 0;
}
if (!col_0) {
for (int i = 0; i < m; ++i) matrix[i][0] = 0;
}
}