如何使用函數式 API 重寫這個順序 API tensorflow 模型?

用戶17020095

我正在嘗試使用功能 API 重寫我的工作順序模型。這是我的順序模型:

num_classes = 3
# Define a simple sequential model
def create_model():
  model = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes)
  ])
  
  model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
  
  return model
  
# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

時序模型的模型總結:

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling_2 (Rescaling)      (None, 180, 180, 3)       0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 180, 180, 16)      448
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 90, 90, 16)        0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 90, 90, 32)        4640
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 45, 45, 32)        0
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 45, 45, 64)        18496
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 22, 22, 64)        0
_________________________________________________________________
flatten_1 (Flatten)          (None, 30976)             0
_________________________________________________________________
dense_2 (Dense)              (None, 128)               3965056
_________________________________________________________________
dense_3 (Dense)              (None, 3)                 387
=================================================================
Total params: 3,989,027
Trainable params: 3,989,027
Non-trainable params: 0
_________________________________________________________________

這是我將其重寫為功能模型的嘗試。

num_classes = 3
input_shape=(img_height, img_width, 3)
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model

def create_model():
  model_input = Input(shape=input_shape) 
  # how to include preprocessing layer that I have in my sequential model: layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  x = Conv2D(16, 3, activation='relu',padding='same')(model_input) 
  x = MaxPooling2D()(x) 
  x = Conv2D(32, 3, activation='relu',padding='same')(model_input) 
  x = MaxPooling2D()(x) 
  x = Conv2D(64, 3, activation='relu',padding='same')(model_input) 
  x = MaxPooling2D()(x) 
  x = Flatten()(x)
  outputs = Dense(num_classes, activation='relu')(x)

  model = Model(model_input, x,)

  model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

  return model

# Create a basic model instance
model = create_model()
 
# Display the model's architecture
model.summary()

keras.utils.plot_model(model, "model_with_shape_info.png", show_shapes=True)

功能模型模型匯總:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_10 (InputLayer)       [(None, 180, 180, 3)]     0         
                                                                 
 conv2d_23 (Conv2D)          (None, 180, 180, 64)      1792      
                                                                 
 max_pooling2d_23 (MaxPoolin  (None, 90, 90, 64)       0         
 g2D)                                                            
                                                                 
 flatten_4 (Flatten)         (None, 518400)            0         
                                                                 
=================================================================
Total params: 1,792
Trainable params: 1,792
Non-trainable params: 0
_________________________________________________________________

我試圖在我的順序模型中逐行確定要包含在我的功能模型中的層。你能幫我理解如何正確地將我的順序模型重寫為功能模型嗎?感謝您的幫助。

編輯:嘗試編譯和訓練功能模型。

model2.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=10
history = model2.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
單獨在一起

您似乎缺少一些圖層並且沒有正確連接它們。試試這個,兩個模型應該有相同的層數和訓練參數:

import tensorflow as tf

model1 = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255, input_shape=(180, 180, 3)),
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(3)
])

model_input = tf.keras.layers.Input(shape=(180, 180, 3)) 
x = tf.keras.layers.Rescaling(1./255)(model_input) 
x = tf.keras.layers.Conv2D(16, 3, activation='relu',padding='same')(x)
x = tf.keras.layers.MaxPooling2D()(x) 
x = tf.keras.layers.Conv2D(32, 3, activation='relu',padding='same')(x) 
x = tf.keras.layers.MaxPooling2D()(x) 
x = tf.keras.layers.Conv2D(64, 3, activation='relu',padding='same')(x) 
x = tf.keras.layers.MaxPooling2D()(x) 
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(3)(x)

model2 = tf.keras.Model(model_input, outputs)

print(model1.summary())

print(model2.summary())

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

如果我使用依賴注入,如何在 Web API 控制器構造函數中傳遞多個接口參數?

如何使用內置於 API 函數 random() 中的 Urbandictionary API

Как сгенерировать уникальные имена для сохраненной модели каждой эпохи с помощью API Keras от TensorFlow

使用各種損失函數評估預訓練的 Tensorflow keras 模型

將 Java 函數式 API 與 Spring Cloud Data Flow 和 Polled Consumers 結合使用

ASP.NET Core Web API - 如何將屬性組合作為原始數據存儲到另一個模型的單列中

為什麼來自 Tensorflow 2 對象檢測 API 的微調模型的 mAP 低?

如何監控 Lambda 運行時,包括每個函數的 API 網關開銷?

Tensorflow 2.0 및 Java API

Tensorflow Estimator API : 요약

NodeJS REST API模型和服务结构

如何在Rest API的Request Body模型類中使用Inteface作為字段類型

在類函數中獲取 API

Firebase Rest Api 使用接口回調從我的模型類讀取數據返回 null

如何在 ReactJS(函數組件)中將數據從 API 渲染到表格

如何使用跨函數重寫相同的代碼

如何為 Tensorflow LSTM 類編寫自定義調用函數?

如何使用 do 和 "by" 重寫這個已棄用的表達式,以及 "groupby" (Julia)

如何使用數組 Json 解碼 API 數據?

Tensorflow Detection API의 SSD 앵커

Tensorflow 데이터 세트 API

TensorFlow Object Detection API 확장

Tensorflow 객체 감지 API

TensorFlow Dataset API 파싱 오류

Tensorflow Serving REST API Throwing 오류

TensorFlow Object Detection API 오류

Tensorflow Java API set Placeholder for categorical columns

TensorFlow C API 로깅 설정

如何重用keras函數模型的層

TOP 리스트

  1. 1

    Ionic 2 로더가 적시에 표시되지 않음

  2. 2

    JSoup javax.net.ssl.SSLHandshakeException : <url>과 일치하는 주체 대체 DNS 이름이 없습니다.

  3. 3

    std :: regex의 일관성없는 동작

  4. 4

    Xcode10 유효성 검사 : 이미지에 투명성이 없지만 여전히 수락되지 않습니까?

  5. 5

    java.lang.UnsatisfiedLinkError : 지정된 모듈을 찾을 수 없습니다

  6. 6

    rclone으로 원격 디렉토리의 모든 파일을 삭제하는 방법은 무엇입니까?

  7. 7

    상황에 맞는 메뉴 색상

  8. 8

    SMTPException : 전송 연결에서 데이터를 읽을 수 없음 : net_io_connectionclosed

  9. 9

    정점 셰이더에서 카메라에서 개체까지의 XY 거리

  10. 10

    Windows cmd를 통해 Anaconda 환경에서 Python 스크립트 실행

  11. 11

    다음 컨트롤이 추가되었지만 사용할 수 없습니다.

  12. 12

    C #에서 'System.DBNull'형식의 개체를 'System.String'형식으로 캐스팅 할 수 없습니다.

  13. 13

    JNDI를 사용하여 Spring Boot에서 다중 데이터 소스 구성

  14. 14

    Cassandra에서 버전이 지정된 계층의 효율적인 모델링

  15. 15

    복사 / 붙여 넣기 비활성화

  16. 16

    Android Kotlin은 다른 활동에서 함수를 호출합니다.

  17. 17

    Google Play Console에서 '예기치 않은 오류가 발생했습니다. 나중에 다시 시도해주세요. (7100000)'오류를 수정하는 방법은 무엇입니까?

  18. 18

    SQL Server-현명한 데이터 문제 받기

  19. 19

    Seaborn에서 축 제목 숨기기

  20. 20

    ArrayBufferLike의 typescript 정의의 깊은 의미

  21. 21

    Kubernetes Horizontal Pod Autoscaler (HPA) 테스트

뜨겁다태그

보관