# CounterfactualFairness.jl

CounterfactualFairness.jl is a Julia package that provides an interface for causal inference and counterfactual fairness. This project was completed under the mentorship of Zenna Tavares, Moritz Schauer, Jiahao Chen and Sebastian Vollmer.

Link to the GitHub repository: https://github.com/zenna/CounterfactualFairness.jl

Link to the previous blog post: https://nextjournal.com/archanarw/counterfactual-fairness-blogpost-1

### Brief Walkthrough of the CounterfactualFairness.jl

The package is designed with Pearl's Causal Ladder in mind, thus allows for association (constructing the causal model), interventions and counterfactuals.

#### Importing Required Packages

The required packages for the following demonstration -

CounterfactualFairness (arw branch)

Omega (lang branch)

CausalInference

Distributions (version - 0.25.11)

MLJ

Flux (version - 0.12.6)

For visualization- Plots, GraphPlot, Colors, PrettyPrinting

The packages required to precompile CounterfactualFairness.jl successfully (Since it depends on some unregistered packages, they must also be added) -

InferenceBase

SoftPredicates

ReplicaExchange

OmegaCore

OmegaMH

`using Pkg`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:InferenceBase"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:SoftPredicates"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:ReplicaExchange"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl",rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaCore"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaMH"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang"))`

`Pkg.add(PackageSpec(url="https://github.com/zenna/CounterfactualFairness.jl", rev="arw"))`

`Pkg.add("CausalInference")`

`Pkg.add("Distributions")`

`Pkg.add("MLJ")`

`Pkg.add(name = "Flux", version = "0.12.0")`

`Pkg.add("GraphPlot")`

`Pkg.add("Plots")`

`Pkg.add("PrettyPrinting")`

`Pkg.add("Colors")`

`Pkg.add("DataFrames")`

`using Omega, OmegaCore `

`using DataFrames, CounterfactualFairness, CausalInference`

`using Distributions, MLJ, Flux`

`using GraphPlot, Plots, Colors, PrettyPrinting`

#### Association

Using CounterfactualFairness.jl, you may construct a causal model in the following ways -

Automatically from data -

`prob_causal_graph(df)`

can be used to construct a causal model from the dataframe`df`

(by Gaussian mechanism). The function uses`pcalg`

from the CausalInference.jl is used to construct the causal graph.By loading a causal model from CounterfactualFairness.jl

`cm = ;`

`gplot(dag(cm), nodelabel = ([variable(cm, i).name for i in 1:nv(cm)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = shell_layout, NODESIZE = 0.4/sqrt(nv(cm)))`

By entering the distributions over exogenous variables and functions over endogenous variables)

Consider the causal graph below -

This structure could describe the causal mechanism that connects a day's temperature (X), sales at an ice-cream shop (Y) and number of crimes (Z). We may represent this model using CounterfactualFairness.jl as given below -

`g = CausalModel(); # Empty causal graph`

Adding exogenous variables (variables that have no explicit cause within the model)-

`U₁ = add_exo_variable!(g, :U₁, 1 ~ Normal(24, 8));`

`U₂ = add_exo_variable!(g, :U₂, 1 ~ Normal(15, 3));`

`U₃ = add_exo_variable!(g, :U₃, 1 ~ Normal(2, 1));`

Adding endogenous variables (variables whose values depend on other variables in the model) -

`Temp = add_endo_variable!(g, :Temp, identity, U₁);`

`IceCreamSales = add_endo_variable!(g, :IceCreamSales, *, Temp, U₂);`

`Crime = add_endo_variable!(g, :Crime, /, Temp, U₃);`

Visualizing the graph -

`gplot(dag(g), nodelabel = ([variable(g, i).name for i in 1:nv(g)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(g)))`

To apply a context (values of exogenous variables) to the model, we may use `apply_context`

or use `g(ω)`

where `ω`

is a random variable as defined in `Omega`

.

