model.fit ()을 통해 LSTM 모델의 셀 상태를 추출하는 방법은 무엇입니까?

감각

내 LSTM 모델은 다음과 같으며 state_c를 얻고 싶습니다.

def _get_model(input_shape, latent_dim, num_classes):

  inputs = Input(shape=input_shape)
  lstm_lyr,state_h,state_c = LSTM(latent_dim,dropout=0.1,return_state = True)(inputs)
  fc_lyr = Dense(num_classes)(lstm_lyr)
  soft_lyr = Activation('relu')(fc_lyr)
  model = Model(inputs, [soft_lyr,state_c])
  model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
return model
model =_get_model((n_steps_in, n_features),latent_dim ,n_steps_out)
history = model.fit(X_train,Y_train)

그러나 나는 state_c역사에서 추출 할 수 없다 . 반환하는 방법?

Akshay Sehgal

LSTM 계층이 이미 state_cwith 플래그를 반환하고 있기 때문에 "state_c를 얻는 방법"이 무슨 뜻인지 잘 모르겠습니다 return_state=True. 이 경우 다중 출력 모델을 훈련 시키려고한다고 가정합니다. 현재는 단일 출력 만 있지만 모델은 여러 출력으로 컴파일됩니다.

다음은 다중 출력 모델로 작업하는 방법입니다.

from tensorflow.keras import layers, Model, utils

def _get_model(input_shape, latent_dim, num_classes):
    inputs = layers.Input(shape=input_shape)
    lstm_lyr,state_h,state_c = layers.LSTM(latent_dim,dropout=0.1,return_state = True)(inputs)
    fc_lyr = layers.Dense(num_classes)(lstm_lyr)
    soft_lyr = layers.Activation('relu')(fc_lyr)
    model = Model(inputs, [soft_lyr,state_c])   #<------- One input, 2 outputs
    model.compile(optimizer='adam', loss='mse')
    return model


#Dummy data
X = np.random.random((100,15,5))
y1 = np.random.random((100,4))
y2 = np.random.random((100,7))

model =_get_model((15, 5), 7 , 4)
model.fit(X, [y1,y2], epochs=4) #<--------- #One input, 2 outputs
Epoch 1/4
4/4 [==============================] - 2s 6ms/step - loss: 0.6978 - activation_9_loss: 0.2388 - lstm_9_loss: 0.4591
Epoch 2/4
4/4 [==============================] - 0s 6ms/step - loss: 0.6615 - activation_9_loss: 0.2367 - lstm_9_loss: 0.4248
Epoch 3/4
4/4 [==============================] - 0s 7ms/step - loss: 0.6349 - activation_9_loss: 0.2392 - lstm_9_loss: 0.3957
Epoch 4/4
4/4 [==============================] - 0s 8ms/step - loss: 0.6053 - activation_9_loss: 0.2392 - lstm_9_loss: 0.3661

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

셀레늄을 통해 <br> 태그로 구분 된 요소에서 부분 텍스트를 추출하는 방법은 무엇입니까?

코드를 통해 JavaFX에서 GridPane의 셀을 병합하는 방법은 무엇입니까?

Laravel의 다른 모델을 통해 Eloquent 관계를 설정하는 방법은 무엇입니까?

js를 통해 테이블 셀의 내용을 곱하는 방법은 무엇입니까?

tensorflow의 RNN 모델에서 셀 상태와 숨겨진 상태를 추출하는 방법은 무엇입니까?

주의 메커니즘을 위해 LSTM의 이전 출력 및 숨겨진 상태를 사용하는 방법은 무엇입니까?

tensorflow 모델의 출력을 추출하는 방법은 무엇입니까?

LSTM에서 입력 당 1 개 이상의 출력을 생성하는 방법은 무엇입니까?

Fable을 통해 F # 모듈의 공용 함수를 Javascript에 노출하는 방법은 무엇입니까?

다른 모델을 통해 모델에서 템플릿의 가치를 얻는 방법은 무엇입니까?

흰색 픽셀을 통해 연결된 두 점 사이의 거리를 찾는 방법은 무엇입니까?

주어진 셀의 값을 추출하는 방법은 무엇입니까?

