VC Dimension: Measuring Hypothesis Complexity
When you train a machine learning model, you choose a set of possible functions — called a hypothesis class — from which your algorithm selects the best fit for the data. However, not all hypothesis classes are equally complex. Some can represent a wide variety of patterns, while others are more limited. This complexity plays a crucial role in how well your model generalizes to unseen data: a more complex hypothesis class can fit the training data more closely, but it also runs a higher risk of overfitting, which means performing poorly on new examples. To analyze and control this trade-off, you need a way to quantify the complexity of a hypothesis class in a mathematically rigorous way.
The VC dimension (Vapnik–Chervonenkis dimension) of a hypothesis class is the largest number of points that can be labeled in all possible ways (that is, "shattered") by hypotheses from that class. To shatter a set of points means that, for every possible assignment of labels to those points, there exists a hypothesis in the class that correctly classifies them.
Suppose your hypotheses are all intervals on the real number line, and your task is to classify points as "inside" or "outside" an interval. With one point, you can certainly label it either way — so one point can be shattered. With two points, you can choose intervals to achieve all four possible labelings. But with three points, you cannot achieve all eight possible labelings with a single interval (for example, labeling the first and third points positive and the middle one negative is impossible). Thus, three points cannot be shattered by intervals, and the VC dimension of intervals on the real line is 2.
Formally, the VC dimension of a hypothesis class H is the largest integer d such that there exists a set of d points that can be shattered by H. If sets of arbitrary size can be shattered, the VC dimension is infinite.
123456789101112131415161718192021222324252627282930313233343536373839import numpy as np import matplotlib.pyplot as plt def can_shatter(points): n = len(points) for labels in range(2 ** n): found = False for i in range(n): for j in range(i, n): interval = (points[i], points[j]) predicted = [(interval[0] <= x <= interval[1]) for x in points] label_bits = [(labels >> k) & 1 for k in range(n)] if predicted == label_bits: found = True break if found: break if not found: return False return True max_points = 5 vc_dims = [] for k in range(1, max_points + 1): # Place k points equally spaced points = np.linspace(0, 1, k) if can_shatter(points): vc_dims.append(1) else: vc_dims.append(0) plt.figure(figsize=(6, 3)) plt.bar(range(1, max_points + 1), vc_dims) plt.xlabel("Number of points") plt.ylabel("Can be shattered (1=True, 0=False)") plt.title("VC Dimension for Intervals on the Real Line") plt.xticks(range(1, max_points + 1)) plt.ylim(-0.1, 1.1) plt.show()
Danke für Ihr Feedback!
Fragen Sie AI
Fragen Sie AI
Fragen Sie alles oder probieren Sie eine der vorgeschlagenen Fragen, um unser Gespräch zu beginnen
What does "shattering" mean in this context?
Can you explain what VC dimension is in simple terms?
Why is the VC dimension important for machine learning models?
Großartig!
Completion Rate verbessert auf 11.11
VC Dimension: Measuring Hypothesis Complexity
Swipe um das Menü anzuzeigen
When you train a machine learning model, you choose a set of possible functions — called a hypothesis class — from which your algorithm selects the best fit for the data. However, not all hypothesis classes are equally complex. Some can represent a wide variety of patterns, while others are more limited. This complexity plays a crucial role in how well your model generalizes to unseen data: a more complex hypothesis class can fit the training data more closely, but it also runs a higher risk of overfitting, which means performing poorly on new examples. To analyze and control this trade-off, you need a way to quantify the complexity of a hypothesis class in a mathematically rigorous way.
The VC dimension (Vapnik–Chervonenkis dimension) of a hypothesis class is the largest number of points that can be labeled in all possible ways (that is, "shattered") by hypotheses from that class. To shatter a set of points means that, for every possible assignment of labels to those points, there exists a hypothesis in the class that correctly classifies them.
Suppose your hypotheses are all intervals on the real number line, and your task is to classify points as "inside" or "outside" an interval. With one point, you can certainly label it either way — so one point can be shattered. With two points, you can choose intervals to achieve all four possible labelings. But with three points, you cannot achieve all eight possible labelings with a single interval (for example, labeling the first and third points positive and the middle one negative is impossible). Thus, three points cannot be shattered by intervals, and the VC dimension of intervals on the real line is 2.
Formally, the VC dimension of a hypothesis class H is the largest integer d such that there exists a set of d points that can be shattered by H. If sets of arbitrary size can be shattered, the VC dimension is infinite.
123456789101112131415161718192021222324252627282930313233343536373839import numpy as np import matplotlib.pyplot as plt def can_shatter(points): n = len(points) for labels in range(2 ** n): found = False for i in range(n): for j in range(i, n): interval = (points[i], points[j]) predicted = [(interval[0] <= x <= interval[1]) for x in points] label_bits = [(labels >> k) & 1 for k in range(n)] if predicted == label_bits: found = True break if found: break if not found: return False return True max_points = 5 vc_dims = [] for k in range(1, max_points + 1): # Place k points equally spaced points = np.linspace(0, 1, k) if can_shatter(points): vc_dims.append(1) else: vc_dims.append(0) plt.figure(figsize=(6, 3)) plt.bar(range(1, max_points + 1), vc_dims) plt.xlabel("Number of points") plt.ylabel("Can be shattered (1=True, 0=False)") plt.title("VC Dimension for Intervals on the Real Line") plt.xticks(range(1, max_points + 1)) plt.ylim(-0.1, 1.1) plt.show()
Danke für Ihr Feedback!