Diff to HTML by rtfpessoa

Files changed (3) hide show
  1. app-section-3/app/main.py 2023-04-08 18:02:25 → app-section-4/app/main.py 2023-04-18 12:47:49 +37 -1
  2. app-section-3/app/pydantic_models.py 2023-04-18 09:34:31 → app-section-4/app/pydantic_models.py 2023-04-18 09:43:32 +14 -0
  3. app-section-3/tests/test_app.py 2023-04-03 16:54:22 → app-section-4/tests/test_app.py 2023-04-18 09:35:23 +13 -0
app-section-3/app/main.py 2023-04-08 18:02:25 → app-section-4/app/main.py 2023-04-18 12:47:49 RENAMED
@@ -1,7 +1,26 @@
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from .pydantic_models import Observation, Prediction
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  app = FastAPI()
6
 
7
 
@@ -14,4 +33,21 @@
14
  @app.post("/predict", status_code=201)
15
  def predict(obs: Observation) -> Prediction:
16
  """For now, just return a dummy prediction."""
17
- return Prediction(flower_type="setosa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import importlib
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ from sklearn.linear_model import LogisticRegression
7
  from fastapi import FastAPI
8
  from .pydantic_models import Observation, Prediction
9
 
10
 
11
+ def load_model(model_name: str) -> LogisticRegression:
12
+ with importlib.resources.open_binary("app.models", model_name) as f:
13
+ model = pickle.load(f)
14
+ return model
15
+
16
+
17
+ MODEL_NAME = "iris_regression.pickle"
18
+ CLASS_FLOWER_MAPPING = {
19
+ 0: 'setosa',
20
+ 1: 'versicolor',
21
+ 2: 'virginica',
22
+ }
23
+ model = load_model(MODEL_NAME)
24
  app = FastAPI()
25
 
26
 
 
33
  @app.post("/predict", status_code=201)
34
  def predict(obs: Observation) -> Prediction:
35
  """For now, just return a dummy prediction."""
36
+ # .predict() gives us an array, but it has only one element
37
+ prediction = model.predict(obs.as_dataframe())[0]
38
+ flower_type = CLASS_FLOWER_MAPPING[prediction]
39
+ pred = Prediction(flower_type=flower_type)
40
+ return pred
41
+
42
+
43
+ @app.post("/batch_predict", status_code=201)
44
+ def batch_predict(batch: List[Observation]) -> List[Prediction]:
45
+ """Predict the flower type for a batch of observations."""
46
+ rows = [obs.as_row() for obs in batch]
47
+ df = pd.DataFrame(rows)
48
+ output_classes = model.predict(df)
49
+ preds = [
50
+ Prediction(flower_type=CLASS_FLOWER_MAPPING[output_class])
51
+ for output_class in output_classes
52
+ ]
53
+ return preds
app-section-3/app/pydantic_models.py 2023-04-18 09:34:31 → app-section-4/app/pydantic_models.py 2023-04-18 09:43:32 RENAMED
@@ -1,5 +1,6 @@
1
  from typing import Literal
2
 
 
3
  from pydantic import BaseModel
4
 
5
 
@@ -11,6 +12,19 @@
11
  sepal_width: float
12
  petal_length: float
13
  petal_width: float
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class Prediction(BaseModel):
 
1
  from typing import Literal
2
 
3
+ import pandas as pd
4
  from pydantic import BaseModel
5
 
6
 
 
12
  sepal_width: float
13
  petal_length: float
14
  petal_width: float
15
+
16
+ def as_dataframe(self) -> pd.DataFrame:
17
+ """Convert this record to a DataFrame with one row."""
18
+ return pd.DataFrame([self.as_row()])
19
+
20
+ def as_row(self) -> pd.Series:
21
+ row = pd.Series({
22
+ "sepal length (cm)": self.sepal_length,
23
+ "sepal width (cm)": self.sepal_width,
24
+ "petal length (cm)": self.petal_length,
25
+ "petal width (cm)": self.petal_width,
26
+ })
27
+ return row
28
 
29
 
30
  class Prediction(BaseModel):
app-section-3/tests/test_app.py 2023-04-03 16:54:22 → app-section-4/tests/test_app.py 2023-04-18 09:35:23 RENAMED
@@ -21,3 +21,16 @@
21
  assert response.status_code == 201
22
  payload = response.json()
23
  assert payload["flower_type"] == "setosa"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  assert response.status_code == 201
22
  payload = response.json()
23
  assert payload["flower_type"] == "setosa"
24
+
25
+ response = client.post(
26
+ "/predict",
27
+ json={
28
+ "sepal_length": 7.1,
29
+ "sepal_width": 3.5,
30
+ "petal_length": 3.0,
31
+ "petal_width": 0.8,
32
+ },
33
+ )
34
+ assert response.status_code == 201
35
+ payload = response.json()
36
+ assert payload["flower_type"] == "versicolor"