페이지의 각 요소를 클릭하기 위해 html 요소 uisng python 셀레늄을 통해 루프를 만드는 방법은 무엇입니까?

SwiftUI에서 상태 변수를 통해 모양을 애니메이션하는 방법은 무엇입니까?

교차 검증 출력에서 최상의 모델을 사용하여 keras model.predict ()를 사용하는 방법은 무엇입니까?

모델에 연관이있는 경우 Ecto를 통해 행을 복제하는 방법은 무엇입니까?

셀레늄의 모델 대화 상자에서 텍스트를 추출하는 방법은 무엇입니까?

OCR을 위해 특정 색상의 픽셀을 추출하는 방법은 무엇입니까?

X, y 및 하나의 추가 배열을 생성하는 model.fit에 사용자 지정 데이터 생성기를 tensorflow.keras 모델에 입력하는 방법은 무엇입니까?

방향 변경을 통해 필터링 된 CursorAdapter의 상태를 유지하는 방법은 무엇입니까?

Vanila Tensorflow의 LSTM 셀에서 모든 가중치를 추출하는 방법은 무엇입니까?

XPath를 통해 이웃 속성 노드의 값을 추출하는 방법은 무엇입니까?

Scrapy의 소스 코드에서 xpath를 통해 섹션을 추출하는 방법은 무엇입니까?

fit_generator로 keras 모델을 피팅하는 동안 'MemoryError'를 수정하는 방법은 무엇입니까?

VBA를 통해 씽크 셀 개체의 이름을 변경하는 방법은 무엇입니까?

기록을 통해 확인란의 상태를 유지하지 않는 방법은 무엇입니까?

qml의 상위 구성 요소를 통해 C++ 모델을 설정하는 방법은 무엇입니까?

CNN-LSTM 모델에 model.fit() 함수를 적용하는 방법은 무엇입니까?

정규식을 통해 팬더의 목록에서 요소를 추출하는 방법은 무엇입니까?

TOP 리스트

  1. 1

    Matlab의 반복 Sortino 비율

  2. 2

    ImageJ-히스토그램 빈을 변경할 때 최대, 최소 값이 변경되는 이유는 무엇입니까?

  3. 3

    Excel : 합계가 N보다 크거나 같은 상위 값 찾기

  4. 4

    C #에서 'System.DBNull'형식의 개체를 'System.String'형식으로 캐스팅 할 수 없습니다.

  5. 5

    원-사각형 충돌의 충돌 측면을 찾는 문제

  6. 6

    Oracle VirtualBox-설치를 위해 게스트를 부팅 할 때 호스트 시스템이 충돌 함

  7. 7

    어떻게 아무리 "나쁜", ANY의 SSL 인증서와 HttpClient를 사용하지합니다

  8. 8

    Ubuntu는 GUI에서 암호로 사용자를 만듭니다.

  9. 9

    잘못된 상태 예외를 발생시키는 Apache PoolingHttpClientConnectionManager

  10. 10

    Python 사전을 사용하는 동안 "ValueError : could not convert string to float :"발생

  11. 11

    openCV python을 사용하여 텍스트 문서에서 워터 마크를 제거하는 방법은 무엇입니까?

  12. 12

    Vuetify 다중 선택 구성 요소에서 클릭 한 항목의 값 가져 오기

  13. 13

    C ++ VSCode에서 같은 줄에 중괄호 서식 지정

  14. 14

    Cassandra에서 버전이 지정된 계층의 효율적인 모델링

  15. 15

    JQuery datepicker 기능이 인식되지 않거나 새 프로젝트에서 작동하지 않음

  16. 16

    cuda 11.1에서 Pytorch를 사용할 때 PyTorch가 작동하지 않음: Dataloader

  17. 17

    jfreecharts에서 x 및 y 축 선을 조정하는 방법

  18. 18

    상황에 맞는 메뉴 색상

  19. 19

    마우스 휠 JQuery 이벤트 핸들러에 대한 방향 가져 오기

  20. 20

    매개 변수에서 쿼리 객체를 선언하는 방법은 무엇입니까?

  21. 21

    Maven은 아이 프로젝트 대상 폴더를 청소하지

뜨겁다태그

보관