`apply_context(g, (U₁ = 26.2, U₂ = 14.8, U₃ = 2.)) |> pprint`

`g(defω(), return_type = NamedTuple) |> pprint`

#### Interventions

An intervention is applied to a causal model by fixing the value of a particular variable in the model and modifying the entire model accordingly (by removing all incoming edges to the variable since it now has a fixed value). Interventions may be computed as follows -

(Continuing with the previous example)

`i = CounterfactualFairness.Intervention(:Temp, 24.) # Fixing value of Temp to 24`

`i |> pprint`

`m = apply_intervention(g, i)`

`intervened_model = randsample(ω -> m(ω))`

`intervened_model |> pprint`

`gplot(dag(m), nodelabel = ([variable(m, i).name for i in 1:nv(m)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(m)))`

#### Counterfactuals

inline_formula not implemented

where V contain observed variables.

To obtain counterfactuals, we condition `:Crime`

on the observed values and the intervention inline_formula not implemented.

`count = ω -> counterfactual(:Crime, (IceCreamSales = 340., ), i, g, ω);`

`randsample(count); `

##### Computing counterfactuals using MLJ

Wrapper to compute counterfactuals for each observation in a given dataset:

`toy = # Synthetic causal model`

`cfw = CounterfactualWrapper(test = gausscitest, p = 0.1, cf = :Y, interventions = CounterfactualFairness.Intervention(:A, 40.)) `

Now we may use `cfw`

in `fit`

/`transform`

workflow in MLJ.

#### Training a neural network in a way that the predictor is counterfactually fair

##### Check for Sufficient Condition for Counterfactual Fairness

Lemma 1: Let G be the causal graph of the given model (U, V, F). Then Ŷ will be counterfactually fair if it is a function of the non-descendants of A.

To check for sufficient condition given above, `isNonDesc`

returns true if the condition is satisfied and false if it isn't.

`isNonDesc(g, (:IceCreamSales, U₃), (:Temp,)); # false since Temp is a descendant of IceCreamSales `

`isNonDesc(g, (:U₁, :IceCreamSales), (:Crime,)); # true since Crime is not a descendant of neither U₁ nor IceCreamSales `

##### Training using MLJ Interface

Creating synthetic dataset -

`toy = `

`n = 500`

`X = (CausalVar(toy, :X1), CausalVar(toy, :X2), CausalVar(toy, :X3), CausalVar(toy, :X4))`

`U = (CausalVar(toy, :U₁), CausalVar(toy, :U₂), CausalVar(toy, :U₃), CausalVar(toy, :U₄), CausalVar(toy, :U₅))`

`A = CausalVar(toy, :A)`

`Y = CausalVar(toy, :Y)`

`df = DataFrame(`

` X1 = randsample(ω -> X[1](ω), n),`

` X2 = randsample(ω -> X[2](ω), n), `

` X3 = randsample(ω -> X[3](ω), n), `

` X4 = randsample(ω -> X[4](ω), n),`

` A = randsample(ω -> A(ω), n),`

` Y = randsample(ω -> Y(ω), n)`

`)`

`pprint(df)`

Wrapper for the adversarial learning for counterfactual fairness -

`model = AdversarialWrapper(cm = toy, `

` grp = :A, `

` latent = [:U₁, :U₂, :U₃, :U₄, :U₅], `

` observed = [:X1, :X2, :X3, :X4], `

` predictor = Chain(Dense(4, 3), Dense(3, 2), Dense(2, 1)), `

` adversary = Chain(Dense(5, 3), Dense(3, 2, relu)), `

` loss = Flux.Losses.logitbinarycrossentropy, `

` iters = 5)`

`model |> pprint`

Now, `model`

fits into the same framework as other wrappers in MLJ and can be used the same as others.

Using `fit!`

we can train `model`

and predict using `predict`

.

### Future Work

Path-specific interventions are not computed correctly in the package currently, which must be corrected.

Benchmark counterfactual explanations

Add recourse methods