File size: 3,170 Bytes
b1beb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import lcpfn
import torch

model = lcpfn.LCPFN()

def line_plot_fn(data, cutoff, ci_form):
    cutoff = int(cutoff)
    ci = int(ci_form)

    empty_values = list(data[data.y == ""].index)

    if len(empty_values) > 0:
        if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))):
            raise gr.Error("Please enter a valid learning curve.")
        else:
            data = data[data.y != ""]
    
    if len(data) < cutoff:
        raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).")

    try:
        data["y"] = data["y"].astype(float)
    except:
        raise gr.Error("Please enter a valid learning curve.")

    x = torch.arange(1, 51).unsqueeze(1)
    y = torch.from_numpy(data.y.values).float().unsqueeze(1)

    rest_prob = (1 - (ci / 100)) / 2
    predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[(cutoff-1):], qs=[rest_prob, 0.5, 1-rest_prob])
    
    fig, ax = plt.subplots()

    ax.plot(x, data.y, "black", label="target")

    # plot extrapolation
    ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
    ax.fill_between(
            x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label="CI of 90%"
    )

    # plot cutoff
    ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed")
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 50)
    ax.legend(loc="lower right")
    ax.set_xlabel("t")
    ax.set_ylabel("y")

    return fig

prior = lcpfn.sample_from_prior(np.random)
curve, _ = prior()

examples = []
for _ in range(10):
    prior = lcpfn.sample_from_prior(np.random)
    curve, _ = prior()
    if np.random.rand() < 0.5:
        curve = _
    df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"])
    df["t"] = [i for i in range(1, 50 + 1)]
    examples.append([df[["t", "y"]], 10])

with gr.Column() as components:
    gr.Number(value=10)
    gr.Number(value=10)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            dataform = gr.Dataframe(
                    value=examples[0][0],
                    headers=["t", "y"],
                    datatype=["number", "number"],
                    row_count=(50, "fixed"),
                    col_count=(2, "fixed"),
                    type="pandas",
                ) 
            with gr.Row():
                cutoffform = gr.Number(label="cutoff", value=10)
                ci_form = gr.Dropdown(label="Confidence Interval", choices=[
                    ("90%", 90),
                    ("95%", 95),
                    ("99%", 99)
                ], value=90)
            btn = gr.Button("Run")
        outputform = gr.Plot()
    btn.click(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
    gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves")




if __name__ == "__main__":
    demo.launch()