> [!tldr] GAM Variable Effect
> Plots the variable's effect on the linear predictor in a [[Generalized Linear Models|GLM]] or [[Generalized Additive Models|GAM]] model fitted by the `glum` package.
```python
def plot_variable_effect(model, data, feature, interaction=None):
sns.rugplot(data = data, x = feature, hue = interaction, alpha = 0.2)
if interaction:
data = data.drop_duplicates([feature, interaction]).sort_values(feature)
else:
data = data.drop_duplicates(feature).sort_values(feature)
model_matrix = model.X_model_spec_.get_model_matrix(data).toarray()
transformed_features = pd.DataFrame(model_matrix, columns=model.feature_names_)
related_cols = [col for col in transformed_features.columns if feature in col]
coefs = model.coef_table().loc[related_cols]
fitted_relationship = (transformed_features[related_cols] * coefs).sum(axis = 1)
plot_data = pd.DataFrame.from_dict({
feature: data[feature].to_numpy(),
'fitted_relationship': fitted_relationship,
}, orient='columns')
plot_data.sort_values(feature, inplace=True)
if interaction:
plot_data[interaction] = data[interaction].to_numpy()
else:
plot_data[interaction] = 1
sns.lineplot(data = plot_data, x = feature, y = fitted_relationship, hue = interaction)
```